1use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value};
2use crate::schema::Schema;
3use std::collections::HashMap;
4
5type ValidationResult = Result<(), String>;
6
7pub trait Validate {
8 fn validate(&self, schema: &Schema) -> ValidationResult;
9}
10
11pub trait FieldCounter {
12 fn add_to_counter(&self, map: &mut HashMap<String, usize>);
13 fn remove_from_counter(&self, map: &mut HashMap<String, usize>);
14}
15
16impl FieldCounter for Expression {
17 fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
18 match self {
19 Expression::Logical(l) => match l.as_ref() {
20 LogicalExpression::And(l, r) => {
21 l.add_to_counter(map);
22 r.add_to_counter(map);
23 }
24 LogicalExpression::Or(l, r) => {
25 l.add_to_counter(map);
26 r.add_to_counter(map);
27 }
28 LogicalExpression::Not(r) => {
29 r.add_to_counter(map);
30 }
31 },
32 Expression::Predicate(p) => {
33 *map.entry(p.lhs.var_name.clone()).or_default() += 1;
34 }
35 }
36 }
37
38 fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
39 match self {
40 Expression::Logical(l) => match l.as_ref() {
41 LogicalExpression::And(l, r) => {
42 l.remove_from_counter(map);
43 r.remove_from_counter(map);
44 }
45 LogicalExpression::Or(l, r) => {
46 l.remove_from_counter(map);
47 r.remove_from_counter(map);
48 }
49 LogicalExpression::Not(r) => {
50 r.remove_from_counter(map);
51 }
52 },
53 Expression::Predicate(p) => {
54 let val = map.get_mut(&p.lhs.var_name).unwrap();
55 *val -= 1;
56
57 if *val == 0 {
58 assert!(map.remove(&p.lhs.var_name).is_some());
59 }
60 }
61 }
62 }
63}
64
65impl Validate for Expression {
66 fn validate(&self, schema: &Schema) -> ValidationResult {
67 match self {
68 Expression::Logical(l) => {
69 match l.as_ref() {
70 LogicalExpression::And(l, r) => {
71 l.validate(schema)?;
72 r.validate(schema)?;
73 }
74 LogicalExpression::Or(l, r) => {
75 l.validate(schema)?;
76 r.validate(schema)?;
77 }
78 LogicalExpression::Not(r) => {
79 r.validate(schema)?;
80 }
81 }
82
83 Ok(())
84 }
85 Expression::Predicate(p) => {
86 let lhs_type = p.lhs.my_type(schema);
88 if lhs_type.is_none() {
89 return Err("Unknown LHS field".to_string());
90 }
91 let lhs_type = lhs_type.unwrap();
92
93 if p.op != BinaryOperator::Regex && p.op != BinaryOperator::In && p.op != BinaryOperator::NotIn
96 && lhs_type != &p.rhs.my_type()
97 {
98 return Err(
99 "Type mismatch between the LHS and RHS values of predicate".to_string()
100 );
101 }
102
103 let (lower, _any) = p.lhs.get_transformations();
104
105 if lower && lhs_type != &Type::String {
107 return Err(
108 "lower-case transformation function only supported with String type fields"
109 .to_string(),
110 );
111 }
112
113 match p.op {
114 BinaryOperator::Equals | BinaryOperator::NotEquals => { Ok(()) }
115 BinaryOperator::Regex => {
116 if lhs_type == &Type::String {
118 Ok(())
119 } else {
120 Err("Regex operators only supports string operands".to_string())
121 }
122 },
123 BinaryOperator::Prefix | BinaryOperator::Postfix => {
124 match p.rhs {
125 Value::String(_) => {
126 Ok(())
127 }
128 _ => Err("Regex/Prefix/Postfix operators only supports string operands".to_string())
129 }
130 },
131 BinaryOperator::Greater | BinaryOperator::GreaterOrEqual | BinaryOperator::Less | BinaryOperator::LessOrEqual => {
132 match p.rhs {
133 Value::Int(_) => {
134 Ok(())
135 }
136 _ => Err("Greater/GreaterOrEqual/Lesser/LesserOrEqual operators only supports integer operands".to_string())
137 }
138 },
139 BinaryOperator::In | BinaryOperator::NotIn => {
140 match (lhs_type, &p.rhs,) {
142 (Type::IpAddr, Value::IpCidr(_)) => {
143 Ok(())
144 }
145 _ => Err("In/NotIn operators only supports IP in CIDR".to_string())
146 }
147 },
148 BinaryOperator::Contains => {
149 match p.rhs {
150 Value::String(_) => {
151 Ok(())
152 }
153 _ => Err("Contains operator only supports string operands".to_string())
154 }
155 }
156 }
157 }
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::parser::parse;
166 use lazy_static::lazy_static;
167
168 lazy_static! {
169 static ref SCHEMA: Schema = {
170 let mut s = Schema::default();
171 s.add_field("string", Type::String);
172 s.add_field("int", Type::Int);
173 s.add_field("ipaddr", Type::IpAddr);
174 s
175 };
176 }
177
178 #[test]
179 fn unknown_field() {
180 let expression = parse(r#"unkn == "abc""#).unwrap();
181 assert_eq!(
182 expression.validate(&SCHEMA).unwrap_err(),
183 "Unknown LHS field"
184 );
185 }
186
187 #[test]
188 fn string_lhs() {
189 let tests = vec![
190 r#"string == "abc""#,
191 r#"string != "abc""#,
192 r#"string ~ "abc""#,
193 r#"string ^= "abc""#,
194 r#"string =^ "abc""#,
195 r#"lower(string) =^ "abc""#,
196 ];
197 for input in tests {
198 let expression = parse(input).unwrap();
199 expression.validate(&SCHEMA).unwrap();
200 }
201
202 let failing_tests = vec![
203 r#"string == 192.168.0.1"#,
204 r#"string == 192.168.0.0/24"#,
205 r#"string == 123"#,
206 r#"string in "abc""#,
207 ];
208 for input in failing_tests {
209 let expression = parse(input).unwrap();
210 assert!(expression.validate(&SCHEMA).is_err());
211 }
212 }
213
214 #[test]
215 fn ipaddr_lhs() {
216 let tests = vec![
217 r#"ipaddr == 192.168.0.1"#,
218 r#"ipaddr == fd00::1"#,
219 r#"ipaddr in 192.168.0.0/24"#,
220 r#"ipaddr in fd00::/64"#,
221 r#"ipaddr not in 192.168.0.0/24"#,
222 r#"ipaddr not in fd00::/64"#,
223 ];
224 for input in tests {
225 let expression = parse(input).unwrap();
226 expression.validate(&SCHEMA).unwrap();
227 }
228
229 let failing_tests = vec![
230 r#"ipaddr == "abc""#,
231 r#"ipaddr == 123"#,
232 r#"ipaddr in 192.168.0.1"#,
233 r#"ipaddr in fd00::1"#,
234 r#"ipaddr == 192.168.0.0/24"#,
235 r#"ipaddr == fd00::/64"#,
236 r#"lower(ipaddr) == fd00::1"#,
237 ];
238 for input in failing_tests {
239 let expression = parse(input).unwrap();
240 assert!(expression.validate(&SCHEMA).is_err());
241 }
242 }
243
244 #[test]
245 fn int_lhs() {
246 let tests = vec![
247 r#"int == 123"#,
248 r#"int >= 123"#,
249 r#"int <= 123"#,
250 r#"int > 123"#,
251 r#"int < 123"#,
252 ];
253 for input in tests {
254 let expression = parse(input).unwrap();
255 expression.validate(&SCHEMA).unwrap();
256 }
257
258 let failing_tests = vec![
259 r#"int == "abc""#,
260 r#"int in 192.168.0.0/24"#,
261 r#"lower(int) == 123"#,
262 ];
263 for input in failing_tests {
264 let expression = parse(input).unwrap();
265 assert!(expression.validate(&SCHEMA).is_err());
266 }
267 }
268}