llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14/// Logical expression over predicates.
15#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17    And(Vec<Expr<'a, F>>),
18    Or(Vec<Expr<'a, F>>),
19    Not(Box<Expr<'a, F>>),
20    Pred(Filter<'a, F>),
21    Compare {
22        left: ScalarExpr<F>,
23        op: CompareOp,
24        right: ScalarExpr<F>,
25    },
26    InList {
27        expr: ScalarExpr<F>,
28        list: Vec<ScalarExpr<F>>,
29        negated: bool,
30    },
31    /// Check if a scalar expression IS NULL or IS NOT NULL.
32    /// For simple column references, prefer `Pred(Filter { op: IsNull/IsNotNull })` for optimization.
33    /// This variant handles complex expressions like `(col1 + col2) IS NULL`.
34    IsNull {
35        expr: ScalarExpr<F>,
36        negated: bool,
37    },
38    /// A literal boolean value (true/false).
39    /// Used for conditions that are always true or always false (e.g., empty IN lists).
40    Literal(bool),
41    /// Correlated subquery evaluated in a boolean context.
42    Exists(SubqueryExpr),
43}
44
45/// Metadata describing a correlated subquery.
46#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49/// Correlated subquery used within a predicate expression.
50#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52    /// Identifier referencing the subquery definition attached to the parent filter.
53    pub id: SubqueryId,
54    /// True when the SQL contained `NOT EXISTS`.
55    pub negated: bool,
56}
57
58/// Scalar subquery evaluated as part of a scalar expression.
59#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61    /// Identifier referencing the subquery definition attached to the parent projection.
62    pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66    /// Build an AND of filters.
67    #[inline]
68    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69        Expr::And(fs.into_iter().map(Expr::Pred).collect())
70    }
71
72    /// Build an OR of filters.
73    #[inline]
74    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76    }
77
78    /// Wrap an expression in a logical NOT.
79    #[allow(clippy::should_implement_trait)]
80    #[inline]
81    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82        Expr::Not(Box::new(e))
83    }
84}
85
86/// Arithmetic scalar expression that can reference multiple fields.
87#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89    Column(F),
90    Literal(Literal),
91    Binary {
92        left: Box<ScalarExpr<F>>,
93        op: BinaryOp,
94        right: Box<ScalarExpr<F>>,
95    },
96    /// Logical NOT returning 1 for falsey inputs, 0 for truthy inputs, and NULL for NULL inputs.
97    Not(Box<ScalarExpr<F>>),
98    /// NULL test returning 1 when the operand is NULL (or NOT NULL when `negated` is true) and 0 otherwise.
99    /// Returns NULL when the operand cannot be determined.
100    IsNull {
101        expr: Box<ScalarExpr<F>>,
102        negated: bool,
103    },
104    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
105    /// This is used in expressions like COUNT(*) + 1
106    Aggregate(AggregateCall<F>),
107    /// Extract a field from a struct expression.
108    /// For example: `user.address.city` would be represented as
109    /// GetField { base: GetField { base: Column(user), field_name: "address" }, field_name: "city" }
110    GetField {
111        base: Box<ScalarExpr<F>>,
112        field_name: String,
113    },
114    /// Explicit type cast to an Arrow data type.
115    Cast {
116        expr: Box<ScalarExpr<F>>,
117        data_type: DataType,
118    },
119    /// Comparison producing a boolean (1/0) result.
120    Compare {
121        left: Box<ScalarExpr<F>>,
122        op: CompareOp,
123        right: Box<ScalarExpr<F>>,
124    },
125    /// First non-null expression in the provided list.
126    Coalesce(Vec<ScalarExpr<F>>),
127    /// Scalar subquery evaluated per input row.
128    ScalarSubquery(ScalarSubqueryExpr),
129    /// SQL CASE expression with optional operand and ELSE branch.
130    Case {
131        /// Optional operand for simple CASE (e.g., `CASE x WHEN ...`).
132        operand: Option<Box<ScalarExpr<F>>>,
133        /// Ordered (WHEN, THEN) branches.
134        branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135        /// Optional ELSE result.
136        else_expr: Option<Box<ScalarExpr<F>>>,
137    },
138    /// Random number generator returning a float in [0.0, 1.0).
139    ///
140    /// Follows the PostgreSQL/DuckDB standard: each evaluation produces a new
141    /// pseudo-random value. No seed control is exposed at the SQL level.
142    Random,
143}
144
145/// Aggregate function call within a scalar expression.
146///
147/// Each variant (except `CountStar`) operates on an expression rather than just a column.
148/// This allows aggregates like `AVG(col1 + col2)` or `SUM(-col1)` to work correctly.
149#[derive(Clone, Debug)]
150pub enum AggregateCall<F> {
151    CountStar,
152    Count {
153        expr: Box<ScalarExpr<F>>,
154        distinct: bool,
155    },
156    Sum {
157        expr: Box<ScalarExpr<F>>,
158        distinct: bool,
159    },
160    Total {
161        expr: Box<ScalarExpr<F>>,
162        distinct: bool,
163    },
164    Avg {
165        expr: Box<ScalarExpr<F>>,
166        distinct: bool,
167    },
168    Min(Box<ScalarExpr<F>>),
169    Max(Box<ScalarExpr<F>>),
170    CountNulls(Box<ScalarExpr<F>>),
171    GroupConcat {
172        expr: Box<ScalarExpr<F>>,
173        distinct: bool,
174        separator: Option<String>,
175    },
176}
177
178impl<F> ScalarExpr<F> {
179    #[inline]
180    pub fn column(field: F) -> Self {
181        Self::Column(field)
182    }
183
184    #[inline]
185    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
186        Self::Literal(lit.into())
187    }
188
189    #[inline]
190    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
191        Self::Binary {
192            left: Box::new(left),
193            op,
194            right: Box::new(right),
195        }
196    }
197
198    #[inline]
199    pub fn logical_not(expr: Self) -> Self {
200        Self::Not(Box::new(expr))
201    }
202
203    #[inline]
204    pub fn is_null(expr: Self, negated: bool) -> Self {
205        Self::IsNull {
206            expr: Box::new(expr),
207            negated,
208        }
209    }
210
211    #[inline]
212    pub fn aggregate(call: AggregateCall<F>) -> Self {
213        Self::Aggregate(call)
214    }
215
216    #[inline]
217    pub fn get_field(base: Self, field_name: String) -> Self {
218        Self::GetField {
219            base: Box::new(base),
220            field_name,
221        }
222    }
223
224    #[inline]
225    pub fn cast(expr: Self, data_type: DataType) -> Self {
226        Self::Cast {
227            expr: Box::new(expr),
228            data_type,
229        }
230    }
231
232    #[inline]
233    pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
234        Self::Compare {
235            left: Box::new(left),
236            op,
237            right: Box::new(right),
238        }
239    }
240
241    #[inline]
242    pub fn coalesce(exprs: Vec<Self>) -> Self {
243        Self::Coalesce(exprs)
244    }
245
246    #[inline]
247    pub fn scalar_subquery(id: SubqueryId) -> Self {
248        Self::ScalarSubquery(ScalarSubqueryExpr { id })
249    }
250
251    #[inline]
252    pub fn case(
253        operand: Option<Self>,
254        branches: Vec<(Self, Self)>,
255        else_expr: Option<Self>,
256    ) -> Self {
257        Self::Case {
258            operand: operand.map(Box::new),
259            branches,
260            else_expr: else_expr.map(Box::new),
261        }
262    }
263
264    #[inline]
265    pub fn random() -> Self {
266        Self::Random
267    }
268}
269
270/// Arithmetic operator for [`ScalarExpr`].
271#[derive(Clone, Copy, Debug, Eq, PartialEq)]
272pub enum BinaryOp {
273    Add,
274    Subtract,
275    Multiply,
276    Divide,
277    Modulo,
278    And,
279    Or,
280    BitwiseShiftLeft,
281    BitwiseShiftRight,
282}
283
284/// Comparison operator for scalar expressions.
285#[derive(Clone, Copy, Debug, Eq, PartialEq)]
286pub enum CompareOp {
287    Eq,
288    NotEq,
289    Lt,
290    LtEq,
291    Gt,
292    GtEq,
293}
294
295/// Single predicate against a field.
296#[derive(Debug, Clone)]
297pub struct Filter<'a, F> {
298    pub field_id: F,
299    pub op: Operator<'a>,
300}
301
302/// Comparison/matching operators over untyped `Literal`s.
303///
304/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
305/// common case of small, static IN lists built at call sites.
306#[derive(Debug, Clone)]
307pub enum Operator<'a> {
308    Equals(Literal),
309    Range {
310        lower: Bound<Literal>,
311        upper: Bound<Literal>,
312    },
313    GreaterThan(Literal),
314    GreaterThanOrEquals(Literal),
315    LessThan(Literal),
316    LessThanOrEquals(Literal),
317    In(&'a [Literal]),
318    StartsWith {
319        pattern: &'a str,
320        case_sensitive: bool,
321    },
322    EndsWith {
323        pattern: &'a str,
324        case_sensitive: bool,
325    },
326    Contains {
327        pattern: &'a str,
328        case_sensitive: bool,
329    },
330    IsNull,
331    IsNotNull,
332}
333
334impl<'a> Operator<'a> {
335    #[inline]
336    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
337        Operator::StartsWith {
338            pattern,
339            case_sensitive,
340        }
341    }
342
343    #[inline]
344    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
345        Operator::EndsWith {
346            pattern,
347            case_sensitive,
348        }
349    }
350
351    #[inline]
352    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
353        Operator::Contains {
354            pattern,
355            case_sensitive,
356        }
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn build_simple_exprs() {
366        let f1 = Filter {
367            field_id: 1,
368            op: Operator::Equals("abc".into()),
369        };
370        let f2 = Filter {
371            field_id: 2,
372            op: Operator::LessThan("zzz".into()),
373        };
374        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
375        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
376        let not_all = Expr::not(all);
377        match any {
378            Expr::Or(v) => assert_eq!(v.len(), 2),
379            _ => panic!("expected Or"),
380        }
381        match not_all {
382            Expr::Not(inner) => match *inner {
383                Expr::And(v) => assert_eq!(v.len(), 2),
384                _ => panic!("expected And inside Not"),
385            },
386            _ => panic!("expected Not"),
387        }
388    }
389
390    #[test]
391    fn complex_nested_shape() {
392        // f1: id=1 == "a"
393        // f2: id=2 <  "zzz"
394        // f3: id=3 in ["x","y","z"]
395        // f4: id=4 starts_with "pre"
396        let f1 = Filter {
397            field_id: 1u32,
398            op: Operator::Equals("a".into()),
399        };
400        let f2 = Filter {
401            field_id: 2u32,
402            op: Operator::LessThan("zzz".into()),
403        };
404        let in_values = ["x".into(), "y".into(), "z".into()];
405        let f3 = Filter {
406            field_id: 3u32,
407            op: Operator::In(&in_values),
408        };
409        let f4 = Filter {
410            field_id: 4u32,
411            op: Operator::starts_with("pre", true),
412        };
413
414        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
415        let left = Expr::And(vec![
416            Expr::Pred(f1.clone()),
417            Expr::Or(vec![
418                Expr::Pred(f2.clone()),
419                Expr::not(Expr::Pred(f3.clone())),
420            ]),
421        ]);
422        let right = Expr::And(vec![
423            Expr::not(Expr::Pred(f1.clone())),
424            Expr::Pred(f4.clone()),
425        ]);
426        let top = Expr::Or(vec![left, right]);
427
428        // Shape checks
429        match top {
430            Expr::Or(branches) => {
431                assert_eq!(branches.len(), 2);
432                match &branches[0] {
433                    Expr::And(v) => {
434                        assert_eq!(v.len(), 2);
435                        // AND: [Pred(f1), OR(...)]
436                        match &v[0] {
437                            Expr::Pred(Filter { field_id, .. }) => {
438                                assert_eq!(*field_id, 1)
439                            }
440                            _ => panic!("expected Pred(f1) in left-AND[0]"),
441                        }
442                        match &v[1] {
443                            Expr::Or(or_vec) => {
444                                assert_eq!(or_vec.len(), 2);
445                                match &or_vec[0] {
446                                    Expr::Pred(Filter { field_id, .. }) => {
447                                        assert_eq!(*field_id, 2)
448                                    }
449                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
450                                }
451                                match &or_vec[1] {
452                                    Expr::Not(inner) => match inner.as_ref() {
453                                        Expr::Pred(Filter { field_id, .. }) => {
454                                            assert_eq!(*field_id, 3)
455                                        }
456                                        _ => panic!(
457                                            "expected Not(Pred(f3)) in \
458                                             left-AND[1].OR[1]"
459                                        ),
460                                    },
461                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
462                                }
463                            }
464                            _ => panic!("expected OR in left-AND[1]"),
465                        }
466                    }
467                    _ => panic!("expected AND on left branch of top OR"),
468                }
469                match &branches[1] {
470                    Expr::And(v) => {
471                        assert_eq!(v.len(), 2);
472                        // AND: [Not(f1), Pred(f4)]
473                        match &v[0] {
474                            Expr::Not(inner) => match inner.as_ref() {
475                                Expr::Pred(Filter { field_id, .. }) => {
476                                    assert_eq!(*field_id, 1)
477                                }
478                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
479                            },
480                            _ => panic!("expected Not(...) in right-AND[0]"),
481                        }
482                        match &v[1] {
483                            Expr::Pred(Filter { field_id, .. }) => {
484                                assert_eq!(*field_id, 4)
485                            }
486                            _ => panic!("expected Pred(f4) in right-AND[1]"),
487                        }
488                    }
489                    _ => panic!("expected AND on right branch of top OR"),
490                }
491            }
492            _ => panic!("expected top-level OR"),
493        }
494    }
495
496    #[test]
497    fn range_bounds_roundtrip() {
498        // [aaa, bbb)
499        let f = Filter {
500            field_id: 7u32,
501            op: Operator::Range {
502                lower: Bound::Included("aaa".into()),
503                upper: Bound::Excluded("bbb".into()),
504            },
505        };
506
507        match &f.op {
508            Operator::Range { lower, upper } => {
509                if let Bound::Included(l) = lower {
510                    assert_eq!(*l, Literal::String("aaa".to_string()));
511                } else {
512                    panic!("lower bound should be Included");
513                }
514
515                if let Bound::Excluded(u) = upper {
516                    assert_eq!(*u, Literal::String("bbb".to_string()));
517                } else {
518                    panic!("upper bound should be Excluded");
519                }
520            }
521            _ => panic!("expected Range operator"),
522        }
523    }
524
525    #[test]
526    fn helper_builders_preserve_structure_and_order() {
527        let f1 = Filter {
528            field_id: 1u32,
529            op: Operator::Equals("a".into()),
530        };
531        let f2 = Filter {
532            field_id: 2u32,
533            op: Operator::Equals("b".into()),
534        };
535        let f3 = Filter {
536            field_id: 3u32,
537            op: Operator::Equals("c".into()),
538        };
539
540        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
541        match and_expr {
542            Expr::And(v) => {
543                assert_eq!(v.len(), 3);
544                // Expect Pred(1), Pred(2), Pred(3) in order
545                match &v[0] {
546                    Expr::Pred(Filter { field_id, .. }) => {
547                        assert_eq!(*field_id, 1)
548                    }
549                    _ => panic!(),
550                }
551                match &v[1] {
552                    Expr::Pred(Filter { field_id, .. }) => {
553                        assert_eq!(*field_id, 2)
554                    }
555                    _ => panic!(),
556                }
557                match &v[2] {
558                    Expr::Pred(Filter { field_id, .. }) => {
559                        assert_eq!(*field_id, 3)
560                    }
561                    _ => panic!(),
562                }
563            }
564            _ => panic!("expected And"),
565        }
566
567        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
568        match or_expr {
569            Expr::Or(v) => {
570                assert_eq!(v.len(), 3);
571                // Expect Pred(3), Pred(2), Pred(1) in order
572                match &v[0] {
573                    Expr::Pred(Filter { field_id, .. }) => {
574                        assert_eq!(*field_id, 3)
575                    }
576                    _ => panic!(),
577                }
578                match &v[1] {
579                    Expr::Pred(Filter { field_id, .. }) => {
580                        assert_eq!(*field_id, 2)
581                    }
582                    _ => panic!(),
583                }
584                match &v[2] {
585                    Expr::Pred(Filter { field_id, .. }) => {
586                        assert_eq!(*field_id, 1)
587                    }
588                    _ => panic!(),
589                }
590            }
591            _ => panic!("expected Or"),
592        }
593    }
594
595    #[test]
596    fn set_and_pattern_ops_hold_borrowed_slices() {
597        let in_values = ["aa".into(), "bb".into(), "cc".into()];
598        let f_in = Filter {
599            field_id: 9u32,
600            op: Operator::In(&in_values),
601        };
602        match f_in.op {
603            Operator::In(arr) => {
604                assert_eq!(arr.len(), 3);
605            }
606            _ => panic!("expected In"),
607        }
608
609        let f_sw = Filter {
610            field_id: 10u32,
611            op: Operator::starts_with("pre", true),
612        };
613        let f_ew = Filter {
614            field_id: 11u32,
615            op: Operator::ends_with("suf", true),
616        };
617        let f_ct = Filter {
618            field_id: 12u32,
619            op: Operator::contains("mid", true),
620        };
621
622        match f_sw.op {
623            Operator::StartsWith {
624                pattern: b,
625                case_sensitive,
626            } => {
627                assert_eq!(b, "pre");
628                assert!(case_sensitive);
629            }
630            _ => panic!(),
631        }
632        match f_ew.op {
633            Operator::EndsWith {
634                pattern: b,
635                case_sensitive,
636            } => {
637                assert_eq!(b, "suf");
638                assert!(case_sensitive);
639            }
640            _ => panic!(),
641        }
642        match f_ct.op {
643            Operator::Contains {
644                pattern: b,
645                case_sensitive,
646            } => {
647                assert_eq!(b, "mid");
648                assert!(case_sensitive);
649            }
650            _ => panic!(),
651        }
652    }
653
654    #[test]
655    fn generic_field_id_works_with_strings() {
656        // Demonstrate F = &'static str
657        let f1 = Filter {
658            field_id: "name",
659            op: Operator::Equals("alice".into()),
660        };
661        let f2 = Filter {
662            field_id: "status",
663            op: Operator::GreaterThanOrEquals("active".into()),
664        };
665        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
666
667        match expr {
668            Expr::And(v) => {
669                assert_eq!(v.len(), 2);
670                match &v[0] {
671                    Expr::Pred(Filter { field_id, .. }) => {
672                        assert_eq!(*field_id, "name")
673                    }
674                    _ => panic!("expected Pred(name)"),
675                }
676                match &v[1] {
677                    Expr::Pred(Filter { field_id, .. }) => {
678                        assert_eq!(*field_id, "status")
679                    }
680                    _ => panic!("expected Pred(status)"),
681                }
682            }
683            _ => panic!("expected And"),
684        }
685    }
686
687    #[test]
688    fn very_deep_not_chain() {
689        // Build Not(Not(...Not(Pred)...)) of depth 64
690        let base = Expr::Pred(Filter {
691            field_id: 42u32,
692            op: Operator::Equals("x".into()),
693        });
694        let mut expr = base;
695        for _ in 0..64 {
696            expr = Expr::not(expr);
697        }
698
699        // Count nested NOTs
700        let mut count = 0usize;
701        let mut cur = &expr;
702        loop {
703            match cur {
704                Expr::Not(inner) => {
705                    count += 1;
706                    cur = inner;
707                }
708                Expr::Pred(Filter { field_id, .. }) => {
709                    assert_eq!(*field_id, 42);
710                    break;
711                }
712                _ => panic!("unexpected node inside deep NOT chain"),
713            }
714        }
715        assert_eq!(count, 64);
716    }
717
718    #[test]
719    fn literal_construction() {
720        let f = Filter {
721            field_id: "my_u64_col",
722            op: Operator::Range {
723                lower: Bound::Included(150.into()),
724                upper: Bound::Excluded(300.into()),
725            },
726        };
727
728        match f.op {
729            Operator::Range { lower, upper } => {
730                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
731                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
732            }
733            _ => panic!("Expected a range operator"),
734        }
735
736        let f2 = Filter {
737            field_id: "my_str_col",
738            op: Operator::Equals("hello".into()),
739        };
740
741        match f2.op {
742            Operator::Equals(lit) => {
743                assert_eq!(lit, Literal::String("hello".to_string()));
744            }
745            _ => panic!("Expected an equals operator"),
746        }
747    }
748}