1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17 And(Vec<Expr<'a, F>>),
18 Or(Vec<Expr<'a, F>>),
19 Not(Box<Expr<'a, F>>),
20 Pred(Filter<'a, F>),
21 Compare {
22 left: ScalarExpr<F>,
23 op: CompareOp,
24 right: ScalarExpr<F>,
25 },
26 InList {
27 expr: ScalarExpr<F>,
28 list: Vec<ScalarExpr<F>>,
29 negated: bool,
30 },
31 IsNull {
35 expr: ScalarExpr<F>,
36 negated: bool,
37 },
38 Literal(bool),
41 Exists(SubqueryExpr),
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52 pub id: SubqueryId,
54 pub negated: bool,
56}
57
58#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61 pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66 #[inline]
68 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69 Expr::And(fs.into_iter().map(Expr::Pred).collect())
70 }
71
72 #[inline]
74 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76 }
77
78 #[allow(clippy::should_implement_trait)]
80 #[inline]
81 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82 Expr::Not(Box::new(e))
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89 Column(F),
90 Literal(Literal),
91 Binary {
92 left: Box<ScalarExpr<F>>,
93 op: BinaryOp,
94 right: Box<ScalarExpr<F>>,
95 },
96 Not(Box<ScalarExpr<F>>),
98 IsNull {
101 expr: Box<ScalarExpr<F>>,
102 negated: bool,
103 },
104 Aggregate(AggregateCall<F>),
107 GetField {
111 base: Box<ScalarExpr<F>>,
112 field_name: String,
113 },
114 Cast {
116 expr: Box<ScalarExpr<F>>,
117 data_type: DataType,
118 },
119 Compare {
121 left: Box<ScalarExpr<F>>,
122 op: CompareOp,
123 right: Box<ScalarExpr<F>>,
124 },
125 Coalesce(Vec<ScalarExpr<F>>),
127 ScalarSubquery(ScalarSubqueryExpr),
129 Case {
131 operand: Option<Box<ScalarExpr<F>>>,
133 branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135 else_expr: Option<Box<ScalarExpr<F>>>,
137 },
138 Random,
143}
144
145#[derive(Clone, Debug)]
150pub enum AggregateCall<F> {
151 CountStar,
152 Count {
153 expr: Box<ScalarExpr<F>>,
154 distinct: bool,
155 },
156 Sum {
157 expr: Box<ScalarExpr<F>>,
158 distinct: bool,
159 },
160 Total {
161 expr: Box<ScalarExpr<F>>,
162 distinct: bool,
163 },
164 Avg {
165 expr: Box<ScalarExpr<F>>,
166 distinct: bool,
167 },
168 Min(Box<ScalarExpr<F>>),
169 Max(Box<ScalarExpr<F>>),
170 CountNulls(Box<ScalarExpr<F>>),
171 GroupConcat {
172 expr: Box<ScalarExpr<F>>,
173 distinct: bool,
174 separator: Option<String>,
175 },
176}
177
178impl<F> ScalarExpr<F> {
179 #[inline]
180 pub fn column(field: F) -> Self {
181 Self::Column(field)
182 }
183
184 #[inline]
185 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
186 Self::Literal(lit.into())
187 }
188
189 #[inline]
190 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
191 Self::Binary {
192 left: Box::new(left),
193 op,
194 right: Box::new(right),
195 }
196 }
197
198 #[inline]
199 pub fn logical_not(expr: Self) -> Self {
200 Self::Not(Box::new(expr))
201 }
202
203 #[inline]
204 pub fn is_null(expr: Self, negated: bool) -> Self {
205 Self::IsNull {
206 expr: Box::new(expr),
207 negated,
208 }
209 }
210
211 #[inline]
212 pub fn aggregate(call: AggregateCall<F>) -> Self {
213 Self::Aggregate(call)
214 }
215
216 #[inline]
217 pub fn get_field(base: Self, field_name: String) -> Self {
218 Self::GetField {
219 base: Box::new(base),
220 field_name,
221 }
222 }
223
224 #[inline]
225 pub fn cast(expr: Self, data_type: DataType) -> Self {
226 Self::Cast {
227 expr: Box::new(expr),
228 data_type,
229 }
230 }
231
232 #[inline]
233 pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
234 Self::Compare {
235 left: Box::new(left),
236 op,
237 right: Box::new(right),
238 }
239 }
240
241 #[inline]
242 pub fn coalesce(exprs: Vec<Self>) -> Self {
243 Self::Coalesce(exprs)
244 }
245
246 #[inline]
247 pub fn scalar_subquery(id: SubqueryId) -> Self {
248 Self::ScalarSubquery(ScalarSubqueryExpr { id })
249 }
250
251 #[inline]
252 pub fn case(
253 operand: Option<Self>,
254 branches: Vec<(Self, Self)>,
255 else_expr: Option<Self>,
256 ) -> Self {
257 Self::Case {
258 operand: operand.map(Box::new),
259 branches,
260 else_expr: else_expr.map(Box::new),
261 }
262 }
263
264 #[inline]
265 pub fn random() -> Self {
266 Self::Random
267 }
268}
269
270#[derive(Clone, Copy, Debug, Eq, PartialEq)]
272pub enum BinaryOp {
273 Add,
274 Subtract,
275 Multiply,
276 Divide,
277 Modulo,
278 And,
279 Or,
280 BitwiseShiftLeft,
281 BitwiseShiftRight,
282}
283
284#[derive(Clone, Copy, Debug, Eq, PartialEq)]
286pub enum CompareOp {
287 Eq,
288 NotEq,
289 Lt,
290 LtEq,
291 Gt,
292 GtEq,
293}
294
295#[derive(Debug, Clone)]
297pub struct Filter<'a, F> {
298 pub field_id: F,
299 pub op: Operator<'a>,
300}
301
302#[derive(Debug, Clone)]
307pub enum Operator<'a> {
308 Equals(Literal),
309 Range {
310 lower: Bound<Literal>,
311 upper: Bound<Literal>,
312 },
313 GreaterThan(Literal),
314 GreaterThanOrEquals(Literal),
315 LessThan(Literal),
316 LessThanOrEquals(Literal),
317 In(&'a [Literal]),
318 StartsWith {
319 pattern: &'a str,
320 case_sensitive: bool,
321 },
322 EndsWith {
323 pattern: &'a str,
324 case_sensitive: bool,
325 },
326 Contains {
327 pattern: &'a str,
328 case_sensitive: bool,
329 },
330 IsNull,
331 IsNotNull,
332}
333
334impl<'a> Operator<'a> {
335 #[inline]
336 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
337 Operator::StartsWith {
338 pattern,
339 case_sensitive,
340 }
341 }
342
343 #[inline]
344 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
345 Operator::EndsWith {
346 pattern,
347 case_sensitive,
348 }
349 }
350
351 #[inline]
352 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
353 Operator::Contains {
354 pattern,
355 case_sensitive,
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn build_simple_exprs() {
366 let f1 = Filter {
367 field_id: 1,
368 op: Operator::Equals("abc".into()),
369 };
370 let f2 = Filter {
371 field_id: 2,
372 op: Operator::LessThan("zzz".into()),
373 };
374 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
375 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
376 let not_all = Expr::not(all);
377 match any {
378 Expr::Or(v) => assert_eq!(v.len(), 2),
379 _ => panic!("expected Or"),
380 }
381 match not_all {
382 Expr::Not(inner) => match *inner {
383 Expr::And(v) => assert_eq!(v.len(), 2),
384 _ => panic!("expected And inside Not"),
385 },
386 _ => panic!("expected Not"),
387 }
388 }
389
390 #[test]
391 fn complex_nested_shape() {
392 let f1 = Filter {
397 field_id: 1u32,
398 op: Operator::Equals("a".into()),
399 };
400 let f2 = Filter {
401 field_id: 2u32,
402 op: Operator::LessThan("zzz".into()),
403 };
404 let in_values = ["x".into(), "y".into(), "z".into()];
405 let f3 = Filter {
406 field_id: 3u32,
407 op: Operator::In(&in_values),
408 };
409 let f4 = Filter {
410 field_id: 4u32,
411 op: Operator::starts_with("pre", true),
412 };
413
414 let left = Expr::And(vec![
416 Expr::Pred(f1.clone()),
417 Expr::Or(vec![
418 Expr::Pred(f2.clone()),
419 Expr::not(Expr::Pred(f3.clone())),
420 ]),
421 ]);
422 let right = Expr::And(vec![
423 Expr::not(Expr::Pred(f1.clone())),
424 Expr::Pred(f4.clone()),
425 ]);
426 let top = Expr::Or(vec![left, right]);
427
428 match top {
430 Expr::Or(branches) => {
431 assert_eq!(branches.len(), 2);
432 match &branches[0] {
433 Expr::And(v) => {
434 assert_eq!(v.len(), 2);
435 match &v[0] {
437 Expr::Pred(Filter { field_id, .. }) => {
438 assert_eq!(*field_id, 1)
439 }
440 _ => panic!("expected Pred(f1) in left-AND[0]"),
441 }
442 match &v[1] {
443 Expr::Or(or_vec) => {
444 assert_eq!(or_vec.len(), 2);
445 match &or_vec[0] {
446 Expr::Pred(Filter { field_id, .. }) => {
447 assert_eq!(*field_id, 2)
448 }
449 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
450 }
451 match &or_vec[1] {
452 Expr::Not(inner) => match inner.as_ref() {
453 Expr::Pred(Filter { field_id, .. }) => {
454 assert_eq!(*field_id, 3)
455 }
456 _ => panic!(
457 "expected Not(Pred(f3)) in \
458 left-AND[1].OR[1]"
459 ),
460 },
461 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
462 }
463 }
464 _ => panic!("expected OR in left-AND[1]"),
465 }
466 }
467 _ => panic!("expected AND on left branch of top OR"),
468 }
469 match &branches[1] {
470 Expr::And(v) => {
471 assert_eq!(v.len(), 2);
472 match &v[0] {
474 Expr::Not(inner) => match inner.as_ref() {
475 Expr::Pred(Filter { field_id, .. }) => {
476 assert_eq!(*field_id, 1)
477 }
478 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
479 },
480 _ => panic!("expected Not(...) in right-AND[0]"),
481 }
482 match &v[1] {
483 Expr::Pred(Filter { field_id, .. }) => {
484 assert_eq!(*field_id, 4)
485 }
486 _ => panic!("expected Pred(f4) in right-AND[1]"),
487 }
488 }
489 _ => panic!("expected AND on right branch of top OR"),
490 }
491 }
492 _ => panic!("expected top-level OR"),
493 }
494 }
495
496 #[test]
497 fn range_bounds_roundtrip() {
498 let f = Filter {
500 field_id: 7u32,
501 op: Operator::Range {
502 lower: Bound::Included("aaa".into()),
503 upper: Bound::Excluded("bbb".into()),
504 },
505 };
506
507 match &f.op {
508 Operator::Range { lower, upper } => {
509 if let Bound::Included(l) = lower {
510 assert_eq!(*l, Literal::String("aaa".to_string()));
511 } else {
512 panic!("lower bound should be Included");
513 }
514
515 if let Bound::Excluded(u) = upper {
516 assert_eq!(*u, Literal::String("bbb".to_string()));
517 } else {
518 panic!("upper bound should be Excluded");
519 }
520 }
521 _ => panic!("expected Range operator"),
522 }
523 }
524
525 #[test]
526 fn helper_builders_preserve_structure_and_order() {
527 let f1 = Filter {
528 field_id: 1u32,
529 op: Operator::Equals("a".into()),
530 };
531 let f2 = Filter {
532 field_id: 2u32,
533 op: Operator::Equals("b".into()),
534 };
535 let f3 = Filter {
536 field_id: 3u32,
537 op: Operator::Equals("c".into()),
538 };
539
540 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
541 match and_expr {
542 Expr::And(v) => {
543 assert_eq!(v.len(), 3);
544 match &v[0] {
546 Expr::Pred(Filter { field_id, .. }) => {
547 assert_eq!(*field_id, 1)
548 }
549 _ => panic!(),
550 }
551 match &v[1] {
552 Expr::Pred(Filter { field_id, .. }) => {
553 assert_eq!(*field_id, 2)
554 }
555 _ => panic!(),
556 }
557 match &v[2] {
558 Expr::Pred(Filter { field_id, .. }) => {
559 assert_eq!(*field_id, 3)
560 }
561 _ => panic!(),
562 }
563 }
564 _ => panic!("expected And"),
565 }
566
567 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
568 match or_expr {
569 Expr::Or(v) => {
570 assert_eq!(v.len(), 3);
571 match &v[0] {
573 Expr::Pred(Filter { field_id, .. }) => {
574 assert_eq!(*field_id, 3)
575 }
576 _ => panic!(),
577 }
578 match &v[1] {
579 Expr::Pred(Filter { field_id, .. }) => {
580 assert_eq!(*field_id, 2)
581 }
582 _ => panic!(),
583 }
584 match &v[2] {
585 Expr::Pred(Filter { field_id, .. }) => {
586 assert_eq!(*field_id, 1)
587 }
588 _ => panic!(),
589 }
590 }
591 _ => panic!("expected Or"),
592 }
593 }
594
595 #[test]
596 fn set_and_pattern_ops_hold_borrowed_slices() {
597 let in_values = ["aa".into(), "bb".into(), "cc".into()];
598 let f_in = Filter {
599 field_id: 9u32,
600 op: Operator::In(&in_values),
601 };
602 match f_in.op {
603 Operator::In(arr) => {
604 assert_eq!(arr.len(), 3);
605 }
606 _ => panic!("expected In"),
607 }
608
609 let f_sw = Filter {
610 field_id: 10u32,
611 op: Operator::starts_with("pre", true),
612 };
613 let f_ew = Filter {
614 field_id: 11u32,
615 op: Operator::ends_with("suf", true),
616 };
617 let f_ct = Filter {
618 field_id: 12u32,
619 op: Operator::contains("mid", true),
620 };
621
622 match f_sw.op {
623 Operator::StartsWith {
624 pattern: b,
625 case_sensitive,
626 } => {
627 assert_eq!(b, "pre");
628 assert!(case_sensitive);
629 }
630 _ => panic!(),
631 }
632 match f_ew.op {
633 Operator::EndsWith {
634 pattern: b,
635 case_sensitive,
636 } => {
637 assert_eq!(b, "suf");
638 assert!(case_sensitive);
639 }
640 _ => panic!(),
641 }
642 match f_ct.op {
643 Operator::Contains {
644 pattern: b,
645 case_sensitive,
646 } => {
647 assert_eq!(b, "mid");
648 assert!(case_sensitive);
649 }
650 _ => panic!(),
651 }
652 }
653
654 #[test]
655 fn generic_field_id_works_with_strings() {
656 let f1 = Filter {
658 field_id: "name",
659 op: Operator::Equals("alice".into()),
660 };
661 let f2 = Filter {
662 field_id: "status",
663 op: Operator::GreaterThanOrEquals("active".into()),
664 };
665 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
666
667 match expr {
668 Expr::And(v) => {
669 assert_eq!(v.len(), 2);
670 match &v[0] {
671 Expr::Pred(Filter { field_id, .. }) => {
672 assert_eq!(*field_id, "name")
673 }
674 _ => panic!("expected Pred(name)"),
675 }
676 match &v[1] {
677 Expr::Pred(Filter { field_id, .. }) => {
678 assert_eq!(*field_id, "status")
679 }
680 _ => panic!("expected Pred(status)"),
681 }
682 }
683 _ => panic!("expected And"),
684 }
685 }
686
687 #[test]
688 fn very_deep_not_chain() {
689 let base = Expr::Pred(Filter {
691 field_id: 42u32,
692 op: Operator::Equals("x".into()),
693 });
694 let mut expr = base;
695 for _ in 0..64 {
696 expr = Expr::not(expr);
697 }
698
699 let mut count = 0usize;
701 let mut cur = &expr;
702 loop {
703 match cur {
704 Expr::Not(inner) => {
705 count += 1;
706 cur = inner;
707 }
708 Expr::Pred(Filter { field_id, .. }) => {
709 assert_eq!(*field_id, 42);
710 break;
711 }
712 _ => panic!("unexpected node inside deep NOT chain"),
713 }
714 }
715 assert_eq!(count, 64);
716 }
717
718 #[test]
719 fn literal_construction() {
720 let f = Filter {
721 field_id: "my_u64_col",
722 op: Operator::Range {
723 lower: Bound::Included(150.into()),
724 upper: Bound::Excluded(300.into()),
725 },
726 };
727
728 match f.op {
729 Operator::Range { lower, upper } => {
730 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
731 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
732 }
733 _ => panic!("Expected a range operator"),
734 }
735
736 let f2 = Filter {
737 field_id: "my_str_col",
738 op: Operator::Equals("hello".into()),
739 };
740
741 match f2.op {
742 Operator::Equals(lit) => {
743 assert_eq!(lit, Literal::String("hello".to_string()));
744 }
745 _ => panic!("Expected an equals operator"),
746 }
747 }
748}