llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14/// Logical expression over predicates.
15#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17    And(Vec<Expr<'a, F>>),
18    Or(Vec<Expr<'a, F>>),
19    Not(Box<Expr<'a, F>>),
20    Pred(Filter<'a, F>),
21    Compare {
22        left: ScalarExpr<F>,
23        op: CompareOp,
24        right: ScalarExpr<F>,
25    },
26    InList {
27        expr: ScalarExpr<F>,
28        list: Vec<ScalarExpr<F>>,
29        negated: bool,
30    },
31    /// Check if a scalar expression IS NULL or IS NOT NULL.
32    /// For simple column references, prefer `Pred(Filter { op: IsNull/IsNotNull })` for optimization.
33    /// This variant handles complex expressions like `(col1 + col2) IS NULL`.
34    IsNull {
35        expr: ScalarExpr<F>,
36        negated: bool,
37    },
38    /// A literal boolean value (true/false).
39    /// Used for conditions that are always true or always false (e.g., empty IN lists).
40    Literal(bool),
41    /// Correlated subquery evaluated in a boolean context.
42    Exists(SubqueryExpr),
43}
44
45/// Metadata describing a correlated subquery.
46#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49/// Correlated subquery used within a predicate expression.
50#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52    /// Identifier referencing the subquery definition attached to the parent filter.
53    pub id: SubqueryId,
54    /// True when the SQL contained `NOT EXISTS`.
55    pub negated: bool,
56}
57
58/// Scalar subquery evaluated as part of a scalar expression.
59#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61    /// Identifier referencing the subquery definition attached to the parent projection.
62    pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66    /// Build an AND of filters.
67    #[inline]
68    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69        Expr::And(fs.into_iter().map(Expr::Pred).collect())
70    }
71
72    /// Build an OR of filters.
73    #[inline]
74    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76    }
77
78    /// Wrap an expression in a logical NOT.
79    #[allow(clippy::should_implement_trait)]
80    #[inline]
81    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82        Expr::Not(Box::new(e))
83    }
84}
85
86/// Arithmetic scalar expression that can reference multiple fields.
87#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89    Column(F),
90    Literal(Literal),
91    Binary {
92        left: Box<ScalarExpr<F>>,
93        op: BinaryOp,
94        right: Box<ScalarExpr<F>>,
95    },
96    /// Logical NOT returning 1 for falsey inputs, 0 for truthy inputs, and NULL for NULL inputs.
97    Not(Box<ScalarExpr<F>>),
98    /// NULL test returning 1 when the operand is NULL (or NOT NULL when `negated` is true) and 0 otherwise.
99    /// Returns NULL when the operand cannot be determined.
100    IsNull {
101        expr: Box<ScalarExpr<F>>,
102        negated: bool,
103    },
104    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
105    /// This is used in expressions like COUNT(*) + 1
106    Aggregate(AggregateCall<F>),
107    /// Extract a field from a struct expression.
108    /// For example: `user.address.city` would be represented as
109    /// GetField { base: GetField { base: Column(user), field_name: "address" }, field_name: "city" }
110    GetField {
111        base: Box<ScalarExpr<F>>,
112        field_name: String,
113    },
114    /// Explicit type cast to an Arrow data type.
115    Cast {
116        expr: Box<ScalarExpr<F>>,
117        data_type: DataType,
118    },
119    /// Comparison producing a boolean (1/0) result.
120    Compare {
121        left: Box<ScalarExpr<F>>,
122        op: CompareOp,
123        right: Box<ScalarExpr<F>>,
124    },
125    /// First non-null expression in the provided list.
126    Coalesce(Vec<ScalarExpr<F>>),
127    /// Scalar subquery evaluated per input row.
128    ScalarSubquery(ScalarSubqueryExpr),
129    /// SQL CASE expression with optional operand and ELSE branch.
130    Case {
131        /// Optional operand for simple CASE (e.g., `CASE x WHEN ...`).
132        operand: Option<Box<ScalarExpr<F>>>,
133        /// Ordered (WHEN, THEN) branches.
134        branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135        /// Optional ELSE result.
136        else_expr: Option<Box<ScalarExpr<F>>>,
137    },
138}
139
140/// Aggregate function call within a scalar expression.
141///
142/// Each variant (except `CountStar`) operates on an expression rather than just a column.
143/// This allows aggregates like `AVG(col1 + col2)` or `SUM(-col1)` to work correctly.
144#[derive(Clone, Debug)]
145pub enum AggregateCall<F> {
146    CountStar,
147    Count {
148        expr: Box<ScalarExpr<F>>,
149        distinct: bool,
150    },
151    Sum {
152        expr: Box<ScalarExpr<F>>,
153        distinct: bool,
154    },
155    Total {
156        expr: Box<ScalarExpr<F>>,
157        distinct: bool,
158    },
159    Avg {
160        expr: Box<ScalarExpr<F>>,
161        distinct: bool,
162    },
163    Min(Box<ScalarExpr<F>>),
164    Max(Box<ScalarExpr<F>>),
165    CountNulls(Box<ScalarExpr<F>>),
166    GroupConcat {
167        expr: Box<ScalarExpr<F>>,
168        distinct: bool,
169        separator: Option<String>,
170    },
171}
172
173impl<F> ScalarExpr<F> {
174    #[inline]
175    pub fn column(field: F) -> Self {
176        Self::Column(field)
177    }
178
179    #[inline]
180    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
181        Self::Literal(lit.into())
182    }
183
184    #[inline]
185    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
186        Self::Binary {
187            left: Box::new(left),
188            op,
189            right: Box::new(right),
190        }
191    }
192
193    #[inline]
194    pub fn logical_not(expr: Self) -> Self {
195        Self::Not(Box::new(expr))
196    }
197
198    #[inline]
199    pub fn is_null(expr: Self, negated: bool) -> Self {
200        Self::IsNull {
201            expr: Box::new(expr),
202            negated,
203        }
204    }
205
206    #[inline]
207    pub fn aggregate(call: AggregateCall<F>) -> Self {
208        Self::Aggregate(call)
209    }
210
211    #[inline]
212    pub fn get_field(base: Self, field_name: String) -> Self {
213        Self::GetField {
214            base: Box::new(base),
215            field_name,
216        }
217    }
218
219    #[inline]
220    pub fn cast(expr: Self, data_type: DataType) -> Self {
221        Self::Cast {
222            expr: Box::new(expr),
223            data_type,
224        }
225    }
226
227    #[inline]
228    pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
229        Self::Compare {
230            left: Box::new(left),
231            op,
232            right: Box::new(right),
233        }
234    }
235
236    #[inline]
237    pub fn coalesce(exprs: Vec<Self>) -> Self {
238        Self::Coalesce(exprs)
239    }
240
241    #[inline]
242    pub fn scalar_subquery(id: SubqueryId) -> Self {
243        Self::ScalarSubquery(ScalarSubqueryExpr { id })
244    }
245
246    #[inline]
247    pub fn case(
248        operand: Option<Self>,
249        branches: Vec<(Self, Self)>,
250        else_expr: Option<Self>,
251    ) -> Self {
252        Self::Case {
253            operand: operand.map(Box::new),
254            branches,
255            else_expr: else_expr.map(Box::new),
256        }
257    }
258}
259
260/// Arithmetic operator for [`ScalarExpr`].
261#[derive(Clone, Copy, Debug, Eq, PartialEq)]
262pub enum BinaryOp {
263    Add,
264    Subtract,
265    Multiply,
266    Divide,
267    Modulo,
268    And,
269    Or,
270    BitwiseShiftLeft,
271    BitwiseShiftRight,
272}
273
274/// Comparison operator for scalar expressions.
275#[derive(Clone, Copy, Debug, Eq, PartialEq)]
276pub enum CompareOp {
277    Eq,
278    NotEq,
279    Lt,
280    LtEq,
281    Gt,
282    GtEq,
283}
284
285/// Single predicate against a field.
286#[derive(Debug, Clone)]
287pub struct Filter<'a, F> {
288    pub field_id: F,
289    pub op: Operator<'a>,
290}
291
292/// Comparison/matching operators over untyped `Literal`s.
293///
294/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
295/// common case of small, static IN lists built at call sites.
296#[derive(Debug, Clone)]
297pub enum Operator<'a> {
298    Equals(Literal),
299    Range {
300        lower: Bound<Literal>,
301        upper: Bound<Literal>,
302    },
303    GreaterThan(Literal),
304    GreaterThanOrEquals(Literal),
305    LessThan(Literal),
306    LessThanOrEquals(Literal),
307    In(&'a [Literal]),
308    StartsWith {
309        pattern: &'a str,
310        case_sensitive: bool,
311    },
312    EndsWith {
313        pattern: &'a str,
314        case_sensitive: bool,
315    },
316    Contains {
317        pattern: &'a str,
318        case_sensitive: bool,
319    },
320    IsNull,
321    IsNotNull,
322}
323
324impl<'a> Operator<'a> {
325    #[inline]
326    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
327        Operator::StartsWith {
328            pattern,
329            case_sensitive,
330        }
331    }
332
333    #[inline]
334    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
335        Operator::EndsWith {
336            pattern,
337            case_sensitive,
338        }
339    }
340
341    #[inline]
342    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
343        Operator::Contains {
344            pattern,
345            case_sensitive,
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn build_simple_exprs() {
356        let f1 = Filter {
357            field_id: 1,
358            op: Operator::Equals("abc".into()),
359        };
360        let f2 = Filter {
361            field_id: 2,
362            op: Operator::LessThan("zzz".into()),
363        };
364        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
365        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
366        let not_all = Expr::not(all);
367        match any {
368            Expr::Or(v) => assert_eq!(v.len(), 2),
369            _ => panic!("expected Or"),
370        }
371        match not_all {
372            Expr::Not(inner) => match *inner {
373                Expr::And(v) => assert_eq!(v.len(), 2),
374                _ => panic!("expected And inside Not"),
375            },
376            _ => panic!("expected Not"),
377        }
378    }
379
380    #[test]
381    fn complex_nested_shape() {
382        // f1: id=1 == "a"
383        // f2: id=2 <  "zzz"
384        // f3: id=3 in ["x","y","z"]
385        // f4: id=4 starts_with "pre"
386        let f1 = Filter {
387            field_id: 1u32,
388            op: Operator::Equals("a".into()),
389        };
390        let f2 = Filter {
391            field_id: 2u32,
392            op: Operator::LessThan("zzz".into()),
393        };
394        let in_values = ["x".into(), "y".into(), "z".into()];
395        let f3 = Filter {
396            field_id: 3u32,
397            op: Operator::In(&in_values),
398        };
399        let f4 = Filter {
400            field_id: 4u32,
401            op: Operator::starts_with("pre", true),
402        };
403
404        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
405        let left = Expr::And(vec![
406            Expr::Pred(f1.clone()),
407            Expr::Or(vec![
408                Expr::Pred(f2.clone()),
409                Expr::not(Expr::Pred(f3.clone())),
410            ]),
411        ]);
412        let right = Expr::And(vec![
413            Expr::not(Expr::Pred(f1.clone())),
414            Expr::Pred(f4.clone()),
415        ]);
416        let top = Expr::Or(vec![left, right]);
417
418        // Shape checks
419        match top {
420            Expr::Or(branches) => {
421                assert_eq!(branches.len(), 2);
422                match &branches[0] {
423                    Expr::And(v) => {
424                        assert_eq!(v.len(), 2);
425                        // AND: [Pred(f1), OR(...)]
426                        match &v[0] {
427                            Expr::Pred(Filter { field_id, .. }) => {
428                                assert_eq!(*field_id, 1)
429                            }
430                            _ => panic!("expected Pred(f1) in left-AND[0]"),
431                        }
432                        match &v[1] {
433                            Expr::Or(or_vec) => {
434                                assert_eq!(or_vec.len(), 2);
435                                match &or_vec[0] {
436                                    Expr::Pred(Filter { field_id, .. }) => {
437                                        assert_eq!(*field_id, 2)
438                                    }
439                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
440                                }
441                                match &or_vec[1] {
442                                    Expr::Not(inner) => match inner.as_ref() {
443                                        Expr::Pred(Filter { field_id, .. }) => {
444                                            assert_eq!(*field_id, 3)
445                                        }
446                                        _ => panic!(
447                                            "expected Not(Pred(f3)) in \
448                                             left-AND[1].OR[1]"
449                                        ),
450                                    },
451                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
452                                }
453                            }
454                            _ => panic!("expected OR in left-AND[1]"),
455                        }
456                    }
457                    _ => panic!("expected AND on left branch of top OR"),
458                }
459                match &branches[1] {
460                    Expr::And(v) => {
461                        assert_eq!(v.len(), 2);
462                        // AND: [Not(f1), Pred(f4)]
463                        match &v[0] {
464                            Expr::Not(inner) => match inner.as_ref() {
465                                Expr::Pred(Filter { field_id, .. }) => {
466                                    assert_eq!(*field_id, 1)
467                                }
468                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
469                            },
470                            _ => panic!("expected Not(...) in right-AND[0]"),
471                        }
472                        match &v[1] {
473                            Expr::Pred(Filter { field_id, .. }) => {
474                                assert_eq!(*field_id, 4)
475                            }
476                            _ => panic!("expected Pred(f4) in right-AND[1]"),
477                        }
478                    }
479                    _ => panic!("expected AND on right branch of top OR"),
480                }
481            }
482            _ => panic!("expected top-level OR"),
483        }
484    }
485
486    #[test]
487    fn range_bounds_roundtrip() {
488        // [aaa, bbb)
489        let f = Filter {
490            field_id: 7u32,
491            op: Operator::Range {
492                lower: Bound::Included("aaa".into()),
493                upper: Bound::Excluded("bbb".into()),
494            },
495        };
496
497        match &f.op {
498            Operator::Range { lower, upper } => {
499                if let Bound::Included(l) = lower {
500                    assert_eq!(*l, Literal::String("aaa".to_string()));
501                } else {
502                    panic!("lower bound should be Included");
503                }
504
505                if let Bound::Excluded(u) = upper {
506                    assert_eq!(*u, Literal::String("bbb".to_string()));
507                } else {
508                    panic!("upper bound should be Excluded");
509                }
510            }
511            _ => panic!("expected Range operator"),
512        }
513    }
514
515    #[test]
516    fn helper_builders_preserve_structure_and_order() {
517        let f1 = Filter {
518            field_id: 1u32,
519            op: Operator::Equals("a".into()),
520        };
521        let f2 = Filter {
522            field_id: 2u32,
523            op: Operator::Equals("b".into()),
524        };
525        let f3 = Filter {
526            field_id: 3u32,
527            op: Operator::Equals("c".into()),
528        };
529
530        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
531        match and_expr {
532            Expr::And(v) => {
533                assert_eq!(v.len(), 3);
534                // Expect Pred(1), Pred(2), Pred(3) in order
535                match &v[0] {
536                    Expr::Pred(Filter { field_id, .. }) => {
537                        assert_eq!(*field_id, 1)
538                    }
539                    _ => panic!(),
540                }
541                match &v[1] {
542                    Expr::Pred(Filter { field_id, .. }) => {
543                        assert_eq!(*field_id, 2)
544                    }
545                    _ => panic!(),
546                }
547                match &v[2] {
548                    Expr::Pred(Filter { field_id, .. }) => {
549                        assert_eq!(*field_id, 3)
550                    }
551                    _ => panic!(),
552                }
553            }
554            _ => panic!("expected And"),
555        }
556
557        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
558        match or_expr {
559            Expr::Or(v) => {
560                assert_eq!(v.len(), 3);
561                // Expect Pred(3), Pred(2), Pred(1) in order
562                match &v[0] {
563                    Expr::Pred(Filter { field_id, .. }) => {
564                        assert_eq!(*field_id, 3)
565                    }
566                    _ => panic!(),
567                }
568                match &v[1] {
569                    Expr::Pred(Filter { field_id, .. }) => {
570                        assert_eq!(*field_id, 2)
571                    }
572                    _ => panic!(),
573                }
574                match &v[2] {
575                    Expr::Pred(Filter { field_id, .. }) => {
576                        assert_eq!(*field_id, 1)
577                    }
578                    _ => panic!(),
579                }
580            }
581            _ => panic!("expected Or"),
582        }
583    }
584
585    #[test]
586    fn set_and_pattern_ops_hold_borrowed_slices() {
587        let in_values = ["aa".into(), "bb".into(), "cc".into()];
588        let f_in = Filter {
589            field_id: 9u32,
590            op: Operator::In(&in_values),
591        };
592        match f_in.op {
593            Operator::In(arr) => {
594                assert_eq!(arr.len(), 3);
595            }
596            _ => panic!("expected In"),
597        }
598
599        let f_sw = Filter {
600            field_id: 10u32,
601            op: Operator::starts_with("pre", true),
602        };
603        let f_ew = Filter {
604            field_id: 11u32,
605            op: Operator::ends_with("suf", true),
606        };
607        let f_ct = Filter {
608            field_id: 12u32,
609            op: Operator::contains("mid", true),
610        };
611
612        match f_sw.op {
613            Operator::StartsWith {
614                pattern: b,
615                case_sensitive,
616            } => {
617                assert_eq!(b, "pre");
618                assert!(case_sensitive);
619            }
620            _ => panic!(),
621        }
622        match f_ew.op {
623            Operator::EndsWith {
624                pattern: b,
625                case_sensitive,
626            } => {
627                assert_eq!(b, "suf");
628                assert!(case_sensitive);
629            }
630            _ => panic!(),
631        }
632        match f_ct.op {
633            Operator::Contains {
634                pattern: b,
635                case_sensitive,
636            } => {
637                assert_eq!(b, "mid");
638                assert!(case_sensitive);
639            }
640            _ => panic!(),
641        }
642    }
643
644    #[test]
645    fn generic_field_id_works_with_strings() {
646        // Demonstrate F = &'static str
647        let f1 = Filter {
648            field_id: "name",
649            op: Operator::Equals("alice".into()),
650        };
651        let f2 = Filter {
652            field_id: "status",
653            op: Operator::GreaterThanOrEquals("active".into()),
654        };
655        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
656
657        match expr {
658            Expr::And(v) => {
659                assert_eq!(v.len(), 2);
660                match &v[0] {
661                    Expr::Pred(Filter { field_id, .. }) => {
662                        assert_eq!(*field_id, "name")
663                    }
664                    _ => panic!("expected Pred(name)"),
665                }
666                match &v[1] {
667                    Expr::Pred(Filter { field_id, .. }) => {
668                        assert_eq!(*field_id, "status")
669                    }
670                    _ => panic!("expected Pred(status)"),
671                }
672            }
673            _ => panic!("expected And"),
674        }
675    }
676
677    #[test]
678    fn very_deep_not_chain() {
679        // Build Not(Not(...Not(Pred)...)) of depth 64
680        let base = Expr::Pred(Filter {
681            field_id: 42u32,
682            op: Operator::Equals("x".into()),
683        });
684        let mut expr = base;
685        for _ in 0..64 {
686            expr = Expr::not(expr);
687        }
688
689        // Count nested NOTs
690        let mut count = 0usize;
691        let mut cur = &expr;
692        loop {
693            match cur {
694                Expr::Not(inner) => {
695                    count += 1;
696                    cur = inner;
697                }
698                Expr::Pred(Filter { field_id, .. }) => {
699                    assert_eq!(*field_id, 42);
700                    break;
701                }
702                _ => panic!("unexpected node inside deep NOT chain"),
703            }
704        }
705        assert_eq!(count, 64);
706    }
707
708    #[test]
709    fn literal_construction() {
710        let f = Filter {
711            field_id: "my_u64_col",
712            op: Operator::Range {
713                lower: Bound::Included(150.into()),
714                upper: Bound::Excluded(300.into()),
715            },
716        };
717
718        match f.op {
719            Operator::Range { lower, upper } => {
720                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
721                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
722            }
723            _ => panic!("Expected a range operator"),
724        }
725
726        let f2 = Filter {
727            field_id: "my_str_col",
728            op: Operator::Equals("hello".into()),
729        };
730
731        match f2.op {
732            Operator::Equals(lit) => {
733                assert_eq!(lit, Literal::String("hello".to_string()));
734            }
735            _ => panic!("Expected an equals operator"),
736        }
737    }
738}