1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17 And(Vec<Expr<'a, F>>),
18 Or(Vec<Expr<'a, F>>),
19 Not(Box<Expr<'a, F>>),
20 Pred(Filter<'a, F>),
21 Compare {
22 left: ScalarExpr<F>,
23 op: CompareOp,
24 right: ScalarExpr<F>,
25 },
26 InList {
27 expr: ScalarExpr<F>,
28 list: Vec<ScalarExpr<F>>,
29 negated: bool,
30 },
31 IsNull {
35 expr: ScalarExpr<F>,
36 negated: bool,
37 },
38 Literal(bool),
41 Exists(SubqueryExpr),
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52 pub id: SubqueryId,
54 pub negated: bool,
56}
57
58#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61 pub id: SubqueryId,
63 pub data_type: DataType,
65}
66
67impl<'a, F> Expr<'a, F> {
68 #[inline]
70 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
71 Expr::And(fs.into_iter().map(Expr::Pred).collect())
72 }
73
74 #[inline]
76 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
77 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
78 }
79
80 #[allow(clippy::should_implement_trait)]
82 #[inline]
83 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
84 Expr::Not(Box::new(e))
85 }
86
87 pub fn is_full_range_for(&self, expected_field: &F) -> bool
89 where
90 F: PartialEq,
91 {
92 matches!(
93 self,
94 Expr::Pred(Filter {
95 field_id,
96 op:
97 Operator::Range {
98 lower: Bound::Unbounded,
99 upper: Bound::Unbounded,
100 },
101 }) if field_id == expected_field
102 )
103 }
104
105 pub fn is_trivially_true(&self) -> bool {
110 match self {
111 Expr::Pred(Filter {
112 op:
113 Operator::Range {
114 lower: Bound::Unbounded,
115 upper: Bound::Unbounded,
116 },
117 ..
118 }) => true,
119 Expr::Literal(value) => *value,
120 _ => false,
121 }
122 }
123}
124
125#[derive(Clone, Debug)]
127pub enum ScalarExpr<F> {
128 Column(F),
129 Literal(Literal),
130 Binary {
131 left: Box<ScalarExpr<F>>,
132 op: BinaryOp,
133 right: Box<ScalarExpr<F>>,
134 },
135 Not(Box<ScalarExpr<F>>),
137 IsNull {
140 expr: Box<ScalarExpr<F>>,
141 negated: bool,
142 },
143 Aggregate(AggregateCall<F>),
146 GetField {
150 base: Box<ScalarExpr<F>>,
151 field_name: String,
152 },
153 Cast {
155 expr: Box<ScalarExpr<F>>,
156 data_type: DataType,
157 },
158 Compare {
160 left: Box<ScalarExpr<F>>,
161 op: CompareOp,
162 right: Box<ScalarExpr<F>>,
163 },
164 Coalesce(Vec<ScalarExpr<F>>),
166 ScalarSubquery(ScalarSubqueryExpr),
168 Case {
170 operand: Option<Box<ScalarExpr<F>>>,
172 branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
174 else_expr: Option<Box<ScalarExpr<F>>>,
176 },
177 Random,
182}
183
184#[derive(Clone, Debug)]
189pub enum AggregateCall<F> {
190 CountStar,
191 Count {
192 expr: Box<ScalarExpr<F>>,
193 distinct: bool,
194 },
195 Sum {
196 expr: Box<ScalarExpr<F>>,
197 distinct: bool,
198 },
199 Total {
200 expr: Box<ScalarExpr<F>>,
201 distinct: bool,
202 },
203 Avg {
204 expr: Box<ScalarExpr<F>>,
205 distinct: bool,
206 },
207 Min(Box<ScalarExpr<F>>),
208 Max(Box<ScalarExpr<F>>),
209 CountNulls(Box<ScalarExpr<F>>),
210 GroupConcat {
211 expr: Box<ScalarExpr<F>>,
212 distinct: bool,
213 separator: Option<String>,
214 },
215}
216
217impl<F> ScalarExpr<F> {
218 #[inline]
219 pub fn column(field: F) -> Self {
220 Self::Column(field)
221 }
222
223 #[inline]
224 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
225 Self::Literal(lit.into())
226 }
227
228 #[inline]
229 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
230 Self::Binary {
231 left: Box::new(left),
232 op,
233 right: Box::new(right),
234 }
235 }
236
237 #[inline]
238 pub fn logical_not(expr: Self) -> Self {
239 Self::Not(Box::new(expr))
240 }
241
242 #[inline]
243 pub fn is_null(expr: Self, negated: bool) -> Self {
244 Self::IsNull {
245 expr: Box::new(expr),
246 negated,
247 }
248 }
249
250 #[inline]
251 pub fn aggregate(call: AggregateCall<F>) -> Self {
252 Self::Aggregate(call)
253 }
254
255 #[inline]
256 pub fn get_field(base: Self, field_name: String) -> Self {
257 Self::GetField {
258 base: Box::new(base),
259 field_name,
260 }
261 }
262
263 #[inline]
264 pub fn cast(expr: Self, data_type: DataType) -> Self {
265 Self::Cast {
266 expr: Box::new(expr),
267 data_type,
268 }
269 }
270
271 #[inline]
272 pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
273 Self::Compare {
274 left: Box::new(left),
275 op,
276 right: Box::new(right),
277 }
278 }
279
280 #[inline]
281 pub fn coalesce(exprs: Vec<Self>) -> Self {
282 Self::Coalesce(exprs)
283 }
284
285 #[inline]
286 pub fn scalar_subquery(id: SubqueryId, data_type: DataType) -> Self {
287 Self::ScalarSubquery(ScalarSubqueryExpr { id, data_type })
288 }
289
290 #[inline]
291 pub fn case(
292 operand: Option<Self>,
293 branches: Vec<(Self, Self)>,
294 else_expr: Option<Self>,
295 ) -> Self {
296 Self::Case {
297 operand: operand.map(Box::new),
298 branches,
299 else_expr: else_expr.map(Box::new),
300 }
301 }
302
303 #[inline]
304 pub fn random() -> Self {
305 Self::Random
306 }
307}
308
309#[derive(Clone, Copy, Debug, Eq, PartialEq)]
311pub enum BinaryOp {
312 Add,
313 Subtract,
314 Multiply,
315 Divide,
316 Modulo,
317 And,
318 Or,
319 BitwiseShiftLeft,
320 BitwiseShiftRight,
321}
322
323impl BinaryOp {
324 #[inline]
325 pub fn as_str(&self) -> &'static str {
326 match self {
327 BinaryOp::Add => "+",
328 BinaryOp::Subtract => "-",
329 BinaryOp::Multiply => "*",
330 BinaryOp::Divide => "/",
331 BinaryOp::Modulo => "%",
332 BinaryOp::And => "AND",
333 BinaryOp::Or => "OR",
334 BinaryOp::BitwiseShiftLeft => "<<",
335 BinaryOp::BitwiseShiftRight => ">>",
336 }
337 }
338}
339
340#[derive(Clone, Copy, Debug, Eq, PartialEq)]
342pub enum CompareOp {
343 Eq,
344 NotEq,
345 Lt,
346 LtEq,
347 Gt,
348 GtEq,
349}
350
351impl CompareOp {
352 #[inline]
353 pub fn as_str(&self) -> &'static str {
354 match self {
355 CompareOp::Eq => "=",
356 CompareOp::NotEq => "!=",
357 CompareOp::Lt => "<",
358 CompareOp::LtEq => "<=",
359 CompareOp::Gt => ">",
360 CompareOp::GtEq => ">=",
361 }
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct Filter<'a, F> {
368 pub field_id: F,
369 pub op: Operator<'a>,
370}
371
372#[derive(Debug, Clone)]
377pub enum Operator<'a> {
378 Equals(Literal),
379 Range {
380 lower: Bound<Literal>,
381 upper: Bound<Literal>,
382 },
383 GreaterThan(Literal),
384 GreaterThanOrEquals(Literal),
385 LessThan(Literal),
386 LessThanOrEquals(Literal),
387 In(&'a [Literal]),
388 StartsWith {
389 pattern: String,
390 case_sensitive: bool,
391 },
392 EndsWith {
393 pattern: String,
394 case_sensitive: bool,
395 },
396 Contains {
397 pattern: String,
398 case_sensitive: bool,
399 },
400 IsNull,
401 IsNotNull,
402}
403
404impl<'a> Operator<'a> {
405 #[inline]
406 pub fn starts_with(pattern: String, case_sensitive: bool) -> Self {
407 Operator::StartsWith {
408 pattern,
409 case_sensitive,
410 }
411 }
412
413 #[inline]
414 pub fn ends_with(pattern: String, case_sensitive: bool) -> Self {
415 Operator::EndsWith {
416 pattern,
417 case_sensitive,
418 }
419 }
420
421 #[inline]
422 pub fn contains(pattern: String, case_sensitive: bool) -> Self {
423 Operator::Contains {
424 pattern,
425 case_sensitive,
426 }
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn build_simple_exprs() {
436 let f1 = Filter {
437 field_id: 1,
438 op: Operator::Equals("abc".into()),
439 };
440 let f2 = Filter {
441 field_id: 2,
442 op: Operator::LessThan("zzz".into()),
443 };
444 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
445 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
446 let not_all = Expr::not(all);
447 match any {
448 Expr::Or(v) => assert_eq!(v.len(), 2),
449 _ => panic!("expected Or"),
450 }
451 match not_all {
452 Expr::Not(inner) => match *inner {
453 Expr::And(v) => assert_eq!(v.len(), 2),
454 _ => panic!("expected And inside Not"),
455 },
456 _ => panic!("expected Not"),
457 }
458 }
459
460 #[test]
461 fn complex_nested_shape() {
462 let f1 = Filter {
467 field_id: 1u32,
468 op: Operator::Equals("a".into()),
469 };
470 let f2 = Filter {
471 field_id: 2u32,
472 op: Operator::LessThan("zzz".into()),
473 };
474 let in_values = ["x".into(), "y".into(), "z".into()];
475 let f3 = Filter {
476 field_id: 3u32,
477 op: Operator::In(&in_values),
478 };
479 let f4 = Filter {
480 field_id: 4u32,
481 op: Operator::starts_with("pre".to_string(), true),
482 };
483
484 let left = Expr::And(vec![
486 Expr::Pred(f1.clone()),
487 Expr::Or(vec![
488 Expr::Pred(f2.clone()),
489 Expr::not(Expr::Pred(f3.clone())),
490 ]),
491 ]);
492 let right = Expr::And(vec![
493 Expr::not(Expr::Pred(f1.clone())),
494 Expr::Pred(f4.clone()),
495 ]);
496 let top = Expr::Or(vec![left, right]);
497
498 match top {
500 Expr::Or(branches) => {
501 assert_eq!(branches.len(), 2);
502 match &branches[0] {
503 Expr::And(v) => {
504 assert_eq!(v.len(), 2);
505 match &v[0] {
507 Expr::Pred(Filter { field_id, .. }) => {
508 assert_eq!(*field_id, 1)
509 }
510 _ => panic!("expected Pred(f1) in left-AND[0]"),
511 }
512 match &v[1] {
513 Expr::Or(or_vec) => {
514 assert_eq!(or_vec.len(), 2);
515 match &or_vec[0] {
516 Expr::Pred(Filter { field_id, .. }) => {
517 assert_eq!(*field_id, 2)
518 }
519 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
520 }
521 match &or_vec[1] {
522 Expr::Not(inner) => match inner.as_ref() {
523 Expr::Pred(Filter { field_id, .. }) => {
524 assert_eq!(*field_id, 3)
525 }
526 _ => panic!(
527 "expected Not(Pred(f3)) in \
528 left-AND[1].OR[1]"
529 ),
530 },
531 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
532 }
533 }
534 _ => panic!("expected OR in left-AND[1]"),
535 }
536 }
537 _ => panic!("expected AND on left branch of top OR"),
538 }
539 match &branches[1] {
540 Expr::And(v) => {
541 assert_eq!(v.len(), 2);
542 match &v[0] {
544 Expr::Not(inner) => match inner.as_ref() {
545 Expr::Pred(Filter { field_id, .. }) => {
546 assert_eq!(*field_id, 1)
547 }
548 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
549 },
550 _ => panic!("expected Not(...) in right-AND[0]"),
551 }
552 match &v[1] {
553 Expr::Pred(Filter { field_id, .. }) => {
554 assert_eq!(*field_id, 4)
555 }
556 _ => panic!("expected Pred(f4) in right-AND[1]"),
557 }
558 }
559 _ => panic!("expected AND on right branch of top OR"),
560 }
561 }
562 _ => panic!("expected top-level OR"),
563 }
564 }
565
566 #[test]
567 fn range_bounds_roundtrip() {
568 let f = Filter {
570 field_id: 7u32,
571 op: Operator::Range {
572 lower: Bound::Included("aaa".into()),
573 upper: Bound::Excluded("bbb".into()),
574 },
575 };
576
577 match &f.op {
578 Operator::Range { lower, upper } => {
579 if let Bound::Included(l) = lower {
580 assert_eq!(*l, Literal::String("aaa".to_string()));
581 } else {
582 panic!("lower bound should be Included");
583 }
584
585 if let Bound::Excluded(u) = upper {
586 assert_eq!(*u, Literal::String("bbb".to_string()));
587 } else {
588 panic!("upper bound should be Excluded");
589 }
590 }
591 _ => panic!("expected Range operator"),
592 }
593 }
594
595 #[test]
596 fn helper_builders_preserve_structure_and_order() {
597 let f1 = Filter {
598 field_id: 1u32,
599 op: Operator::Equals("a".into()),
600 };
601 let f2 = Filter {
602 field_id: 2u32,
603 op: Operator::Equals("b".into()),
604 };
605 let f3 = Filter {
606 field_id: 3u32,
607 op: Operator::Equals("c".into()),
608 };
609
610 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
611 match and_expr {
612 Expr::And(v) => {
613 assert_eq!(v.len(), 3);
614 match &v[0] {
616 Expr::Pred(Filter { field_id, .. }) => {
617 assert_eq!(*field_id, 1)
618 }
619 _ => panic!(),
620 }
621 match &v[1] {
622 Expr::Pred(Filter { field_id, .. }) => {
623 assert_eq!(*field_id, 2)
624 }
625 _ => panic!(),
626 }
627 match &v[2] {
628 Expr::Pred(Filter { field_id, .. }) => {
629 assert_eq!(*field_id, 3)
630 }
631 _ => panic!(),
632 }
633 }
634 _ => panic!("expected And"),
635 }
636
637 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
638 match or_expr {
639 Expr::Or(v) => {
640 assert_eq!(v.len(), 3);
641 match &v[0] {
643 Expr::Pred(Filter { field_id, .. }) => {
644 assert_eq!(*field_id, 3)
645 }
646 _ => panic!(),
647 }
648 match &v[1] {
649 Expr::Pred(Filter { field_id, .. }) => {
650 assert_eq!(*field_id, 2)
651 }
652 _ => panic!(),
653 }
654 match &v[2] {
655 Expr::Pred(Filter { field_id, .. }) => {
656 assert_eq!(*field_id, 1)
657 }
658 _ => panic!(),
659 }
660 }
661 _ => panic!("expected Or"),
662 }
663 }
664
665 #[test]
666 fn set_and_pattern_ops_hold_borrowed_slices() {
667 let in_values = ["aa".into(), "bb".into(), "cc".into()];
668 let f_in = Filter {
669 field_id: 9u32,
670 op: Operator::In(&in_values),
671 };
672 match f_in.op {
673 Operator::In(arr) => {
674 assert_eq!(arr.len(), 3);
675 }
676 _ => panic!("expected In"),
677 }
678
679 let f4 = Filter {
680 field_id: 4u32,
681 op: Operator::starts_with("pre".to_string(), true),
682 };
683 let f5 = Filter {
684 field_id: 5u32,
685 op: Operator::ends_with("suf".to_string(), true),
686 };
687 let f6 = Filter {
688 field_id: 6u32,
689 op: Operator::contains("mid".to_string(), true),
690 };
691
692 match f4.op {
693 Operator::StartsWith {
694 pattern: b,
695 case_sensitive,
696 } => {
697 assert_eq!(b, "pre");
698 assert!(case_sensitive);
699 }
700 _ => panic!(),
701 }
702 match f5.op {
703 Operator::EndsWith {
704 pattern: b,
705 case_sensitive,
706 } => {
707 assert_eq!(b, "suf");
708 assert!(case_sensitive);
709 }
710 _ => panic!(),
711 }
712 match f6.op {
713 Operator::Contains {
714 pattern: b,
715 case_sensitive,
716 } => {
717 assert_eq!(b, "mid");
718 assert!(case_sensitive);
719 }
720 _ => panic!(),
721 }
722 }
723
724 #[test]
725 fn generic_field_id_works_with_strings() {
726 let f1 = Filter {
728 field_id: "name",
729 op: Operator::Equals("alice".into()),
730 };
731 let f2 = Filter {
732 field_id: "status",
733 op: Operator::GreaterThanOrEquals("active".into()),
734 };
735 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
736
737 match expr {
738 Expr::And(v) => {
739 assert_eq!(v.len(), 2);
740 match &v[0] {
741 Expr::Pred(Filter { field_id, .. }) => {
742 assert_eq!(*field_id, "name")
743 }
744 _ => panic!("expected Pred(name)"),
745 }
746 match &v[1] {
747 Expr::Pred(Filter { field_id, .. }) => {
748 assert_eq!(*field_id, "status")
749 }
750 _ => panic!("expected Pred(status)"),
751 }
752 }
753 _ => panic!("expected And"),
754 }
755 }
756
757 #[test]
758 fn very_deep_not_chain() {
759 let base = Expr::Pred(Filter {
761 field_id: 42u32,
762 op: Operator::Equals("x".into()),
763 });
764 let mut expr = base;
765 for _ in 0..64 {
766 expr = Expr::not(expr);
767 }
768
769 let mut count = 0usize;
771 let mut cur = &expr;
772 loop {
773 match cur {
774 Expr::Not(inner) => {
775 count += 1;
776 cur = inner;
777 }
778 Expr::Pred(Filter { field_id, .. }) => {
779 assert_eq!(*field_id, 42);
780 break;
781 }
782 _ => panic!("unexpected node inside deep NOT chain"),
783 }
784 }
785 assert_eq!(count, 64);
786 }
787
788 #[test]
789 fn literal_construction() {
790 let f = Filter {
791 field_id: "my_u64_col",
792 op: Operator::Range {
793 lower: Bound::Included(150.into()),
794 upper: Bound::Excluded(300.into()),
795 },
796 };
797
798 match f.op {
799 Operator::Range { lower, upper } => {
800 assert_eq!(lower, Bound::Included(Literal::Int128(150)));
801 assert_eq!(upper, Bound::Excluded(Literal::Int128(300)));
802 }
803 _ => panic!("Expected a range operator"),
804 }
805
806 let f2 = Filter {
807 field_id: "my_str_col",
808 op: Operator::Equals("hello".into()),
809 };
810
811 match f2.op {
812 Operator::Equals(lit) => {
813 assert_eq!(lit, Literal::String("hello".to_string()));
814 }
815 _ => panic!("Expected an equals operator"),
816 }
817 }
818}