1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17 And(Vec<Expr<'a, F>>),
18 Or(Vec<Expr<'a, F>>),
19 Not(Box<Expr<'a, F>>),
20 Pred(Filter<'a, F>),
21 Compare {
22 left: ScalarExpr<F>,
23 op: CompareOp,
24 right: ScalarExpr<F>,
25 },
26 InList {
27 expr: ScalarExpr<F>,
28 list: Vec<ScalarExpr<F>>,
29 negated: bool,
30 },
31 IsNull {
35 expr: ScalarExpr<F>,
36 negated: bool,
37 },
38 Literal(bool),
41 Exists(SubqueryExpr),
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52 pub id: SubqueryId,
54 pub negated: bool,
56}
57
58#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61 pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66 #[inline]
68 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69 Expr::And(fs.into_iter().map(Expr::Pred).collect())
70 }
71
72 #[inline]
74 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76 }
77
78 #[allow(clippy::should_implement_trait)]
80 #[inline]
81 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82 Expr::Not(Box::new(e))
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89 Column(F),
90 Literal(Literal),
91 Binary {
92 left: Box<ScalarExpr<F>>,
93 op: BinaryOp,
94 right: Box<ScalarExpr<F>>,
95 },
96 Not(Box<ScalarExpr<F>>),
98 IsNull {
101 expr: Box<ScalarExpr<F>>,
102 negated: bool,
103 },
104 Aggregate(AggregateCall<F>),
107 GetField {
111 base: Box<ScalarExpr<F>>,
112 field_name: String,
113 },
114 Cast {
116 expr: Box<ScalarExpr<F>>,
117 data_type: DataType,
118 },
119 Compare {
121 left: Box<ScalarExpr<F>>,
122 op: CompareOp,
123 right: Box<ScalarExpr<F>>,
124 },
125 Coalesce(Vec<ScalarExpr<F>>),
127 ScalarSubquery(ScalarSubqueryExpr),
129 Case {
131 operand: Option<Box<ScalarExpr<F>>>,
133 branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135 else_expr: Option<Box<ScalarExpr<F>>>,
137 },
138}
139
140#[derive(Clone, Debug)]
145pub enum AggregateCall<F> {
146 CountStar,
147 Count {
148 expr: Box<ScalarExpr<F>>,
149 distinct: bool,
150 },
151 Sum {
152 expr: Box<ScalarExpr<F>>,
153 distinct: bool,
154 },
155 Avg {
156 expr: Box<ScalarExpr<F>>,
157 distinct: bool,
158 },
159 Min(Box<ScalarExpr<F>>),
160 Max(Box<ScalarExpr<F>>),
161 CountNulls(Box<ScalarExpr<F>>),
162}
163
164impl<F> ScalarExpr<F> {
165 #[inline]
166 pub fn column(field: F) -> Self {
167 Self::Column(field)
168 }
169
170 #[inline]
171 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
172 Self::Literal(lit.into())
173 }
174
175 #[inline]
176 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
177 Self::Binary {
178 left: Box::new(left),
179 op,
180 right: Box::new(right),
181 }
182 }
183
184 #[inline]
185 pub fn logical_not(expr: Self) -> Self {
186 Self::Not(Box::new(expr))
187 }
188
189 #[inline]
190 pub fn is_null(expr: Self, negated: bool) -> Self {
191 Self::IsNull {
192 expr: Box::new(expr),
193 negated,
194 }
195 }
196
197 #[inline]
198 pub fn aggregate(call: AggregateCall<F>) -> Self {
199 Self::Aggregate(call)
200 }
201
202 #[inline]
203 pub fn get_field(base: Self, field_name: String) -> Self {
204 Self::GetField {
205 base: Box::new(base),
206 field_name,
207 }
208 }
209
210 #[inline]
211 pub fn cast(expr: Self, data_type: DataType) -> Self {
212 Self::Cast {
213 expr: Box::new(expr),
214 data_type,
215 }
216 }
217
218 #[inline]
219 pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
220 Self::Compare {
221 left: Box::new(left),
222 op,
223 right: Box::new(right),
224 }
225 }
226
227 #[inline]
228 pub fn coalesce(exprs: Vec<Self>) -> Self {
229 Self::Coalesce(exprs)
230 }
231
232 #[inline]
233 pub fn scalar_subquery(id: SubqueryId) -> Self {
234 Self::ScalarSubquery(ScalarSubqueryExpr { id })
235 }
236
237 #[inline]
238 pub fn case(
239 operand: Option<Self>,
240 branches: Vec<(Self, Self)>,
241 else_expr: Option<Self>,
242 ) -> Self {
243 Self::Case {
244 operand: operand.map(Box::new),
245 branches,
246 else_expr: else_expr.map(Box::new),
247 }
248 }
249}
250
251#[derive(Clone, Copy, Debug, Eq, PartialEq)]
253pub enum BinaryOp {
254 Add,
255 Subtract,
256 Multiply,
257 Divide,
258 Modulo,
259}
260
261#[derive(Clone, Copy, Debug, Eq, PartialEq)]
263pub enum CompareOp {
264 Eq,
265 NotEq,
266 Lt,
267 LtEq,
268 Gt,
269 GtEq,
270}
271
272#[derive(Debug, Clone)]
274pub struct Filter<'a, F> {
275 pub field_id: F,
276 pub op: Operator<'a>,
277}
278
279#[derive(Debug, Clone)]
284pub enum Operator<'a> {
285 Equals(Literal),
286 Range {
287 lower: Bound<Literal>,
288 upper: Bound<Literal>,
289 },
290 GreaterThan(Literal),
291 GreaterThanOrEquals(Literal),
292 LessThan(Literal),
293 LessThanOrEquals(Literal),
294 In(&'a [Literal]),
295 StartsWith {
296 pattern: &'a str,
297 case_sensitive: bool,
298 },
299 EndsWith {
300 pattern: &'a str,
301 case_sensitive: bool,
302 },
303 Contains {
304 pattern: &'a str,
305 case_sensitive: bool,
306 },
307 IsNull,
308 IsNotNull,
309}
310
311impl<'a> Operator<'a> {
312 #[inline]
313 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
314 Operator::StartsWith {
315 pattern,
316 case_sensitive,
317 }
318 }
319
320 #[inline]
321 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
322 Operator::EndsWith {
323 pattern,
324 case_sensitive,
325 }
326 }
327
328 #[inline]
329 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
330 Operator::Contains {
331 pattern,
332 case_sensitive,
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn build_simple_exprs() {
343 let f1 = Filter {
344 field_id: 1,
345 op: Operator::Equals("abc".into()),
346 };
347 let f2 = Filter {
348 field_id: 2,
349 op: Operator::LessThan("zzz".into()),
350 };
351 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
352 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
353 let not_all = Expr::not(all);
354 match any {
355 Expr::Or(v) => assert_eq!(v.len(), 2),
356 _ => panic!("expected Or"),
357 }
358 match not_all {
359 Expr::Not(inner) => match *inner {
360 Expr::And(v) => assert_eq!(v.len(), 2),
361 _ => panic!("expected And inside Not"),
362 },
363 _ => panic!("expected Not"),
364 }
365 }
366
367 #[test]
368 fn complex_nested_shape() {
369 let f1 = Filter {
374 field_id: 1u32,
375 op: Operator::Equals("a".into()),
376 };
377 let f2 = Filter {
378 field_id: 2u32,
379 op: Operator::LessThan("zzz".into()),
380 };
381 let in_values = ["x".into(), "y".into(), "z".into()];
382 let f3 = Filter {
383 field_id: 3u32,
384 op: Operator::In(&in_values),
385 };
386 let f4 = Filter {
387 field_id: 4u32,
388 op: Operator::starts_with("pre", true),
389 };
390
391 let left = Expr::And(vec![
393 Expr::Pred(f1.clone()),
394 Expr::Or(vec![
395 Expr::Pred(f2.clone()),
396 Expr::not(Expr::Pred(f3.clone())),
397 ]),
398 ]);
399 let right = Expr::And(vec![
400 Expr::not(Expr::Pred(f1.clone())),
401 Expr::Pred(f4.clone()),
402 ]);
403 let top = Expr::Or(vec![left, right]);
404
405 match top {
407 Expr::Or(branches) => {
408 assert_eq!(branches.len(), 2);
409 match &branches[0] {
410 Expr::And(v) => {
411 assert_eq!(v.len(), 2);
412 match &v[0] {
414 Expr::Pred(Filter { field_id, .. }) => {
415 assert_eq!(*field_id, 1)
416 }
417 _ => panic!("expected Pred(f1) in left-AND[0]"),
418 }
419 match &v[1] {
420 Expr::Or(or_vec) => {
421 assert_eq!(or_vec.len(), 2);
422 match &or_vec[0] {
423 Expr::Pred(Filter { field_id, .. }) => {
424 assert_eq!(*field_id, 2)
425 }
426 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
427 }
428 match &or_vec[1] {
429 Expr::Not(inner) => match inner.as_ref() {
430 Expr::Pred(Filter { field_id, .. }) => {
431 assert_eq!(*field_id, 3)
432 }
433 _ => panic!(
434 "expected Not(Pred(f3)) in \
435 left-AND[1].OR[1]"
436 ),
437 },
438 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
439 }
440 }
441 _ => panic!("expected OR in left-AND[1]"),
442 }
443 }
444 _ => panic!("expected AND on left branch of top OR"),
445 }
446 match &branches[1] {
447 Expr::And(v) => {
448 assert_eq!(v.len(), 2);
449 match &v[0] {
451 Expr::Not(inner) => match inner.as_ref() {
452 Expr::Pred(Filter { field_id, .. }) => {
453 assert_eq!(*field_id, 1)
454 }
455 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
456 },
457 _ => panic!("expected Not(...) in right-AND[0]"),
458 }
459 match &v[1] {
460 Expr::Pred(Filter { field_id, .. }) => {
461 assert_eq!(*field_id, 4)
462 }
463 _ => panic!("expected Pred(f4) in right-AND[1]"),
464 }
465 }
466 _ => panic!("expected AND on right branch of top OR"),
467 }
468 }
469 _ => panic!("expected top-level OR"),
470 }
471 }
472
473 #[test]
474 fn range_bounds_roundtrip() {
475 let f = Filter {
477 field_id: 7u32,
478 op: Operator::Range {
479 lower: Bound::Included("aaa".into()),
480 upper: Bound::Excluded("bbb".into()),
481 },
482 };
483
484 match &f.op {
485 Operator::Range { lower, upper } => {
486 if let Bound::Included(l) = lower {
487 assert_eq!(*l, Literal::String("aaa".to_string()));
488 } else {
489 panic!("lower bound should be Included");
490 }
491
492 if let Bound::Excluded(u) = upper {
493 assert_eq!(*u, Literal::String("bbb".to_string()));
494 } else {
495 panic!("upper bound should be Excluded");
496 }
497 }
498 _ => panic!("expected Range operator"),
499 }
500 }
501
502 #[test]
503 fn helper_builders_preserve_structure_and_order() {
504 let f1 = Filter {
505 field_id: 1u32,
506 op: Operator::Equals("a".into()),
507 };
508 let f2 = Filter {
509 field_id: 2u32,
510 op: Operator::Equals("b".into()),
511 };
512 let f3 = Filter {
513 field_id: 3u32,
514 op: Operator::Equals("c".into()),
515 };
516
517 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
518 match and_expr {
519 Expr::And(v) => {
520 assert_eq!(v.len(), 3);
521 match &v[0] {
523 Expr::Pred(Filter { field_id, .. }) => {
524 assert_eq!(*field_id, 1)
525 }
526 _ => panic!(),
527 }
528 match &v[1] {
529 Expr::Pred(Filter { field_id, .. }) => {
530 assert_eq!(*field_id, 2)
531 }
532 _ => panic!(),
533 }
534 match &v[2] {
535 Expr::Pred(Filter { field_id, .. }) => {
536 assert_eq!(*field_id, 3)
537 }
538 _ => panic!(),
539 }
540 }
541 _ => panic!("expected And"),
542 }
543
544 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
545 match or_expr {
546 Expr::Or(v) => {
547 assert_eq!(v.len(), 3);
548 match &v[0] {
550 Expr::Pred(Filter { field_id, .. }) => {
551 assert_eq!(*field_id, 3)
552 }
553 _ => panic!(),
554 }
555 match &v[1] {
556 Expr::Pred(Filter { field_id, .. }) => {
557 assert_eq!(*field_id, 2)
558 }
559 _ => panic!(),
560 }
561 match &v[2] {
562 Expr::Pred(Filter { field_id, .. }) => {
563 assert_eq!(*field_id, 1)
564 }
565 _ => panic!(),
566 }
567 }
568 _ => panic!("expected Or"),
569 }
570 }
571
572 #[test]
573 fn set_and_pattern_ops_hold_borrowed_slices() {
574 let in_values = ["aa".into(), "bb".into(), "cc".into()];
575 let f_in = Filter {
576 field_id: 9u32,
577 op: Operator::In(&in_values),
578 };
579 match f_in.op {
580 Operator::In(arr) => {
581 assert_eq!(arr.len(), 3);
582 }
583 _ => panic!("expected In"),
584 }
585
586 let f_sw = Filter {
587 field_id: 10u32,
588 op: Operator::starts_with("pre", true),
589 };
590 let f_ew = Filter {
591 field_id: 11u32,
592 op: Operator::ends_with("suf", true),
593 };
594 let f_ct = Filter {
595 field_id: 12u32,
596 op: Operator::contains("mid", true),
597 };
598
599 match f_sw.op {
600 Operator::StartsWith {
601 pattern: b,
602 case_sensitive,
603 } => {
604 assert_eq!(b, "pre");
605 assert!(case_sensitive);
606 }
607 _ => panic!(),
608 }
609 match f_ew.op {
610 Operator::EndsWith {
611 pattern: b,
612 case_sensitive,
613 } => {
614 assert_eq!(b, "suf");
615 assert!(case_sensitive);
616 }
617 _ => panic!(),
618 }
619 match f_ct.op {
620 Operator::Contains {
621 pattern: b,
622 case_sensitive,
623 } => {
624 assert_eq!(b, "mid");
625 assert!(case_sensitive);
626 }
627 _ => panic!(),
628 }
629 }
630
631 #[test]
632 fn generic_field_id_works_with_strings() {
633 let f1 = Filter {
635 field_id: "name",
636 op: Operator::Equals("alice".into()),
637 };
638 let f2 = Filter {
639 field_id: "status",
640 op: Operator::GreaterThanOrEquals("active".into()),
641 };
642 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
643
644 match expr {
645 Expr::And(v) => {
646 assert_eq!(v.len(), 2);
647 match &v[0] {
648 Expr::Pred(Filter { field_id, .. }) => {
649 assert_eq!(*field_id, "name")
650 }
651 _ => panic!("expected Pred(name)"),
652 }
653 match &v[1] {
654 Expr::Pred(Filter { field_id, .. }) => {
655 assert_eq!(*field_id, "status")
656 }
657 _ => panic!("expected Pred(status)"),
658 }
659 }
660 _ => panic!("expected And"),
661 }
662 }
663
664 #[test]
665 fn very_deep_not_chain() {
666 let base = Expr::Pred(Filter {
668 field_id: 42u32,
669 op: Operator::Equals("x".into()),
670 });
671 let mut expr = base;
672 for _ in 0..64 {
673 expr = Expr::not(expr);
674 }
675
676 let mut count = 0usize;
678 let mut cur = &expr;
679 loop {
680 match cur {
681 Expr::Not(inner) => {
682 count += 1;
683 cur = inner;
684 }
685 Expr::Pred(Filter { field_id, .. }) => {
686 assert_eq!(*field_id, 42);
687 break;
688 }
689 _ => panic!("unexpected node inside deep NOT chain"),
690 }
691 }
692 assert_eq!(count, 64);
693 }
694
695 #[test]
696 fn literal_construction() {
697 let f = Filter {
698 field_id: "my_u64_col",
699 op: Operator::Range {
700 lower: Bound::Included(150.into()),
701 upper: Bound::Excluded(300.into()),
702 },
703 };
704
705 match f.op {
706 Operator::Range { lower, upper } => {
707 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
708 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
709 }
710 _ => panic!("Expected a range operator"),
711 }
712
713 let f2 = Filter {
714 field_id: "my_str_col",
715 op: Operator::Equals("hello".into()),
716 };
717
718 match f2.op {
719 Operator::Equals(lit) => {
720 assert_eq!(lit, Literal::String("hello".to_string()));
721 }
722 _ => panic!("Expected an equals operator"),
723 }
724 }
725}