1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17 And(Vec<Expr<'a, F>>),
18 Or(Vec<Expr<'a, F>>),
19 Not(Box<Expr<'a, F>>),
20 Pred(Filter<'a, F>),
21 Compare {
22 left: ScalarExpr<F>,
23 op: CompareOp,
24 right: ScalarExpr<F>,
25 },
26 InList {
27 expr: ScalarExpr<F>,
28 list: Vec<ScalarExpr<F>>,
29 negated: bool,
30 },
31 IsNull {
35 expr: ScalarExpr<F>,
36 negated: bool,
37 },
38 Literal(bool),
41 Exists(SubqueryExpr),
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52 pub id: SubqueryId,
54 pub negated: bool,
56}
57
58#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61 pub id: SubqueryId,
63 pub data_type: DataType,
65}
66
67impl<'a, F> Expr<'a, F> {
68 #[inline]
70 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
71 Expr::And(fs.into_iter().map(Expr::Pred).collect())
72 }
73
74 #[inline]
76 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
77 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
78 }
79
80 #[allow(clippy::should_implement_trait)]
82 #[inline]
83 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
84 Expr::Not(Box::new(e))
85 }
86
87 pub fn is_full_range_for(&self, expected_field: &F) -> bool
89 where
90 F: PartialEq,
91 {
92 matches!(
93 self,
94 Expr::Pred(Filter {
95 field_id,
96 op:
97 Operator::Range {
98 lower: Bound::Unbounded,
99 upper: Bound::Unbounded,
100 },
101 }) if field_id == expected_field
102 )
103 }
104}
105
106#[derive(Clone, Debug)]
108pub enum ScalarExpr<F> {
109 Column(F),
110 Literal(Literal),
111 Binary {
112 left: Box<ScalarExpr<F>>,
113 op: BinaryOp,
114 right: Box<ScalarExpr<F>>,
115 },
116 Not(Box<ScalarExpr<F>>),
118 IsNull {
121 expr: Box<ScalarExpr<F>>,
122 negated: bool,
123 },
124 Aggregate(AggregateCall<F>),
127 GetField {
131 base: Box<ScalarExpr<F>>,
132 field_name: String,
133 },
134 Cast {
136 expr: Box<ScalarExpr<F>>,
137 data_type: DataType,
138 },
139 Compare {
141 left: Box<ScalarExpr<F>>,
142 op: CompareOp,
143 right: Box<ScalarExpr<F>>,
144 },
145 Coalesce(Vec<ScalarExpr<F>>),
147 ScalarSubquery(ScalarSubqueryExpr),
149 Case {
151 operand: Option<Box<ScalarExpr<F>>>,
153 branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
155 else_expr: Option<Box<ScalarExpr<F>>>,
157 },
158 Random,
163}
164
165#[derive(Clone, Debug)]
170pub enum AggregateCall<F> {
171 CountStar,
172 Count {
173 expr: Box<ScalarExpr<F>>,
174 distinct: bool,
175 },
176 Sum {
177 expr: Box<ScalarExpr<F>>,
178 distinct: bool,
179 },
180 Total {
181 expr: Box<ScalarExpr<F>>,
182 distinct: bool,
183 },
184 Avg {
185 expr: Box<ScalarExpr<F>>,
186 distinct: bool,
187 },
188 Min(Box<ScalarExpr<F>>),
189 Max(Box<ScalarExpr<F>>),
190 CountNulls(Box<ScalarExpr<F>>),
191 GroupConcat {
192 expr: Box<ScalarExpr<F>>,
193 distinct: bool,
194 separator: Option<String>,
195 },
196}
197
198impl<F> ScalarExpr<F> {
199 #[inline]
200 pub fn column(field: F) -> Self {
201 Self::Column(field)
202 }
203
204 #[inline]
205 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
206 Self::Literal(lit.into())
207 }
208
209 #[inline]
210 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
211 Self::Binary {
212 left: Box::new(left),
213 op,
214 right: Box::new(right),
215 }
216 }
217
218 #[inline]
219 pub fn logical_not(expr: Self) -> Self {
220 Self::Not(Box::new(expr))
221 }
222
223 #[inline]
224 pub fn is_null(expr: Self, negated: bool) -> Self {
225 Self::IsNull {
226 expr: Box::new(expr),
227 negated,
228 }
229 }
230
231 #[inline]
232 pub fn aggregate(call: AggregateCall<F>) -> Self {
233 Self::Aggregate(call)
234 }
235
236 #[inline]
237 pub fn get_field(base: Self, field_name: String) -> Self {
238 Self::GetField {
239 base: Box::new(base),
240 field_name,
241 }
242 }
243
244 #[inline]
245 pub fn cast(expr: Self, data_type: DataType) -> Self {
246 Self::Cast {
247 expr: Box::new(expr),
248 data_type,
249 }
250 }
251
252 #[inline]
253 pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
254 Self::Compare {
255 left: Box::new(left),
256 op,
257 right: Box::new(right),
258 }
259 }
260
261 #[inline]
262 pub fn coalesce(exprs: Vec<Self>) -> Self {
263 Self::Coalesce(exprs)
264 }
265
266 #[inline]
267 pub fn scalar_subquery(id: SubqueryId, data_type: DataType) -> Self {
268 Self::ScalarSubquery(ScalarSubqueryExpr { id, data_type })
269 }
270
271 #[inline]
272 pub fn case(
273 operand: Option<Self>,
274 branches: Vec<(Self, Self)>,
275 else_expr: Option<Self>,
276 ) -> Self {
277 Self::Case {
278 operand: operand.map(Box::new),
279 branches,
280 else_expr: else_expr.map(Box::new),
281 }
282 }
283
284 #[inline]
285 pub fn random() -> Self {
286 Self::Random
287 }
288}
289
290#[derive(Clone, Copy, Debug, Eq, PartialEq)]
292pub enum BinaryOp {
293 Add,
294 Subtract,
295 Multiply,
296 Divide,
297 Modulo,
298 And,
299 Or,
300 BitwiseShiftLeft,
301 BitwiseShiftRight,
302}
303
304impl BinaryOp {
305 #[inline]
306 pub fn as_str(&self) -> &'static str {
307 match self {
308 BinaryOp::Add => "+",
309 BinaryOp::Subtract => "-",
310 BinaryOp::Multiply => "*",
311 BinaryOp::Divide => "/",
312 BinaryOp::Modulo => "%",
313 BinaryOp::And => "AND",
314 BinaryOp::Or => "OR",
315 BinaryOp::BitwiseShiftLeft => "<<",
316 BinaryOp::BitwiseShiftRight => ">>",
317 }
318 }
319}
320
321#[derive(Clone, Copy, Debug, Eq, PartialEq)]
323pub enum CompareOp {
324 Eq,
325 NotEq,
326 Lt,
327 LtEq,
328 Gt,
329 GtEq,
330}
331
332impl CompareOp {
333 #[inline]
334 pub fn as_str(&self) -> &'static str {
335 match self {
336 CompareOp::Eq => "=",
337 CompareOp::NotEq => "!=",
338 CompareOp::Lt => "<",
339 CompareOp::LtEq => "<=",
340 CompareOp::Gt => ">",
341 CompareOp::GtEq => ">=",
342 }
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct Filter<'a, F> {
349 pub field_id: F,
350 pub op: Operator<'a>,
351}
352
353#[derive(Debug, Clone)]
358pub enum Operator<'a> {
359 Equals(Literal),
360 Range {
361 lower: Bound<Literal>,
362 upper: Bound<Literal>,
363 },
364 GreaterThan(Literal),
365 GreaterThanOrEquals(Literal),
366 LessThan(Literal),
367 LessThanOrEquals(Literal),
368 In(&'a [Literal]),
369 StartsWith {
370 pattern: String,
371 case_sensitive: bool,
372 },
373 EndsWith {
374 pattern: String,
375 case_sensitive: bool,
376 },
377 Contains {
378 pattern: String,
379 case_sensitive: bool,
380 },
381 IsNull,
382 IsNotNull,
383}
384
385impl<'a> Operator<'a> {
386 #[inline]
387 pub fn starts_with(pattern: String, case_sensitive: bool) -> Self {
388 Operator::StartsWith {
389 pattern,
390 case_sensitive,
391 }
392 }
393
394 #[inline]
395 pub fn ends_with(pattern: String, case_sensitive: bool) -> Self {
396 Operator::EndsWith {
397 pattern,
398 case_sensitive,
399 }
400 }
401
402 #[inline]
403 pub fn contains(pattern: String, case_sensitive: bool) -> Self {
404 Operator::Contains {
405 pattern,
406 case_sensitive,
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn build_simple_exprs() {
417 let f1 = Filter {
418 field_id: 1,
419 op: Operator::Equals("abc".into()),
420 };
421 let f2 = Filter {
422 field_id: 2,
423 op: Operator::LessThan("zzz".into()),
424 };
425 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
426 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
427 let not_all = Expr::not(all);
428 match any {
429 Expr::Or(v) => assert_eq!(v.len(), 2),
430 _ => panic!("expected Or"),
431 }
432 match not_all {
433 Expr::Not(inner) => match *inner {
434 Expr::And(v) => assert_eq!(v.len(), 2),
435 _ => panic!("expected And inside Not"),
436 },
437 _ => panic!("expected Not"),
438 }
439 }
440
441 #[test]
442 fn complex_nested_shape() {
443 let f1 = Filter {
448 field_id: 1u32,
449 op: Operator::Equals("a".into()),
450 };
451 let f2 = Filter {
452 field_id: 2u32,
453 op: Operator::LessThan("zzz".into()),
454 };
455 let in_values = ["x".into(), "y".into(), "z".into()];
456 let f3 = Filter {
457 field_id: 3u32,
458 op: Operator::In(&in_values),
459 };
460 let f4 = Filter {
461 field_id: 4u32,
462 op: Operator::starts_with("pre".to_string(), true),
463 };
464
465 let left = Expr::And(vec![
467 Expr::Pred(f1.clone()),
468 Expr::Or(vec![
469 Expr::Pred(f2.clone()),
470 Expr::not(Expr::Pred(f3.clone())),
471 ]),
472 ]);
473 let right = Expr::And(vec![
474 Expr::not(Expr::Pred(f1.clone())),
475 Expr::Pred(f4.clone()),
476 ]);
477 let top = Expr::Or(vec![left, right]);
478
479 match top {
481 Expr::Or(branches) => {
482 assert_eq!(branches.len(), 2);
483 match &branches[0] {
484 Expr::And(v) => {
485 assert_eq!(v.len(), 2);
486 match &v[0] {
488 Expr::Pred(Filter { field_id, .. }) => {
489 assert_eq!(*field_id, 1)
490 }
491 _ => panic!("expected Pred(f1) in left-AND[0]"),
492 }
493 match &v[1] {
494 Expr::Or(or_vec) => {
495 assert_eq!(or_vec.len(), 2);
496 match &or_vec[0] {
497 Expr::Pred(Filter { field_id, .. }) => {
498 assert_eq!(*field_id, 2)
499 }
500 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
501 }
502 match &or_vec[1] {
503 Expr::Not(inner) => match inner.as_ref() {
504 Expr::Pred(Filter { field_id, .. }) => {
505 assert_eq!(*field_id, 3)
506 }
507 _ => panic!(
508 "expected Not(Pred(f3)) in \
509 left-AND[1].OR[1]"
510 ),
511 },
512 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
513 }
514 }
515 _ => panic!("expected OR in left-AND[1]"),
516 }
517 }
518 _ => panic!("expected AND on left branch of top OR"),
519 }
520 match &branches[1] {
521 Expr::And(v) => {
522 assert_eq!(v.len(), 2);
523 match &v[0] {
525 Expr::Not(inner) => match inner.as_ref() {
526 Expr::Pred(Filter { field_id, .. }) => {
527 assert_eq!(*field_id, 1)
528 }
529 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
530 },
531 _ => panic!("expected Not(...) in right-AND[0]"),
532 }
533 match &v[1] {
534 Expr::Pred(Filter { field_id, .. }) => {
535 assert_eq!(*field_id, 4)
536 }
537 _ => panic!("expected Pred(f4) in right-AND[1]"),
538 }
539 }
540 _ => panic!("expected AND on right branch of top OR"),
541 }
542 }
543 _ => panic!("expected top-level OR"),
544 }
545 }
546
547 #[test]
548 fn range_bounds_roundtrip() {
549 let f = Filter {
551 field_id: 7u32,
552 op: Operator::Range {
553 lower: Bound::Included("aaa".into()),
554 upper: Bound::Excluded("bbb".into()),
555 },
556 };
557
558 match &f.op {
559 Operator::Range { lower, upper } => {
560 if let Bound::Included(l) = lower {
561 assert_eq!(*l, Literal::String("aaa".to_string()));
562 } else {
563 panic!("lower bound should be Included");
564 }
565
566 if let Bound::Excluded(u) = upper {
567 assert_eq!(*u, Literal::String("bbb".to_string()));
568 } else {
569 panic!("upper bound should be Excluded");
570 }
571 }
572 _ => panic!("expected Range operator"),
573 }
574 }
575
576 #[test]
577 fn helper_builders_preserve_structure_and_order() {
578 let f1 = Filter {
579 field_id: 1u32,
580 op: Operator::Equals("a".into()),
581 };
582 let f2 = Filter {
583 field_id: 2u32,
584 op: Operator::Equals("b".into()),
585 };
586 let f3 = Filter {
587 field_id: 3u32,
588 op: Operator::Equals("c".into()),
589 };
590
591 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
592 match and_expr {
593 Expr::And(v) => {
594 assert_eq!(v.len(), 3);
595 match &v[0] {
597 Expr::Pred(Filter { field_id, .. }) => {
598 assert_eq!(*field_id, 1)
599 }
600 _ => panic!(),
601 }
602 match &v[1] {
603 Expr::Pred(Filter { field_id, .. }) => {
604 assert_eq!(*field_id, 2)
605 }
606 _ => panic!(),
607 }
608 match &v[2] {
609 Expr::Pred(Filter { field_id, .. }) => {
610 assert_eq!(*field_id, 3)
611 }
612 _ => panic!(),
613 }
614 }
615 _ => panic!("expected And"),
616 }
617
618 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
619 match or_expr {
620 Expr::Or(v) => {
621 assert_eq!(v.len(), 3);
622 match &v[0] {
624 Expr::Pred(Filter { field_id, .. }) => {
625 assert_eq!(*field_id, 3)
626 }
627 _ => panic!(),
628 }
629 match &v[1] {
630 Expr::Pred(Filter { field_id, .. }) => {
631 assert_eq!(*field_id, 2)
632 }
633 _ => panic!(),
634 }
635 match &v[2] {
636 Expr::Pred(Filter { field_id, .. }) => {
637 assert_eq!(*field_id, 1)
638 }
639 _ => panic!(),
640 }
641 }
642 _ => panic!("expected Or"),
643 }
644 }
645
646 #[test]
647 fn set_and_pattern_ops_hold_borrowed_slices() {
648 let in_values = ["aa".into(), "bb".into(), "cc".into()];
649 let f_in = Filter {
650 field_id: 9u32,
651 op: Operator::In(&in_values),
652 };
653 match f_in.op {
654 Operator::In(arr) => {
655 assert_eq!(arr.len(), 3);
656 }
657 _ => panic!("expected In"),
658 }
659
660 let f4 = Filter {
661 field_id: 4u32,
662 op: Operator::starts_with("pre".to_string(), true),
663 };
664 let f5 = Filter {
665 field_id: 5u32,
666 op: Operator::ends_with("suf".to_string(), true),
667 };
668 let f6 = Filter {
669 field_id: 6u32,
670 op: Operator::contains("mid".to_string(), true),
671 };
672
673 match f4.op {
674 Operator::StartsWith {
675 pattern: b,
676 case_sensitive,
677 } => {
678 assert_eq!(b, "pre");
679 assert!(case_sensitive);
680 }
681 _ => panic!(),
682 }
683 match f5.op {
684 Operator::EndsWith {
685 pattern: b,
686 case_sensitive,
687 } => {
688 assert_eq!(b, "suf");
689 assert!(case_sensitive);
690 }
691 _ => panic!(),
692 }
693 match f6.op {
694 Operator::Contains {
695 pattern: b,
696 case_sensitive,
697 } => {
698 assert_eq!(b, "mid");
699 assert!(case_sensitive);
700 }
701 _ => panic!(),
702 }
703 }
704
705 #[test]
706 fn generic_field_id_works_with_strings() {
707 let f1 = Filter {
709 field_id: "name",
710 op: Operator::Equals("alice".into()),
711 };
712 let f2 = Filter {
713 field_id: "status",
714 op: Operator::GreaterThanOrEquals("active".into()),
715 };
716 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
717
718 match expr {
719 Expr::And(v) => {
720 assert_eq!(v.len(), 2);
721 match &v[0] {
722 Expr::Pred(Filter { field_id, .. }) => {
723 assert_eq!(*field_id, "name")
724 }
725 _ => panic!("expected Pred(name)"),
726 }
727 match &v[1] {
728 Expr::Pred(Filter { field_id, .. }) => {
729 assert_eq!(*field_id, "status")
730 }
731 _ => panic!("expected Pred(status)"),
732 }
733 }
734 _ => panic!("expected And"),
735 }
736 }
737
738 #[test]
739 fn very_deep_not_chain() {
740 let base = Expr::Pred(Filter {
742 field_id: 42u32,
743 op: Operator::Equals("x".into()),
744 });
745 let mut expr = base;
746 for _ in 0..64 {
747 expr = Expr::not(expr);
748 }
749
750 let mut count = 0usize;
752 let mut cur = &expr;
753 loop {
754 match cur {
755 Expr::Not(inner) => {
756 count += 1;
757 cur = inner;
758 }
759 Expr::Pred(Filter { field_id, .. }) => {
760 assert_eq!(*field_id, 42);
761 break;
762 }
763 _ => panic!("unexpected node inside deep NOT chain"),
764 }
765 }
766 assert_eq!(count, 64);
767 }
768
769 #[test]
770 fn literal_construction() {
771 let f = Filter {
772 field_id: "my_u64_col",
773 op: Operator::Range {
774 lower: Bound::Included(150.into()),
775 upper: Bound::Excluded(300.into()),
776 },
777 };
778
779 match f.op {
780 Operator::Range { lower, upper } => {
781 assert_eq!(lower, Bound::Included(Literal::Int128(150)));
782 assert_eq!(upper, Bound::Excluded(Literal::Int128(300)));
783 }
784 _ => panic!("Expected a range operator"),
785 }
786
787 let f2 = Filter {
788 field_id: "my_str_col",
789 op: Operator::Equals("hello".into()),
790 };
791
792 match f2.op {
793 Operator::Equals(lit) => {
794 assert_eq!(lit, Literal::String("hello".to_string()));
795 }
796 _ => panic!("Expected an equals operator"),
797 }
798 }
799}