atc_router/
ast.rs

1use crate::schema::Schema;
2use cidr::IpCidr;
3use regex::Regex;
4use std::net::IpAddr;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
10#[derive(Debug)]
11pub enum Expression {
12    Logical(Box<LogicalExpression>),
13    Predicate(Predicate),
14}
15
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17#[derive(Debug)]
18pub enum LogicalExpression {
19    And(Expression, Expression),
20    Or(Expression, Expression),
21    Not(Expression),
22}
23
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25#[derive(Debug, PartialEq, Eq)]
26pub enum LhsTransformations {
27    Lower,
28    Any,
29}
30
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[derive(Debug, PartialEq, Eq)]
33pub enum BinaryOperator {
34    Equals,         // ==
35    NotEquals,      // !=
36    Regex,          // ~
37    Prefix,         // ^=
38    Postfix,        // =^
39    Greater,        // >
40    GreaterOrEqual, // >=
41    Less,           // <
42    LessOrEqual,    // <=
43    In,             // in
44    NotIn,          // not in
45    Contains,       // contains
46}
47
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[derive(Debug, Clone)]
50pub enum Value {
51    String(String),
52    IpCidr(IpCidr),
53    IpAddr(IpAddr),
54    Int(i64),
55    #[cfg_attr(feature = "serde", serde(with = "serde_regex"))]
56    Regex(Regex),
57}
58
59impl PartialEq for Value {
60    fn eq(&self, other: &Self) -> bool {
61        match (self, other) {
62            (Self::Regex(_), _) | (_, Self::Regex(_)) => {
63                panic!("Regexes can not be compared using eq")
64            }
65            (Self::String(s1), Self::String(s2)) => s1 == s2,
66            (Self::IpCidr(i1), Self::IpCidr(i2)) => i1 == i2,
67            (Self::IpAddr(i1), Self::IpAddr(i2)) => i1 == i2,
68            (Self::Int(i1), Self::Int(i2)) => i1 == i2,
69            _ => false,
70        }
71    }
72}
73
74impl Value {
75    pub fn my_type(&self) -> Type {
76        match self {
77            Value::String(_) => Type::String,
78            Value::IpCidr(_) => Type::IpCidr,
79            Value::IpAddr(_) => Type::IpAddr,
80            Value::Int(_) => Type::Int,
81            Value::Regex(_) => Type::Regex,
82        }
83    }
84}
85
86impl From<String> for Value {
87    fn from(v: String) -> Self {
88        Value::String(v)
89    }
90}
91
92#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Eq, PartialEq)]
94#[repr(C)]
95pub enum Type {
96    String,
97    IpCidr,
98    IpAddr,
99    Int,
100    Regex,
101}
102
103#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
104#[derive(Debug)]
105pub struct Lhs {
106    pub var_name: String,
107    pub transformations: Vec<LhsTransformations>,
108}
109
110impl Lhs {
111    pub fn my_type<'a>(&self, schema: &'a Schema) -> Option<&'a Type> {
112        schema.type_of(&self.var_name)
113    }
114
115    pub fn get_transformations(&self) -> (bool, bool) {
116        let mut lower = false;
117        let mut any = false;
118
119        self.transformations.iter().for_each(|i| match i {
120            LhsTransformations::Any => any = true,
121            LhsTransformations::Lower => lower = true,
122        });
123
124        (lower, any)
125    }
126}
127
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129#[derive(Debug)]
130pub struct Predicate {
131    pub lhs: Lhs,
132    pub rhs: Value,
133    pub op: BinaryOperator,
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::parser::parse;
140    use std::fmt;
141
142    impl fmt::Display for Expression {
143        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144            write!(
145                f,
146                "{}",
147                match self {
148                    Expression::Logical(logical) => logical.to_string(),
149                    Expression::Predicate(predicate) => predicate.to_string(),
150                }
151            )
152        }
153    }
154
155    impl fmt::Display for LogicalExpression {
156        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157            write!(
158                f,
159                "{}",
160                match self {
161                    LogicalExpression::And(left, right) => {
162                        format!("({} && {})", left, right)
163                    }
164                    LogicalExpression::Or(left, right) => {
165                        format!("({} || {})", left, right)
166                    }
167                    LogicalExpression::Not(e) => {
168                        format!("!({})", e)
169                    }
170                }
171            )
172        }
173    }
174
175    impl fmt::Display for LhsTransformations {
176        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
177            write!(
178                f,
179                "{}",
180                match self {
181                    LhsTransformations::Lower => "lower".to_string(),
182                    LhsTransformations::Any => "any".to_string(),
183                }
184            )
185        }
186    }
187
188    impl fmt::Display for Value {
189        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190            match self {
191                Value::String(s) => write!(f, "\"{}\"", s),
192                Value::IpCidr(cidr) => write!(f, "{}", cidr),
193                Value::IpAddr(addr) => write!(f, "{}", addr),
194                Value::Int(i) => write!(f, "{}", i),
195                Value::Regex(re) => write!(f, "\"{}\"", re),
196            }
197        }
198    }
199
200    impl fmt::Display for Lhs {
201        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202            let mut s = self.var_name.to_string();
203            for transformation in &self.transformations {
204                s = format!("{}({})", transformation, s);
205            }
206            write!(f, "{}", s)
207        }
208    }
209
210    impl fmt::Display for BinaryOperator {
211        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212            use BinaryOperator::*;
213
214            write!(
215                f,
216                "{}",
217                match self {
218                    Equals => "==",
219                    NotEquals => "!=",
220                    Regex => "~",
221                    Prefix => "^=",
222                    Postfix => "=^",
223                    Greater => ">",
224                    GreaterOrEqual => ">=",
225                    Less => "<",
226                    LessOrEqual => "<=",
227                    In => "in",
228                    NotIn => "not in",
229                    Contains => "contains",
230                }
231            )
232        }
233    }
234
235    impl fmt::Display for Predicate {
236        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237            write!(f, "({} {} {})", self.lhs, self.op, self.rhs)
238        }
239    }
240
241    #[test]
242    fn expr_op_and_prec() {
243        let tests = vec![
244            ("a > 0", "(a > 0)"),
245            ("a in \"abc\"", "(a in \"abc\")"),
246            ("a == 1 && b != 2", "((a == 1) && (b != 2))"),
247            (
248                "a ^= \"1\" && b =^ \"2\" || c >= 3",
249                "((a ^= \"1\") && ((b =^ \"2\") || (c >= 3)))",
250            ),
251            (
252                "a == 1 && b != 2 || c >= 3",
253                "((a == 1) && ((b != 2) || (c >= 3)))",
254            ),
255            (
256                "a > 1 || b < 2 && c <= 3 || d not in \"foo\"",
257                "(((a > 1) || (b < 2)) && ((c <= 3) || (d not in \"foo\")))",
258            ),
259            (
260                "a > 1 || ((b < 2) && (c <= 3)) || d not in \"foo\"",
261                "(((a > 1) || ((b < 2) && (c <= 3))) || (d not in \"foo\"))",
262            ),
263            ("!(a == 1)", "!((a == 1))"),
264            (
265                "!(a == 1) && b == 2 && !(c == 3) && d >= 4",
266                "(((!((a == 1)) && (b == 2)) && !((c == 3))) && (d >= 4))",
267            ),
268            (
269                "!(a == 1 || b == 2 && c == 3) && d == 4",
270                "(!((((a == 1) || (b == 2)) && (c == 3))) && (d == 4))",
271            ),
272        ];
273        for (input, expected) in tests {
274            let result = parse(input).unwrap();
275            assert_eq!(result.to_string(), expected);
276        }
277    }
278
279    #[test]
280    fn expr_var_name_and_ip() {
281        let tests = vec![
282            // ipv4_literal
283            ("kong.foo in 1.1.1.1", "(kong.foo in 1.1.1.1)"),
284            // ipv4_cidr_literal
285            (
286                "kong.foo.foo2 in 10.0.0.0/24",
287                "(kong.foo.foo2 in 10.0.0.0/24)",
288            ),
289            // ipv6_literal
290            (
291                "kong.foo.foo3 in 2001:db8::/32",
292                "(kong.foo.foo3 in 2001:db8::/32)",
293            ),
294            // ipv6_cidr_literal
295            (
296                "kong.foo.foo4 in 2001:db8::/32",
297                "(kong.foo.foo4 in 2001:db8::/32)",
298            ),
299        ];
300        for (input, expected) in tests {
301            let result = parse(input).unwrap();
302            assert_eq!(result.to_string(), expected);
303        }
304    }
305
306    #[test]
307    fn expr_regex() {
308        let tests = vec![
309            // regex_literal
310            (
311                "kong.foo.foo5 ~ \"^foo.*$\"",
312                "(kong.foo.foo5 ~ \"^foo.*$\")",
313            ),
314            // regex_literal
315            (
316                "kong.foo.foo6 ~ \"^foo.*$\"",
317                "(kong.foo.foo6 ~ \"^foo.*$\")",
318            ),
319        ];
320        for (input, expected) in tests {
321            let result = parse(input).unwrap();
322            assert_eq!(result.to_string(), expected);
323        }
324    }
325
326    #[test]
327    fn expr_digits() {
328        let tests = vec![
329            // dec literal
330            ("kong.foo.foo7 == 123", "(kong.foo.foo7 == 123)"),
331            // hex literal
332            ("kong.foo.foo8 == 0x123", "(kong.foo.foo8 == 291)"),
333            // oct literal
334            ("kong.foo.foo9 == 0123", "(kong.foo.foo9 == 83)"),
335            // dec negative literal
336            ("kong.foo.foo10 == -123", "(kong.foo.foo10 == -123)"),
337            // hex negative literal
338            ("kong.foo.foo11 == -0x123", "(kong.foo.foo11 == -291)"),
339            // oct negative literal
340            ("kong.foo.foo12 == -0123", "(kong.foo.foo12 == -83)"),
341        ];
342        for (input, expected) in tests {
343            let result = parse(input).unwrap();
344            assert_eq!(result.to_string(), expected);
345        }
346    }
347
348    #[test]
349    fn expr_transformations() {
350        let tests = vec![
351            // lower
352            (
353                "lower(kong.foo.foo13) == \"foo\"",
354                "(lower(kong.foo.foo13) == \"foo\")",
355            ),
356            // any
357            (
358                "any(kong.foo.foo14) == \"foo\"",
359                "(any(kong.foo.foo14) == \"foo\")",
360            ),
361        ];
362        for (input, expected) in tests {
363            let result = parse(input).unwrap();
364            assert_eq!(result.to_string(), expected);
365        }
366    }
367
368    #[test]
369    fn expr_transformations_nested() {
370        let tests = vec![
371            // lower + lower
372            (
373                "lower(lower(kong.foo.foo15)) == \"foo\"",
374                "(lower(lower(kong.foo.foo15)) == \"foo\")",
375            ),
376            // lower + any
377            (
378                "lower(any(kong.foo.foo16)) == \"foo\"",
379                "(lower(any(kong.foo.foo16)) == \"foo\")",
380            ),
381            // any + lower
382            (
383                "any(lower(kong.foo.foo17)) == \"foo\"",
384                "(any(lower(kong.foo.foo17)) == \"foo\")",
385            ),
386            // any + any
387            (
388                "any(any(kong.foo.foo18)) == \"foo\"",
389                "(any(any(kong.foo.foo18)) == \"foo\")",
390            ),
391        ];
392        for (input, expected) in tests {
393            let result = parse(input).unwrap();
394            assert_eq!(result.to_string(), expected);
395        }
396    }
397
398    #[test]
399    fn str_unicode_test() {
400        let tests = vec![
401            // cjk chars
402            ("t_msg in \"你好\"", "(t_msg in \"你好\")"),
403            // 0xXXX unicode
404            ("t_msg in \"\u{4f60}\u{597d}\"", "(t_msg in \"你好\")"),
405        ];
406        for (input, expected) in tests {
407            let result = parse(input).unwrap();
408            assert_eq!(result.to_string(), expected);
409        }
410    }
411
412    #[test]
413    fn rawstr_test() {
414        let tests = vec![
415            // invalid escape sequence
416            (r##"a == r#"/path/to/\d+"#"##, r#"(a == "/path/to/\d+")"#),
417            // valid escape sequence
418            (r##"a == r#"/path/to/\n+"#"##, r#"(a == "/path/to/\n+")"#),
419        ];
420        for (input, expected) in tests {
421            let result = parse(input).unwrap();
422            assert_eq!(result.to_string(), expected);
423        }
424    }
425}