1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17 And(Vec<Expr<'a, F>>),
18 Or(Vec<Expr<'a, F>>),
19 Not(Box<Expr<'a, F>>),
20 Pred(Filter<'a, F>),
21 Compare {
22 left: ScalarExpr<F>,
23 op: CompareOp,
24 right: ScalarExpr<F>,
25 },
26 InList {
27 expr: ScalarExpr<F>,
28 list: Vec<ScalarExpr<F>>,
29 negated: bool,
30 },
31 IsNull {
35 expr: ScalarExpr<F>,
36 negated: bool,
37 },
38 Literal(bool),
41 Exists(SubqueryExpr),
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52 pub id: SubqueryId,
54 pub negated: bool,
56}
57
58#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61 pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66 #[inline]
68 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69 Expr::And(fs.into_iter().map(Expr::Pred).collect())
70 }
71
72 #[inline]
74 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76 }
77
78 #[allow(clippy::should_implement_trait)]
80 #[inline]
81 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82 Expr::Not(Box::new(e))
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89 Column(F),
90 Literal(Literal),
91 Binary {
92 left: Box<ScalarExpr<F>>,
93 op: BinaryOp,
94 right: Box<ScalarExpr<F>>,
95 },
96 Not(Box<ScalarExpr<F>>),
98 IsNull {
101 expr: Box<ScalarExpr<F>>,
102 negated: bool,
103 },
104 Aggregate(AggregateCall<F>),
107 GetField {
111 base: Box<ScalarExpr<F>>,
112 field_name: String,
113 },
114 Cast {
116 expr: Box<ScalarExpr<F>>,
117 data_type: DataType,
118 },
119 Compare {
121 left: Box<ScalarExpr<F>>,
122 op: CompareOp,
123 right: Box<ScalarExpr<F>>,
124 },
125 Coalesce(Vec<ScalarExpr<F>>),
127 ScalarSubquery(ScalarSubqueryExpr),
129 Case {
131 operand: Option<Box<ScalarExpr<F>>>,
133 branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135 else_expr: Option<Box<ScalarExpr<F>>>,
137 },
138}
139
140#[derive(Clone, Debug)]
145pub enum AggregateCall<F> {
146 CountStar,
147 Count {
148 expr: Box<ScalarExpr<F>>,
149 distinct: bool,
150 },
151 Sum {
152 expr: Box<ScalarExpr<F>>,
153 distinct: bool,
154 },
155 Total {
156 expr: Box<ScalarExpr<F>>,
157 distinct: bool,
158 },
159 Avg {
160 expr: Box<ScalarExpr<F>>,
161 distinct: bool,
162 },
163 Min(Box<ScalarExpr<F>>),
164 Max(Box<ScalarExpr<F>>),
165 CountNulls(Box<ScalarExpr<F>>),
166 GroupConcat {
167 expr: Box<ScalarExpr<F>>,
168 distinct: bool,
169 separator: Option<String>,
170 },
171}
172
173impl<F> ScalarExpr<F> {
174 #[inline]
175 pub fn column(field: F) -> Self {
176 Self::Column(field)
177 }
178
179 #[inline]
180 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
181 Self::Literal(lit.into())
182 }
183
184 #[inline]
185 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
186 Self::Binary {
187 left: Box::new(left),
188 op,
189 right: Box::new(right),
190 }
191 }
192
193 #[inline]
194 pub fn logical_not(expr: Self) -> Self {
195 Self::Not(Box::new(expr))
196 }
197
198 #[inline]
199 pub fn is_null(expr: Self, negated: bool) -> Self {
200 Self::IsNull {
201 expr: Box::new(expr),
202 negated,
203 }
204 }
205
206 #[inline]
207 pub fn aggregate(call: AggregateCall<F>) -> Self {
208 Self::Aggregate(call)
209 }
210
211 #[inline]
212 pub fn get_field(base: Self, field_name: String) -> Self {
213 Self::GetField {
214 base: Box::new(base),
215 field_name,
216 }
217 }
218
219 #[inline]
220 pub fn cast(expr: Self, data_type: DataType) -> Self {
221 Self::Cast {
222 expr: Box::new(expr),
223 data_type,
224 }
225 }
226
227 #[inline]
228 pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
229 Self::Compare {
230 left: Box::new(left),
231 op,
232 right: Box::new(right),
233 }
234 }
235
236 #[inline]
237 pub fn coalesce(exprs: Vec<Self>) -> Self {
238 Self::Coalesce(exprs)
239 }
240
241 #[inline]
242 pub fn scalar_subquery(id: SubqueryId) -> Self {
243 Self::ScalarSubquery(ScalarSubqueryExpr { id })
244 }
245
246 #[inline]
247 pub fn case(
248 operand: Option<Self>,
249 branches: Vec<(Self, Self)>,
250 else_expr: Option<Self>,
251 ) -> Self {
252 Self::Case {
253 operand: operand.map(Box::new),
254 branches,
255 else_expr: else_expr.map(Box::new),
256 }
257 }
258}
259
260#[derive(Clone, Copy, Debug, Eq, PartialEq)]
262pub enum BinaryOp {
263 Add,
264 Subtract,
265 Multiply,
266 Divide,
267 Modulo,
268 And,
269 Or,
270 BitwiseShiftLeft,
271 BitwiseShiftRight,
272}
273
274#[derive(Clone, Copy, Debug, Eq, PartialEq)]
276pub enum CompareOp {
277 Eq,
278 NotEq,
279 Lt,
280 LtEq,
281 Gt,
282 GtEq,
283}
284
285#[derive(Debug, Clone)]
287pub struct Filter<'a, F> {
288 pub field_id: F,
289 pub op: Operator<'a>,
290}
291
292#[derive(Debug, Clone)]
297pub enum Operator<'a> {
298 Equals(Literal),
299 Range {
300 lower: Bound<Literal>,
301 upper: Bound<Literal>,
302 },
303 GreaterThan(Literal),
304 GreaterThanOrEquals(Literal),
305 LessThan(Literal),
306 LessThanOrEquals(Literal),
307 In(&'a [Literal]),
308 StartsWith {
309 pattern: &'a str,
310 case_sensitive: bool,
311 },
312 EndsWith {
313 pattern: &'a str,
314 case_sensitive: bool,
315 },
316 Contains {
317 pattern: &'a str,
318 case_sensitive: bool,
319 },
320 IsNull,
321 IsNotNull,
322}
323
324impl<'a> Operator<'a> {
325 #[inline]
326 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
327 Operator::StartsWith {
328 pattern,
329 case_sensitive,
330 }
331 }
332
333 #[inline]
334 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
335 Operator::EndsWith {
336 pattern,
337 case_sensitive,
338 }
339 }
340
341 #[inline]
342 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
343 Operator::Contains {
344 pattern,
345 case_sensitive,
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn build_simple_exprs() {
356 let f1 = Filter {
357 field_id: 1,
358 op: Operator::Equals("abc".into()),
359 };
360 let f2 = Filter {
361 field_id: 2,
362 op: Operator::LessThan("zzz".into()),
363 };
364 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
365 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
366 let not_all = Expr::not(all);
367 match any {
368 Expr::Or(v) => assert_eq!(v.len(), 2),
369 _ => panic!("expected Or"),
370 }
371 match not_all {
372 Expr::Not(inner) => match *inner {
373 Expr::And(v) => assert_eq!(v.len(), 2),
374 _ => panic!("expected And inside Not"),
375 },
376 _ => panic!("expected Not"),
377 }
378 }
379
380 #[test]
381 fn complex_nested_shape() {
382 let f1 = Filter {
387 field_id: 1u32,
388 op: Operator::Equals("a".into()),
389 };
390 let f2 = Filter {
391 field_id: 2u32,
392 op: Operator::LessThan("zzz".into()),
393 };
394 let in_values = ["x".into(), "y".into(), "z".into()];
395 let f3 = Filter {
396 field_id: 3u32,
397 op: Operator::In(&in_values),
398 };
399 let f4 = Filter {
400 field_id: 4u32,
401 op: Operator::starts_with("pre", true),
402 };
403
404 let left = Expr::And(vec![
406 Expr::Pred(f1.clone()),
407 Expr::Or(vec![
408 Expr::Pred(f2.clone()),
409 Expr::not(Expr::Pred(f3.clone())),
410 ]),
411 ]);
412 let right = Expr::And(vec![
413 Expr::not(Expr::Pred(f1.clone())),
414 Expr::Pred(f4.clone()),
415 ]);
416 let top = Expr::Or(vec![left, right]);
417
418 match top {
420 Expr::Or(branches) => {
421 assert_eq!(branches.len(), 2);
422 match &branches[0] {
423 Expr::And(v) => {
424 assert_eq!(v.len(), 2);
425 match &v[0] {
427 Expr::Pred(Filter { field_id, .. }) => {
428 assert_eq!(*field_id, 1)
429 }
430 _ => panic!("expected Pred(f1) in left-AND[0]"),
431 }
432 match &v[1] {
433 Expr::Or(or_vec) => {
434 assert_eq!(or_vec.len(), 2);
435 match &or_vec[0] {
436 Expr::Pred(Filter { field_id, .. }) => {
437 assert_eq!(*field_id, 2)
438 }
439 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
440 }
441 match &or_vec[1] {
442 Expr::Not(inner) => match inner.as_ref() {
443 Expr::Pred(Filter { field_id, .. }) => {
444 assert_eq!(*field_id, 3)
445 }
446 _ => panic!(
447 "expected Not(Pred(f3)) in \
448 left-AND[1].OR[1]"
449 ),
450 },
451 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
452 }
453 }
454 _ => panic!("expected OR in left-AND[1]"),
455 }
456 }
457 _ => panic!("expected AND on left branch of top OR"),
458 }
459 match &branches[1] {
460 Expr::And(v) => {
461 assert_eq!(v.len(), 2);
462 match &v[0] {
464 Expr::Not(inner) => match inner.as_ref() {
465 Expr::Pred(Filter { field_id, .. }) => {
466 assert_eq!(*field_id, 1)
467 }
468 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
469 },
470 _ => panic!("expected Not(...) in right-AND[0]"),
471 }
472 match &v[1] {
473 Expr::Pred(Filter { field_id, .. }) => {
474 assert_eq!(*field_id, 4)
475 }
476 _ => panic!("expected Pred(f4) in right-AND[1]"),
477 }
478 }
479 _ => panic!("expected AND on right branch of top OR"),
480 }
481 }
482 _ => panic!("expected top-level OR"),
483 }
484 }
485
486 #[test]
487 fn range_bounds_roundtrip() {
488 let f = Filter {
490 field_id: 7u32,
491 op: Operator::Range {
492 lower: Bound::Included("aaa".into()),
493 upper: Bound::Excluded("bbb".into()),
494 },
495 };
496
497 match &f.op {
498 Operator::Range { lower, upper } => {
499 if let Bound::Included(l) = lower {
500 assert_eq!(*l, Literal::String("aaa".to_string()));
501 } else {
502 panic!("lower bound should be Included");
503 }
504
505 if let Bound::Excluded(u) = upper {
506 assert_eq!(*u, Literal::String("bbb".to_string()));
507 } else {
508 panic!("upper bound should be Excluded");
509 }
510 }
511 _ => panic!("expected Range operator"),
512 }
513 }
514
515 #[test]
516 fn helper_builders_preserve_structure_and_order() {
517 let f1 = Filter {
518 field_id: 1u32,
519 op: Operator::Equals("a".into()),
520 };
521 let f2 = Filter {
522 field_id: 2u32,
523 op: Operator::Equals("b".into()),
524 };
525 let f3 = Filter {
526 field_id: 3u32,
527 op: Operator::Equals("c".into()),
528 };
529
530 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
531 match and_expr {
532 Expr::And(v) => {
533 assert_eq!(v.len(), 3);
534 match &v[0] {
536 Expr::Pred(Filter { field_id, .. }) => {
537 assert_eq!(*field_id, 1)
538 }
539 _ => panic!(),
540 }
541 match &v[1] {
542 Expr::Pred(Filter { field_id, .. }) => {
543 assert_eq!(*field_id, 2)
544 }
545 _ => panic!(),
546 }
547 match &v[2] {
548 Expr::Pred(Filter { field_id, .. }) => {
549 assert_eq!(*field_id, 3)
550 }
551 _ => panic!(),
552 }
553 }
554 _ => panic!("expected And"),
555 }
556
557 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
558 match or_expr {
559 Expr::Or(v) => {
560 assert_eq!(v.len(), 3);
561 match &v[0] {
563 Expr::Pred(Filter { field_id, .. }) => {
564 assert_eq!(*field_id, 3)
565 }
566 _ => panic!(),
567 }
568 match &v[1] {
569 Expr::Pred(Filter { field_id, .. }) => {
570 assert_eq!(*field_id, 2)
571 }
572 _ => panic!(),
573 }
574 match &v[2] {
575 Expr::Pred(Filter { field_id, .. }) => {
576 assert_eq!(*field_id, 1)
577 }
578 _ => panic!(),
579 }
580 }
581 _ => panic!("expected Or"),
582 }
583 }
584
585 #[test]
586 fn set_and_pattern_ops_hold_borrowed_slices() {
587 let in_values = ["aa".into(), "bb".into(), "cc".into()];
588 let f_in = Filter {
589 field_id: 9u32,
590 op: Operator::In(&in_values),
591 };
592 match f_in.op {
593 Operator::In(arr) => {
594 assert_eq!(arr.len(), 3);
595 }
596 _ => panic!("expected In"),
597 }
598
599 let f_sw = Filter {
600 field_id: 10u32,
601 op: Operator::starts_with("pre", true),
602 };
603 let f_ew = Filter {
604 field_id: 11u32,
605 op: Operator::ends_with("suf", true),
606 };
607 let f_ct = Filter {
608 field_id: 12u32,
609 op: Operator::contains("mid", true),
610 };
611
612 match f_sw.op {
613 Operator::StartsWith {
614 pattern: b,
615 case_sensitive,
616 } => {
617 assert_eq!(b, "pre");
618 assert!(case_sensitive);
619 }
620 _ => panic!(),
621 }
622 match f_ew.op {
623 Operator::EndsWith {
624 pattern: b,
625 case_sensitive,
626 } => {
627 assert_eq!(b, "suf");
628 assert!(case_sensitive);
629 }
630 _ => panic!(),
631 }
632 match f_ct.op {
633 Operator::Contains {
634 pattern: b,
635 case_sensitive,
636 } => {
637 assert_eq!(b, "mid");
638 assert!(case_sensitive);
639 }
640 _ => panic!(),
641 }
642 }
643
644 #[test]
645 fn generic_field_id_works_with_strings() {
646 let f1 = Filter {
648 field_id: "name",
649 op: Operator::Equals("alice".into()),
650 };
651 let f2 = Filter {
652 field_id: "status",
653 op: Operator::GreaterThanOrEquals("active".into()),
654 };
655 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
656
657 match expr {
658 Expr::And(v) => {
659 assert_eq!(v.len(), 2);
660 match &v[0] {
661 Expr::Pred(Filter { field_id, .. }) => {
662 assert_eq!(*field_id, "name")
663 }
664 _ => panic!("expected Pred(name)"),
665 }
666 match &v[1] {
667 Expr::Pred(Filter { field_id, .. }) => {
668 assert_eq!(*field_id, "status")
669 }
670 _ => panic!("expected Pred(status)"),
671 }
672 }
673 _ => panic!("expected And"),
674 }
675 }
676
677 #[test]
678 fn very_deep_not_chain() {
679 let base = Expr::Pred(Filter {
681 field_id: 42u32,
682 op: Operator::Equals("x".into()),
683 });
684 let mut expr = base;
685 for _ in 0..64 {
686 expr = Expr::not(expr);
687 }
688
689 let mut count = 0usize;
691 let mut cur = &expr;
692 loop {
693 match cur {
694 Expr::Not(inner) => {
695 count += 1;
696 cur = inner;
697 }
698 Expr::Pred(Filter { field_id, .. }) => {
699 assert_eq!(*field_id, 42);
700 break;
701 }
702 _ => panic!("unexpected node inside deep NOT chain"),
703 }
704 }
705 assert_eq!(count, 64);
706 }
707
708 #[test]
709 fn literal_construction() {
710 let f = Filter {
711 field_id: "my_u64_col",
712 op: Operator::Range {
713 lower: Bound::Included(150.into()),
714 upper: Bound::Excluded(300.into()),
715 },
716 };
717
718 match f.op {
719 Operator::Range { lower, upper } => {
720 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
721 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
722 }
723 _ => panic!("Expected a range operator"),
724 }
725
726 let f2 = Filter {
727 field_id: "my_str_col",
728 op: Operator::Equals("hello".into()),
729 };
730
731 match f2.op {
732 Operator::Equals(lit) => {
733 assert_eq!(lit, Literal::String("hello".to_string()));
734 }
735 _ => panic!("Expected an equals operator"),
736 }
737 }
738}