1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16 And(Vec<Expr<'a, F>>),
17 Or(Vec<Expr<'a, F>>),
18 Not(Box<Expr<'a, F>>),
19 Pred(Filter<'a, F>),
20 Compare {
21 left: ScalarExpr<F>,
22 op: CompareOp,
23 right: ScalarExpr<F>,
24 },
25}
26
27impl<'a, F> Expr<'a, F> {
28 #[inline]
30 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
31 Expr::And(fs.into_iter().map(Expr::Pred).collect())
32 }
33
34 #[inline]
36 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
37 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
38 }
39
40 #[allow(clippy::should_implement_trait)]
42 #[inline]
43 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
44 Expr::Not(Box::new(e))
45 }
46}
47
48#[derive(Clone, Debug)]
50pub enum ScalarExpr<F> {
51 Column(F),
52 Literal(Literal),
53 Binary {
54 left: Box<ScalarExpr<F>>,
55 op: BinaryOp,
56 right: Box<ScalarExpr<F>>,
57 },
58 Aggregate(AggregateCall<F>),
61}
62
63#[derive(Clone, Debug)]
65pub enum AggregateCall<F> {
66 CountStar,
67 Count(F),
68 Sum(F),
69 Min(F),
70 Max(F),
71 CountNulls(F),
72}
73
74impl<F> ScalarExpr<F> {
75 #[inline]
76 pub fn column(field: F) -> Self {
77 Self::Column(field)
78 }
79
80 #[inline]
81 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
82 Self::Literal(lit.into())
83 }
84
85 #[inline]
86 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
87 Self::Binary {
88 left: Box::new(left),
89 op,
90 right: Box::new(right),
91 }
92 }
93
94 #[inline]
95 pub fn aggregate(call: AggregateCall<F>) -> Self {
96 Self::Aggregate(call)
97 }
98}
99
100#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum BinaryOp {
103 Add,
104 Subtract,
105 Multiply,
106 Divide,
107 Modulo,
108}
109
110#[derive(Clone, Copy, Debug, Eq, PartialEq)]
112pub enum CompareOp {
113 Eq,
114 NotEq,
115 Lt,
116 LtEq,
117 Gt,
118 GtEq,
119}
120
121#[derive(Debug, Clone)]
123pub struct Filter<'a, F> {
124 pub field_id: F,
125 pub op: Operator<'a>,
126}
127
128#[derive(Debug, Clone)]
133pub enum Operator<'a> {
134 Equals(Literal),
135 Range {
136 lower: Bound<Literal>,
137 upper: Bound<Literal>,
138 },
139 GreaterThan(Literal),
140 GreaterThanOrEquals(Literal),
141 LessThan(Literal),
142 LessThanOrEquals(Literal),
143 In(&'a [Literal]),
144 StartsWith {
145 pattern: &'a str,
146 case_sensitive: bool,
147 },
148 EndsWith {
149 pattern: &'a str,
150 case_sensitive: bool,
151 },
152 Contains {
153 pattern: &'a str,
154 case_sensitive: bool,
155 },
156}
157
158impl<'a> Operator<'a> {
159 #[inline]
160 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
161 Operator::StartsWith {
162 pattern,
163 case_sensitive,
164 }
165 }
166
167 #[inline]
168 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
169 Operator::EndsWith {
170 pattern,
171 case_sensitive,
172 }
173 }
174
175 #[inline]
176 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
177 Operator::Contains {
178 pattern,
179 case_sensitive,
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn build_simple_exprs() {
190 let f1 = Filter {
191 field_id: 1,
192 op: Operator::Equals("abc".into()),
193 };
194 let f2 = Filter {
195 field_id: 2,
196 op: Operator::LessThan("zzz".into()),
197 };
198 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
199 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
200 let not_all = Expr::not(all);
201 match any {
202 Expr::Or(v) => assert_eq!(v.len(), 2),
203 _ => panic!("expected Or"),
204 }
205 match not_all {
206 Expr::Not(inner) => match *inner {
207 Expr::And(v) => assert_eq!(v.len(), 2),
208 _ => panic!("expected And inside Not"),
209 },
210 _ => panic!("expected Not"),
211 }
212 }
213
214 #[test]
215 fn complex_nested_shape() {
216 let f1 = Filter {
221 field_id: 1u32,
222 op: Operator::Equals("a".into()),
223 };
224 let f2 = Filter {
225 field_id: 2u32,
226 op: Operator::LessThan("zzz".into()),
227 };
228 let in_values = ["x".into(), "y".into(), "z".into()];
229 let f3 = Filter {
230 field_id: 3u32,
231 op: Operator::In(&in_values),
232 };
233 let f4 = Filter {
234 field_id: 4u32,
235 op: Operator::starts_with("pre", true),
236 };
237
238 let left = Expr::And(vec![
240 Expr::Pred(f1.clone()),
241 Expr::Or(vec![
242 Expr::Pred(f2.clone()),
243 Expr::not(Expr::Pred(f3.clone())),
244 ]),
245 ]);
246 let right = Expr::And(vec![
247 Expr::not(Expr::Pred(f1.clone())),
248 Expr::Pred(f4.clone()),
249 ]);
250 let top = Expr::Or(vec![left, right]);
251
252 match top {
254 Expr::Or(branches) => {
255 assert_eq!(branches.len(), 2);
256 match &branches[0] {
257 Expr::And(v) => {
258 assert_eq!(v.len(), 2);
259 match &v[0] {
261 Expr::Pred(Filter { field_id, .. }) => {
262 assert_eq!(*field_id, 1)
263 }
264 _ => panic!("expected Pred(f1) in left-AND[0]"),
265 }
266 match &v[1] {
267 Expr::Or(or_vec) => {
268 assert_eq!(or_vec.len(), 2);
269 match &or_vec[0] {
270 Expr::Pred(Filter { field_id, .. }) => {
271 assert_eq!(*field_id, 2)
272 }
273 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
274 }
275 match &or_vec[1] {
276 Expr::Not(inner) => match inner.as_ref() {
277 Expr::Pred(Filter { field_id, .. }) => {
278 assert_eq!(*field_id, 3)
279 }
280 _ => panic!(
281 "expected Not(Pred(f3)) in \
282 left-AND[1].OR[1]"
283 ),
284 },
285 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
286 }
287 }
288 _ => panic!("expected OR in left-AND[1]"),
289 }
290 }
291 _ => panic!("expected AND on left branch of top OR"),
292 }
293 match &branches[1] {
294 Expr::And(v) => {
295 assert_eq!(v.len(), 2);
296 match &v[0] {
298 Expr::Not(inner) => match inner.as_ref() {
299 Expr::Pred(Filter { field_id, .. }) => {
300 assert_eq!(*field_id, 1)
301 }
302 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
303 },
304 _ => panic!("expected Not(...) in right-AND[0]"),
305 }
306 match &v[1] {
307 Expr::Pred(Filter { field_id, .. }) => {
308 assert_eq!(*field_id, 4)
309 }
310 _ => panic!("expected Pred(f4) in right-AND[1]"),
311 }
312 }
313 _ => panic!("expected AND on right branch of top OR"),
314 }
315 }
316 _ => panic!("expected top-level OR"),
317 }
318 }
319
320 #[test]
321 fn range_bounds_roundtrip() {
322 let f = Filter {
324 field_id: 7u32,
325 op: Operator::Range {
326 lower: Bound::Included("aaa".into()),
327 upper: Bound::Excluded("bbb".into()),
328 },
329 };
330
331 match &f.op {
332 Operator::Range { lower, upper } => {
333 if let Bound::Included(l) = lower {
334 assert_eq!(*l, Literal::String("aaa".to_string()));
335 } else {
336 panic!("lower bound should be Included");
337 }
338
339 if let Bound::Excluded(u) = upper {
340 assert_eq!(*u, Literal::String("bbb".to_string()));
341 } else {
342 panic!("upper bound should be Excluded");
343 }
344 }
345 _ => panic!("expected Range operator"),
346 }
347 }
348
349 #[test]
350 fn helper_builders_preserve_structure_and_order() {
351 let f1 = Filter {
352 field_id: 1u32,
353 op: Operator::Equals("a".into()),
354 };
355 let f2 = Filter {
356 field_id: 2u32,
357 op: Operator::Equals("b".into()),
358 };
359 let f3 = Filter {
360 field_id: 3u32,
361 op: Operator::Equals("c".into()),
362 };
363
364 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
365 match and_expr {
366 Expr::And(v) => {
367 assert_eq!(v.len(), 3);
368 match &v[0] {
370 Expr::Pred(Filter { field_id, .. }) => {
371 assert_eq!(*field_id, 1)
372 }
373 _ => panic!(),
374 }
375 match &v[1] {
376 Expr::Pred(Filter { field_id, .. }) => {
377 assert_eq!(*field_id, 2)
378 }
379 _ => panic!(),
380 }
381 match &v[2] {
382 Expr::Pred(Filter { field_id, .. }) => {
383 assert_eq!(*field_id, 3)
384 }
385 _ => panic!(),
386 }
387 }
388 _ => panic!("expected And"),
389 }
390
391 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
392 match or_expr {
393 Expr::Or(v) => {
394 assert_eq!(v.len(), 3);
395 match &v[0] {
397 Expr::Pred(Filter { field_id, .. }) => {
398 assert_eq!(*field_id, 3)
399 }
400 _ => panic!(),
401 }
402 match &v[1] {
403 Expr::Pred(Filter { field_id, .. }) => {
404 assert_eq!(*field_id, 2)
405 }
406 _ => panic!(),
407 }
408 match &v[2] {
409 Expr::Pred(Filter { field_id, .. }) => {
410 assert_eq!(*field_id, 1)
411 }
412 _ => panic!(),
413 }
414 }
415 _ => panic!("expected Or"),
416 }
417 }
418
419 #[test]
420 fn set_and_pattern_ops_hold_borrowed_slices() {
421 let in_values = ["aa".into(), "bb".into(), "cc".into()];
422 let f_in = Filter {
423 field_id: 9u32,
424 op: Operator::In(&in_values),
425 };
426 match f_in.op {
427 Operator::In(arr) => {
428 assert_eq!(arr.len(), 3);
429 }
430 _ => panic!("expected In"),
431 }
432
433 let f_sw = Filter {
434 field_id: 10u32,
435 op: Operator::starts_with("pre", true),
436 };
437 let f_ew = Filter {
438 field_id: 11u32,
439 op: Operator::ends_with("suf", true),
440 };
441 let f_ct = Filter {
442 field_id: 12u32,
443 op: Operator::contains("mid", true),
444 };
445
446 match f_sw.op {
447 Operator::StartsWith {
448 pattern: b,
449 case_sensitive,
450 } => {
451 assert_eq!(b, "pre");
452 assert!(case_sensitive);
453 }
454 _ => panic!(),
455 }
456 match f_ew.op {
457 Operator::EndsWith {
458 pattern: b,
459 case_sensitive,
460 } => {
461 assert_eq!(b, "suf");
462 assert!(case_sensitive);
463 }
464 _ => panic!(),
465 }
466 match f_ct.op {
467 Operator::Contains {
468 pattern: b,
469 case_sensitive,
470 } => {
471 assert_eq!(b, "mid");
472 assert!(case_sensitive);
473 }
474 _ => panic!(),
475 }
476 }
477
478 #[test]
479 fn generic_field_id_works_with_strings() {
480 let f1 = Filter {
482 field_id: "name",
483 op: Operator::Equals("alice".into()),
484 };
485 let f2 = Filter {
486 field_id: "status",
487 op: Operator::GreaterThanOrEquals("active".into()),
488 };
489 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
490
491 match expr {
492 Expr::And(v) => {
493 assert_eq!(v.len(), 2);
494 match &v[0] {
495 Expr::Pred(Filter { field_id, .. }) => {
496 assert_eq!(*field_id, "name")
497 }
498 _ => panic!("expected Pred(name)"),
499 }
500 match &v[1] {
501 Expr::Pred(Filter { field_id, .. }) => {
502 assert_eq!(*field_id, "status")
503 }
504 _ => panic!("expected Pred(status)"),
505 }
506 }
507 _ => panic!("expected And"),
508 }
509 }
510
511 #[test]
512 fn very_deep_not_chain() {
513 let base = Expr::Pred(Filter {
515 field_id: 42u32,
516 op: Operator::Equals("x".into()),
517 });
518 let mut expr = base;
519 for _ in 0..64 {
520 expr = Expr::not(expr);
521 }
522
523 let mut count = 0usize;
525 let mut cur = &expr;
526 loop {
527 match cur {
528 Expr::Not(inner) => {
529 count += 1;
530 cur = inner;
531 }
532 Expr::Pred(Filter { field_id, .. }) => {
533 assert_eq!(*field_id, 42);
534 break;
535 }
536 _ => panic!("unexpected node inside deep NOT chain"),
537 }
538 }
539 assert_eq!(count, 64);
540 }
541
542 #[test]
543 fn literal_construction() {
544 let f = Filter {
545 field_id: "my_u64_col",
546 op: Operator::Range {
547 lower: Bound::Included(150.into()),
548 upper: Bound::Excluded(300.into()),
549 },
550 };
551
552 match f.op {
553 Operator::Range { lower, upper } => {
554 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
555 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
556 }
557 _ => panic!("Expected a range operator"),
558 }
559
560 let f2 = Filter {
561 field_id: "my_str_col",
562 op: Operator::Equals("hello".into()),
563 };
564
565 match f2.op {
566 Operator::Equals(lit) => {
567 assert_eq!(lit, Literal::String("hello".to_string()));
568 }
569 _ => panic!("Expected an equals operator"),
570 }
571 }
572}