1use crate::schema::Schema;
2use cidr::IpCidr;
3use regex::Regex;
4use std::net::IpAddr;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
10#[derive(Debug)]
11pub enum Expression {
12 Logical(Box<LogicalExpression>),
13 Predicate(Predicate),
14}
15
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17#[derive(Debug)]
18pub enum LogicalExpression {
19 And(Expression, Expression),
20 Or(Expression, Expression),
21 Not(Expression),
22}
23
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25#[derive(Debug, PartialEq, Eq)]
26pub enum LhsTransformations {
27 Lower,
28 Any,
29}
30
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[derive(Debug, PartialEq, Eq)]
33pub enum BinaryOperator {
34 Equals, NotEquals, Regex, Prefix, Postfix, Greater, GreaterOrEqual, Less, LessOrEqual, In, NotIn, Contains, }
47
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[derive(Debug, Clone)]
50pub enum Value {
51 String(String),
52 IpCidr(IpCidr),
53 IpAddr(IpAddr),
54 Int(i64),
55 #[cfg_attr(feature = "serde", serde(with = "serde_regex"))]
56 Regex(Regex),
57}
58
59impl PartialEq for Value {
60 fn eq(&self, other: &Self) -> bool {
61 match (self, other) {
62 (Self::Regex(_), _) | (_, Self::Regex(_)) => {
63 panic!("Regexes can not be compared using eq")
64 }
65 (Self::String(s1), Self::String(s2)) => s1 == s2,
66 (Self::IpCidr(i1), Self::IpCidr(i2)) => i1 == i2,
67 (Self::IpAddr(i1), Self::IpAddr(i2)) => i1 == i2,
68 (Self::Int(i1), Self::Int(i2)) => i1 == i2,
69 _ => false,
70 }
71 }
72}
73
74impl Value {
75 pub fn my_type(&self) -> Type {
76 match self {
77 Value::String(_) => Type::String,
78 Value::IpCidr(_) => Type::IpCidr,
79 Value::IpAddr(_) => Type::IpAddr,
80 Value::Int(_) => Type::Int,
81 Value::Regex(_) => Type::Regex,
82 }
83 }
84}
85
86impl From<String> for Value {
87 fn from(v: String) -> Self {
88 Value::String(v)
89 }
90}
91
92#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93#[derive(Debug, Eq, PartialEq)]
94#[repr(C)]
95pub enum Type {
96 String,
97 IpCidr,
98 IpAddr,
99 Int,
100 Regex,
101}
102
103#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
104#[derive(Debug)]
105pub struct Lhs {
106 pub var_name: String,
107 pub transformations: Vec<LhsTransformations>,
108}
109
110impl Lhs {
111 pub fn my_type<'a>(&self, schema: &'a Schema) -> Option<&'a Type> {
112 schema.type_of(&self.var_name)
113 }
114
115 pub fn get_transformations(&self) -> (bool, bool) {
116 let mut lower = false;
117 let mut any = false;
118
119 self.transformations.iter().for_each(|i| match i {
120 LhsTransformations::Any => any = true,
121 LhsTransformations::Lower => lower = true,
122 });
123
124 (lower, any)
125 }
126}
127
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129#[derive(Debug)]
130pub struct Predicate {
131 pub lhs: Lhs,
132 pub rhs: Value,
133 pub op: BinaryOperator,
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::parser::parse;
140 use std::fmt;
141
142 impl fmt::Display for Expression {
143 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144 write!(
145 f,
146 "{}",
147 match self {
148 Expression::Logical(logical) => logical.to_string(),
149 Expression::Predicate(predicate) => predicate.to_string(),
150 }
151 )
152 }
153 }
154
155 impl fmt::Display for LogicalExpression {
156 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157 write!(
158 f,
159 "{}",
160 match self {
161 LogicalExpression::And(left, right) => {
162 format!("({} && {})", left, right)
163 }
164 LogicalExpression::Or(left, right) => {
165 format!("({} || {})", left, right)
166 }
167 LogicalExpression::Not(e) => {
168 format!("!({})", e)
169 }
170 }
171 )
172 }
173 }
174
175 impl fmt::Display for LhsTransformations {
176 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
177 write!(
178 f,
179 "{}",
180 match self {
181 LhsTransformations::Lower => "lower".to_string(),
182 LhsTransformations::Any => "any".to_string(),
183 }
184 )
185 }
186 }
187
188 impl fmt::Display for Value {
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 match self {
191 Value::String(s) => write!(f, "\"{}\"", s),
192 Value::IpCidr(cidr) => write!(f, "{}", cidr),
193 Value::IpAddr(addr) => write!(f, "{}", addr),
194 Value::Int(i) => write!(f, "{}", i),
195 Value::Regex(re) => write!(f, "\"{}\"", re),
196 }
197 }
198 }
199
200 impl fmt::Display for Lhs {
201 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202 let mut s = self.var_name.to_string();
203 for transformation in &self.transformations {
204 s = format!("{}({})", transformation, s);
205 }
206 write!(f, "{}", s)
207 }
208 }
209
210 impl fmt::Display for BinaryOperator {
211 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212 use BinaryOperator::*;
213
214 write!(
215 f,
216 "{}",
217 match self {
218 Equals => "==",
219 NotEquals => "!=",
220 Regex => "~",
221 Prefix => "^=",
222 Postfix => "=^",
223 Greater => ">",
224 GreaterOrEqual => ">=",
225 Less => "<",
226 LessOrEqual => "<=",
227 In => "in",
228 NotIn => "not in",
229 Contains => "contains",
230 }
231 )
232 }
233 }
234
235 impl fmt::Display for Predicate {
236 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237 write!(f, "({} {} {})", self.lhs, self.op, self.rhs)
238 }
239 }
240
241 #[test]
242 fn expr_op_and_prec() {
243 let tests = vec![
244 ("a > 0", "(a > 0)"),
245 ("a in \"abc\"", "(a in \"abc\")"),
246 ("a == 1 && b != 2", "((a == 1) && (b != 2))"),
247 (
248 "a ^= \"1\" && b =^ \"2\" || c >= 3",
249 "((a ^= \"1\") && ((b =^ \"2\") || (c >= 3)))",
250 ),
251 (
252 "a == 1 && b != 2 || c >= 3",
253 "((a == 1) && ((b != 2) || (c >= 3)))",
254 ),
255 (
256 "a > 1 || b < 2 && c <= 3 || d not in \"foo\"",
257 "(((a > 1) || (b < 2)) && ((c <= 3) || (d not in \"foo\")))",
258 ),
259 (
260 "a > 1 || ((b < 2) && (c <= 3)) || d not in \"foo\"",
261 "(((a > 1) || ((b < 2) && (c <= 3))) || (d not in \"foo\"))",
262 ),
263 ("!(a == 1)", "!((a == 1))"),
264 (
265 "!(a == 1) && b == 2 && !(c == 3) && d >= 4",
266 "(((!((a == 1)) && (b == 2)) && !((c == 3))) && (d >= 4))",
267 ),
268 (
269 "!(a == 1 || b == 2 && c == 3) && d == 4",
270 "(!((((a == 1) || (b == 2)) && (c == 3))) && (d == 4))",
271 ),
272 ];
273 for (input, expected) in tests {
274 let result = parse(input).unwrap();
275 assert_eq!(result.to_string(), expected);
276 }
277 }
278
279 #[test]
280 fn expr_var_name_and_ip() {
281 let tests = vec![
282 ("kong.foo in 1.1.1.1", "(kong.foo in 1.1.1.1)"),
284 (
286 "kong.foo.foo2 in 10.0.0.0/24",
287 "(kong.foo.foo2 in 10.0.0.0/24)",
288 ),
289 (
291 "kong.foo.foo3 in 2001:db8::/32",
292 "(kong.foo.foo3 in 2001:db8::/32)",
293 ),
294 (
296 "kong.foo.foo4 in 2001:db8::/32",
297 "(kong.foo.foo4 in 2001:db8::/32)",
298 ),
299 ];
300 for (input, expected) in tests {
301 let result = parse(input).unwrap();
302 assert_eq!(result.to_string(), expected);
303 }
304 }
305
306 #[test]
307 fn expr_regex() {
308 let tests = vec![
309 (
311 "kong.foo.foo5 ~ \"^foo.*$\"",
312 "(kong.foo.foo5 ~ \"^foo.*$\")",
313 ),
314 (
316 "kong.foo.foo6 ~ \"^foo.*$\"",
317 "(kong.foo.foo6 ~ \"^foo.*$\")",
318 ),
319 ];
320 for (input, expected) in tests {
321 let result = parse(input).unwrap();
322 assert_eq!(result.to_string(), expected);
323 }
324 }
325
326 #[test]
327 fn expr_digits() {
328 let tests = vec![
329 ("kong.foo.foo7 == 123", "(kong.foo.foo7 == 123)"),
331 ("kong.foo.foo8 == 0x123", "(kong.foo.foo8 == 291)"),
333 ("kong.foo.foo9 == 0123", "(kong.foo.foo9 == 83)"),
335 ("kong.foo.foo10 == -123", "(kong.foo.foo10 == -123)"),
337 ("kong.foo.foo11 == -0x123", "(kong.foo.foo11 == -291)"),
339 ("kong.foo.foo12 == -0123", "(kong.foo.foo12 == -83)"),
341 ];
342 for (input, expected) in tests {
343 let result = parse(input).unwrap();
344 assert_eq!(result.to_string(), expected);
345 }
346 }
347
348 #[test]
349 fn expr_transformations() {
350 let tests = vec![
351 (
353 "lower(kong.foo.foo13) == \"foo\"",
354 "(lower(kong.foo.foo13) == \"foo\")",
355 ),
356 (
358 "any(kong.foo.foo14) == \"foo\"",
359 "(any(kong.foo.foo14) == \"foo\")",
360 ),
361 ];
362 for (input, expected) in tests {
363 let result = parse(input).unwrap();
364 assert_eq!(result.to_string(), expected);
365 }
366 }
367
368 #[test]
369 fn expr_transformations_nested() {
370 let tests = vec![
371 (
373 "lower(lower(kong.foo.foo15)) == \"foo\"",
374 "(lower(lower(kong.foo.foo15)) == \"foo\")",
375 ),
376 (
378 "lower(any(kong.foo.foo16)) == \"foo\"",
379 "(lower(any(kong.foo.foo16)) == \"foo\")",
380 ),
381 (
383 "any(lower(kong.foo.foo17)) == \"foo\"",
384 "(any(lower(kong.foo.foo17)) == \"foo\")",
385 ),
386 (
388 "any(any(kong.foo.foo18)) == \"foo\"",
389 "(any(any(kong.foo.foo18)) == \"foo\")",
390 ),
391 ];
392 for (input, expected) in tests {
393 let result = parse(input).unwrap();
394 assert_eq!(result.to_string(), expected);
395 }
396 }
397
398 #[test]
399 fn str_unicode_test() {
400 let tests = vec![
401 ("t_msg in \"你好\"", "(t_msg in \"你好\")"),
403 ("t_msg in \"\u{4f60}\u{597d}\"", "(t_msg in \"你好\")"),
405 ];
406 for (input, expected) in tests {
407 let result = parse(input).unwrap();
408 assert_eq!(result.to_string(), expected);
409 }
410 }
411
412 #[test]
413 fn rawstr_test() {
414 let tests = vec![
415 (r##"a == r#"/path/to/\d+"#"##, r#"(a == "/path/to/\d+")"#),
417 (r##"a == r#"/path/to/\n+"#"##, r#"(a == "/path/to/\n+")"#),
419 ];
420 for (input, expected) in tests {
421 let result = parse(input).unwrap();
422 assert_eq!(result.to_string(), expected);
423 }
424 }
425}