viceroy_lib/
acl.rs

1use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::sync::Arc;
6
7/// Acls is a mapping of names to acl.
8#[derive(Clone, Debug, Default)]
9pub struct Acls {
10    acls: HashMap<String, Arc<Acl>>,
11}
12
13impl Acls {
14    pub fn new() -> Self {
15        Self {
16            acls: HashMap::new(),
17        }
18    }
19
20    pub fn get_acl(&self, name: &str) -> Option<&Arc<Acl>> {
21        self.acls.get(name)
22    }
23
24    pub fn insert(&mut self, name: String, acl: Acl) {
25        self.acls.insert(name, Arc::new(acl));
26    }
27}
28
29/// An acl is a collection of acl entries.
30///
31/// The JSON representation of this struct intentionally matches the JSON
32/// format used to create/update ACLs via api.fastly.com. The goal being
33/// to allow users to use the same JSON in Viceroy as in production.
34///
35/// Example:
36///
37/// ```json
38///    { "entries": [
39///        { "op": "create", "prefix": "1.2.3.0/24", "action": "BLOCK" },
40///        { "op": "create", "prefix": "23.23.23.23/32", "action": "ALLOW" },
41///        { "op": "update", "prefix": "FACE::/32", "action": "ALLOW" }
42///    ]}
43/// ```
44///
45/// Note that, in Viceroy, the `op` field is ignored.
46#[derive(Debug, Default, Deserialize)]
47pub struct Acl {
48    pub(crate) entries: Vec<Entry>,
49}
50
51impl Acl {
52    /// Lookup performs a naive lookup of the given IP address
53    /// over the acls entries.
54    ///
55    /// If the IP matches multiple ACL entries, then:
56    /// - The most specific match is returned (longest mask),
57    /// - and in case of a tie, the last entry wins.
58    pub fn lookup(&self, ip: IpAddr) -> Option<&Entry> {
59        self.entries.iter().fold(None, |acc, entry| {
60            if let Some(mask) = entry.prefix.is_match(ip) {
61                if acc.is_none_or(|prev_match: &Entry| mask >= prev_match.prefix.mask) {
62                    return Some(entry);
63                }
64            }
65            acc
66        })
67    }
68}
69
70/// An entry is an IP prefix and its associated action.
71#[derive(Debug, Deserialize, Serialize, PartialEq)]
72pub struct Entry {
73    prefix: Prefix,
74    action: Action,
75}
76
77/// A prefix is an IP and network mask.
78#[derive(Debug, PartialEq)]
79pub struct Prefix {
80    ip: IpAddr,
81    mask: u8,
82}
83
84impl Prefix {
85    pub(crate) fn new(ip: IpAddr, mask: u8) -> Self {
86        // Normalize IP based on mask.
87        let (ip, mask) = match ip {
88            IpAddr::V4(v4) => {
89                let mask = mask.clamp(1, 32);
90                let bit_mask = u32::MAX << (32 - mask);
91                (
92                    IpAddr::V4(Ipv4Addr::from_bits(v4.to_bits() & bit_mask)),
93                    mask,
94                )
95            }
96            IpAddr::V6(v6) => {
97                let mask = mask.clamp(1, 128);
98                let bit_mask = u128::MAX << (128 - mask);
99                (
100                    IpAddr::V6(Ipv6Addr::from_bits(v6.to_bits() & bit_mask)),
101                    mask,
102                )
103            }
104        };
105
106        Self { ip, mask }
107    }
108
109    /// If the given IP matches the prefix, then the prefix's
110    /// mask is returned.
111    pub(crate) fn is_match(&self, ip: IpAddr) -> Option<u8> {
112        let masked = Self::new(ip, self.mask);
113        if masked.ip == self.ip {
114            Some(self.mask)
115        } else {
116            None
117        }
118    }
119}
120
121impl Display for Prefix {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.write_fmt(format_args!("{}/{}", self.ip, self.mask))
124    }
125}
126
127impl<'de> Deserialize<'de> for Prefix {
128    fn deserialize<D>(de: D) -> Result<Self, D::Error>
129    where
130        D: Deserializer<'de>,
131    {
132        let v = String::deserialize(de)?;
133        let (ip, mask) = v.split_once('/').ok_or(D::Error::custom(format!(
134            "invalid format '{}': want IP/MASK",
135            v
136        )))?;
137
138        let mask = mask
139            .parse::<u8>()
140            .map_err(|err| D::Error::custom(format!("invalid prefix {}: {}", mask, err)))?;
141
142        // Detect whether the IP is v4 or v6.
143        let ip = match ip.contains(':') {
144            false => {
145                if !(1..=32).contains(&mask) {
146                    return Err(D::Error::custom(format!(
147                        "mask outside allowed range [1, 32]: {}",
148                        mask
149                    )));
150                }
151                ip.parse::<Ipv4Addr>().map(IpAddr::V4)
152            }
153            true => {
154                if !(1..=128).contains(&mask) {
155                    return Err(D::Error::custom(format!(
156                        "mask outside allowed range [1, 128]: {}",
157                        mask
158                    )));
159                }
160                ip.parse::<Ipv6Addr>().map(IpAddr::V6)
161            }
162        }
163        .map_err(|err| D::Error::custom(format!("invalid ip address {}: {}", ip, err)))?;
164
165        Ok(Self::new(ip, mask))
166    }
167}
168
169impl Serialize for Prefix {
170    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171    where
172        S: Serializer,
173    {
174        serializer.serialize_str(format!("{}", self).as_str())
175    }
176}
177
178const ACTION_ALLOW: &str = "ALLOW";
179const ACTION_BLOCK: &str = "BLOCK";
180
181/// An action for a prefix.
182#[derive(Clone, Debug, PartialEq)]
183pub enum Action {
184    Allow,
185    Block,
186    Other(String),
187}
188
189impl<'de> Deserialize<'de> for Action {
190    fn deserialize<D>(de: D) -> Result<Self, D::Error>
191    where
192        D: Deserializer<'de>,
193    {
194        let action = String::deserialize(de)?;
195        Ok(match action.to_uppercase().as_str() {
196            ACTION_ALLOW => Self::Allow,
197            ACTION_BLOCK => Self::Block,
198            _ => Self::Other(action),
199        })
200    }
201}
202
203impl Serialize for Action {
204    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205    where
206        S: Serializer,
207    {
208        match self {
209            Self::Allow => serializer.serialize_str(ACTION_ALLOW),
210            Self::Block => serializer.serialize_str(ACTION_BLOCK),
211            Self::Other(other) => serializer.serialize_str(format!("Other({})", other).as_str()),
212        }
213    }
214}
215
216#[test]
217fn prefix_is_match() {
218    let prefix = Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 16);
219
220    assert_eq!(
221        prefix.is_match(Ipv4Addr::new(192, 168, 100, 0).into()),
222        Some(16)
223    );
224    assert_eq!(
225        prefix.is_match(Ipv4Addr::new(192, 168, 200, 200).into()),
226        Some(16)
227    );
228
229    assert_eq!(prefix.is_match(Ipv4Addr::new(192, 167, 0, 0).into()), None);
230    assert_eq!(prefix.is_match(Ipv4Addr::new(192, 169, 0, 0).into()), None);
231
232    let prefix = Prefix::new(Ipv6Addr::new(0xFACE, 0, 0, 0, 0, 0, 0, 0).into(), 16);
233    assert_eq!(
234        prefix.is_match(Ipv6Addr::new(0xFACE, 1, 2, 3, 4, 5, 6, 7).into()),
235        Some(16)
236    );
237
238    let v4 = Ipv4Addr::new(192, 168, 200, 200);
239    let v4_as_v6 = v4.to_ipv6_mapped();
240
241    assert_eq!(Prefix::new(v4.into(), 8).is_match(v4_as_v6.into()), None);
242    assert_eq!(Prefix::new(v4_as_v6.into(), 8).is_match(v4.into()), None);
243}
244
245#[test]
246fn acl_lookup() {
247    let acl = Acl {
248        entries: vec![
249            Entry {
250                prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 16),
251                action: Action::Block,
252            },
253            Entry {
254                prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 24),
255                action: Action::Block,
256            },
257            Entry {
258                prefix: Prefix::new(Ipv4Addr::new(192, 168, 100, 0).into(), 8),
259                action: Action::Block,
260            },
261        ],
262    };
263
264    match acl.lookup(Ipv4Addr::new(192, 168, 100, 1).into()) {
265        Some(lookup_match) => {
266            assert_eq!(acl.entries[1], *lookup_match);
267        }
268        None => panic!("expected lookup match"),
269    };
270
271    match acl.lookup(Ipv4Addr::new(192, 168, 200, 1).into()) {
272        Some(lookup_match) => {
273            assert_eq!(acl.entries[0], *lookup_match);
274        }
275        None => panic!("expected lookup match"),
276    };
277
278    match acl.lookup(Ipv4Addr::new(192, 1, 1, 1).into()) {
279        Some(lookup_match) => {
280            assert_eq!(acl.entries[2], *lookup_match);
281        }
282        None => panic!("expected lookup match"),
283    };
284
285    if let Some(lookup_match) = acl.lookup(Ipv4Addr::new(1, 1, 1, 1).into()) {
286        panic!("expected no lookup match, got {:?}", lookup_match)
287    };
288}
289
290#[test]
291fn acl_json_parse() {
292    // In the following JSON, the `op` field should be ignored. It's included
293    // to assert that the JSON format used with api.fastly.com to create/modify
294    // ACLs can be used in Viceroy as well.
295    let input = r#"
296    { "entries": [
297        { "op": "create", "prefix": "1.2.3.0/24", "action": "BLOCK" },
298        { "op": "update", "prefix": "192.168.0.0/16", "action": "BLOCK" },
299        { "op": "create", "prefix": "23.23.23.23/32", "action": "ALLOW" },
300        { "op": "update", "prefix": "1.2.3.4/32", "action": "ALLOW" },
301        { "op": "update", "prefix": "1.2.3.4/8", "action": "ALLOW" }
302    ]}
303    "#;
304    let acl: Acl = serde_json::from_str(input).expect("can decode");
305
306    let want = vec![
307        Entry {
308            prefix: Prefix {
309                ip: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 0)),
310                mask: 24,
311            },
312            action: Action::Block,
313        },
314        Entry {
315            prefix: Prefix {
316                ip: IpAddr::V4(Ipv4Addr::new(192, 168, 0, 0)),
317                mask: 16,
318            },
319            action: Action::Block,
320        },
321        Entry {
322            prefix: Prefix {
323                ip: IpAddr::V4(Ipv4Addr::new(23, 23, 23, 23)),
324                mask: 32,
325            },
326            action: Action::Allow,
327        },
328        Entry {
329            prefix: Prefix {
330                ip: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
331                mask: 32,
332            },
333            action: Action::Allow,
334        },
335        Entry {
336            prefix: Prefix {
337                ip: IpAddr::V4(Ipv4Addr::new(1, 0, 0, 0)),
338                mask: 8,
339            },
340            action: Action::Allow,
341        },
342    ];
343
344    assert_eq!(acl.entries, want);
345}
346
347#[test]
348fn prefix_json_roundtrip() {
349    let assert_roundtrips = |input: &str, want: &str| {
350        let prefix: Prefix =
351            serde_json::from_str(format!("\"{}\"", input).as_str()).expect("can decode");
352        let got = serde_json::to_string(&prefix).expect("can encode");
353        assert_eq!(
354            got,
355            format!("\"{}\"", want),
356            "'{}' roundtrip: got {}, want {}",
357            input,
358            got,
359            want
360        );
361    };
362
363    assert_roundtrips("255.255.255.255/32", "255.255.255.255/32");
364    assert_roundtrips("255.255.255.255/8", "255.0.0.0/8");
365
366    assert_roundtrips("2002::1234:abcd:ffff:c0a8:101/64", "2002:0:0:1234::/64");
367    assert_roundtrips("2000::AB/32", "2000::/32");
368
369    // Invalid prefix.
370    assert!(serde_json::from_str::<Prefix>("\"1.2.3.4/33\"").is_err());
371    assert!(serde_json::from_str::<Prefix>("\"200::/129\"").is_err());
372    assert!(serde_json::from_str::<Prefix>("\"200::/none\"").is_err());
373
374    // Invalid IP.
375    assert!(serde_json::from_str::<Prefix>("\"1.2.3.four/16\"").is_err());
376    assert!(serde_json::from_str::<Prefix>("\"200::end/32\"").is_err());
377
378    // Invalid format.
379    assert!(serde_json::from_str::<Prefix>("\"1.2.3.4\"").is_err());
380    assert!(serde_json::from_str::<Prefix>("\"200::\"").is_err());
381}
382
383#[test]
384fn action_json_roundtrip() {
385    let assert_roundtrips = |input: &str, want: &str| {
386        let action: Action =
387            serde_json::from_str(format!("\"{}\"", input).as_str()).expect("can decode");
388        let got = serde_json::to_string(&action).expect("can encode");
389        assert_eq!(
390            got,
391            format!("\"{}\"", want),
392            "'{}' roundtrip: got {}, want {}",
393            input,
394            got,
395            want
396        );
397    };
398
399    assert_roundtrips("ALLOW", "ALLOW");
400    assert_roundtrips("allow", "ALLOW");
401    assert_roundtrips("BLOCK", "BLOCK");
402    assert_roundtrips("block", "BLOCK");
403    assert_roundtrips("POTATO", "Other(POTATO)");
404    assert_roundtrips("potato", "Other(potato)");
405}