atc_router/
semantics.rs

1use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value};
2use crate::schema::Schema;
3use std::collections::HashMap;
4
5type ValidationResult = Result<(), String>;
6
7pub trait Validate {
8    fn validate(&self, schema: &Schema) -> ValidationResult;
9}
10
11pub trait FieldCounter {
12    fn add_to_counter(&self, map: &mut HashMap<String, usize>);
13    fn remove_from_counter(&self, map: &mut HashMap<String, usize>);
14}
15
16impl FieldCounter for Expression {
17    fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
18        match self {
19            Expression::Logical(l) => match l.as_ref() {
20                LogicalExpression::And(l, r) => {
21                    l.add_to_counter(map);
22                    r.add_to_counter(map);
23                }
24                LogicalExpression::Or(l, r) => {
25                    l.add_to_counter(map);
26                    r.add_to_counter(map);
27                }
28                LogicalExpression::Not(r) => {
29                    r.add_to_counter(map);
30                }
31            },
32            Expression::Predicate(p) => {
33                *map.entry(p.lhs.var_name.clone()).or_default() += 1;
34            }
35        }
36    }
37
38    fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
39        match self {
40            Expression::Logical(l) => match l.as_ref() {
41                LogicalExpression::And(l, r) => {
42                    l.remove_from_counter(map);
43                    r.remove_from_counter(map);
44                }
45                LogicalExpression::Or(l, r) => {
46                    l.remove_from_counter(map);
47                    r.remove_from_counter(map);
48                }
49                LogicalExpression::Not(r) => {
50                    r.remove_from_counter(map);
51                }
52            },
53            Expression::Predicate(p) => {
54                let val = map.get_mut(&p.lhs.var_name).unwrap();
55                *val -= 1;
56
57                if *val == 0 {
58                    assert!(map.remove(&p.lhs.var_name).is_some());
59                }
60            }
61        }
62    }
63}
64
65impl Validate for Expression {
66    fn validate(&self, schema: &Schema) -> ValidationResult {
67        match self {
68            Expression::Logical(l) => {
69                match l.as_ref() {
70                    LogicalExpression::And(l, r) => {
71                        l.validate(schema)?;
72                        r.validate(schema)?;
73                    }
74                    LogicalExpression::Or(l, r) => {
75                        l.validate(schema)?;
76                        r.validate(schema)?;
77                    }
78                    LogicalExpression::Not(r) => {
79                        r.validate(schema)?;
80                    }
81                }
82
83                Ok(())
84            }
85            Expression::Predicate(p) => {
86                // lhs and rhs must be the same type
87                let lhs_type = p.lhs.my_type(schema);
88                if lhs_type.is_none() {
89                    return Err("Unknown LHS field".to_string());
90                }
91                let lhs_type = lhs_type.unwrap();
92
93                if p.op != BinaryOperator::Regex // Regex RHS is always Regex, and LHS is always String
94                    && p.op != BinaryOperator::In // In/NotIn supports IPAddr in IpCidr
95                    && p.op != BinaryOperator::NotIn
96                    && lhs_type != &p.rhs.my_type()
97                {
98                    return Err(
99                        "Type mismatch between the LHS and RHS values of predicate".to_string()
100                    );
101                }
102
103                let (lower, _any) = p.lhs.get_transformations();
104
105                // LHS transformations only makes sense with string fields
106                if lower && lhs_type != &Type::String {
107                    return Err(
108                        "lower-case transformation function only supported with String type fields"
109                            .to_string(),
110                    );
111                }
112
113                match p.op {
114                    BinaryOperator::Equals | BinaryOperator::NotEquals => { Ok(()) }
115                    BinaryOperator::Regex => {
116                        // unchecked path above
117                        if lhs_type == &Type::String {
118                            Ok(())
119                        } else {
120                            Err("Regex operators only supports string operands".to_string())
121                        }
122                    },
123                    BinaryOperator::Prefix | BinaryOperator::Postfix => {
124                        match p.rhs {
125                            Value::String(_) => {
126                                Ok(())
127                            }
128                            _ => Err("Regex/Prefix/Postfix operators only supports string operands".to_string())
129                        }
130                    },
131                    BinaryOperator::Greater | BinaryOperator::GreaterOrEqual | BinaryOperator::Less | BinaryOperator::LessOrEqual => {
132                        match p.rhs {
133                            Value::Int(_) => {
134                                Ok(())
135                            }
136                            _ => Err("Greater/GreaterOrEqual/Lesser/LesserOrEqual operators only supports integer operands".to_string())
137                        }
138                    },
139                    BinaryOperator::In | BinaryOperator::NotIn => {
140                        // unchecked path above
141                        match (lhs_type, &p.rhs,) {
142                            (Type::IpAddr, Value::IpCidr(_)) => {
143                                Ok(())
144                            }
145                            _ => Err("In/NotIn operators only supports IP in CIDR".to_string())
146                        }
147                    },
148                    BinaryOperator::Contains => {
149                        match p.rhs {
150                            Value::String(_) => {
151                                Ok(())
152                            }
153                            _ => Err("Contains operator only supports string operands".to_string())
154                        }
155                    }
156                }
157            }
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::parser::parse;
166    use lazy_static::lazy_static;
167
168    lazy_static! {
169        static ref SCHEMA: Schema = {
170            let mut s = Schema::default();
171            s.add_field("string", Type::String);
172            s.add_field("int", Type::Int);
173            s.add_field("ipaddr", Type::IpAddr);
174            s
175        };
176    }
177
178    #[test]
179    fn unknown_field() {
180        let expression = parse(r#"unkn == "abc""#).unwrap();
181        assert_eq!(
182            expression.validate(&SCHEMA).unwrap_err(),
183            "Unknown LHS field"
184        );
185    }
186
187    #[test]
188    fn string_lhs() {
189        let tests = vec![
190            r#"string == "abc""#,
191            r#"string != "abc""#,
192            r#"string ~ "abc""#,
193            r#"string ^= "abc""#,
194            r#"string =^ "abc""#,
195            r#"lower(string) =^ "abc""#,
196        ];
197        for input in tests {
198            let expression = parse(input).unwrap();
199            expression.validate(&SCHEMA).unwrap();
200        }
201
202        let failing_tests = vec![
203            r#"string == 192.168.0.1"#,
204            r#"string == 192.168.0.0/24"#,
205            r#"string == 123"#,
206            r#"string in "abc""#,
207        ];
208        for input in failing_tests {
209            let expression = parse(input).unwrap();
210            assert!(expression.validate(&SCHEMA).is_err());
211        }
212    }
213
214    #[test]
215    fn ipaddr_lhs() {
216        let tests = vec![
217            r#"ipaddr == 192.168.0.1"#,
218            r#"ipaddr == fd00::1"#,
219            r#"ipaddr in 192.168.0.0/24"#,
220            r#"ipaddr in fd00::/64"#,
221            r#"ipaddr not in 192.168.0.0/24"#,
222            r#"ipaddr not in fd00::/64"#,
223        ];
224        for input in tests {
225            let expression = parse(input).unwrap();
226            expression.validate(&SCHEMA).unwrap();
227        }
228
229        let failing_tests = vec![
230            r#"ipaddr == "abc""#,
231            r#"ipaddr == 123"#,
232            r#"ipaddr in 192.168.0.1"#,
233            r#"ipaddr in fd00::1"#,
234            r#"ipaddr == 192.168.0.0/24"#,
235            r#"ipaddr == fd00::/64"#,
236            r#"lower(ipaddr) == fd00::1"#,
237        ];
238        for input in failing_tests {
239            let expression = parse(input).unwrap();
240            assert!(expression.validate(&SCHEMA).is_err());
241        }
242    }
243
244    #[test]
245    fn int_lhs() {
246        let tests = vec![
247            r#"int == 123"#,
248            r#"int >= 123"#,
249            r#"int <= 123"#,
250            r#"int > 123"#,
251            r#"int < 123"#,
252        ];
253        for input in tests {
254            let expression = parse(input).unwrap();
255            expression.validate(&SCHEMA).unwrap();
256        }
257
258        let failing_tests = vec![
259            r#"int == "abc""#,
260            r#"int in 192.168.0.0/24"#,
261            r#"lower(int) == 123"#,
262        ];
263        for input in failing_tests {
264            let expression = parse(input).unwrap();
265            assert!(expression.validate(&SCHEMA).is_err());
266        }
267    }
268}