llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use arrow::datatypes::DataType;
12use std::ops::Bound;
13
14/// Logical expression over predicates.
15#[derive(Clone, Debug)]
16pub enum Expr<'a, F> {
17    And(Vec<Expr<'a, F>>),
18    Or(Vec<Expr<'a, F>>),
19    Not(Box<Expr<'a, F>>),
20    Pred(Filter<'a, F>),
21    Compare {
22        left: ScalarExpr<F>,
23        op: CompareOp,
24        right: ScalarExpr<F>,
25    },
26    InList {
27        expr: ScalarExpr<F>,
28        list: Vec<ScalarExpr<F>>,
29        negated: bool,
30    },
31    /// Check if a scalar expression IS NULL or IS NOT NULL.
32    /// For simple column references, prefer `Pred(Filter { op: IsNull/IsNotNull })` for optimization.
33    /// This variant handles complex expressions like `(col1 + col2) IS NULL`.
34    IsNull {
35        expr: ScalarExpr<F>,
36        negated: bool,
37    },
38    /// A literal boolean value (true/false).
39    /// Used for conditions that are always true or always false (e.g., empty IN lists).
40    Literal(bool),
41    /// Correlated subquery evaluated in a boolean context.
42    Exists(SubqueryExpr),
43}
44
45/// Metadata describing a correlated subquery.
46#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
47pub struct SubqueryId(pub u32);
48
49/// Correlated subquery used within a predicate expression.
50#[derive(Clone, Debug)]
51pub struct SubqueryExpr {
52    /// Identifier referencing the subquery definition attached to the parent filter.
53    pub id: SubqueryId,
54    /// True when the SQL contained `NOT EXISTS`.
55    pub negated: bool,
56}
57
58/// Scalar subquery evaluated as part of a scalar expression.
59#[derive(Clone, Debug)]
60pub struct ScalarSubqueryExpr {
61    /// Identifier referencing the subquery definition attached to the parent projection.
62    pub id: SubqueryId,
63}
64
65impl<'a, F> Expr<'a, F> {
66    /// Build an AND of filters.
67    #[inline]
68    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
69        Expr::And(fs.into_iter().map(Expr::Pred).collect())
70    }
71
72    /// Build an OR of filters.
73    #[inline]
74    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
75        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
76    }
77
78    /// Wrap an expression in a logical NOT.
79    #[allow(clippy::should_implement_trait)]
80    #[inline]
81    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
82        Expr::Not(Box::new(e))
83    }
84}
85
86/// Arithmetic scalar expression that can reference multiple fields.
87#[derive(Clone, Debug)]
88pub enum ScalarExpr<F> {
89    Column(F),
90    Literal(Literal),
91    Binary {
92        left: Box<ScalarExpr<F>>,
93        op: BinaryOp,
94        right: Box<ScalarExpr<F>>,
95    },
96    /// Logical NOT returning 1 for falsey inputs, 0 for truthy inputs, and NULL for NULL inputs.
97    Not(Box<ScalarExpr<F>>),
98    /// NULL test returning 1 when the operand is NULL (or NOT NULL when `negated` is true) and 0 otherwise.
99    /// Returns NULL when the operand cannot be determined.
100    IsNull {
101        expr: Box<ScalarExpr<F>>,
102        negated: bool,
103    },
104    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
105    /// This is used in expressions like COUNT(*) + 1
106    Aggregate(AggregateCall<F>),
107    /// Extract a field from a struct expression.
108    /// For example: `user.address.city` would be represented as
109    /// GetField { base: GetField { base: Column(user), field_name: "address" }, field_name: "city" }
110    GetField {
111        base: Box<ScalarExpr<F>>,
112        field_name: String,
113    },
114    /// Explicit type cast to an Arrow data type.
115    Cast {
116        expr: Box<ScalarExpr<F>>,
117        data_type: DataType,
118    },
119    /// Comparison producing a boolean (1/0) result.
120    Compare {
121        left: Box<ScalarExpr<F>>,
122        op: CompareOp,
123        right: Box<ScalarExpr<F>>,
124    },
125    /// First non-null expression in the provided list.
126    Coalesce(Vec<ScalarExpr<F>>),
127    /// Scalar subquery evaluated per input row.
128    ScalarSubquery(ScalarSubqueryExpr),
129    /// SQL CASE expression with optional operand and ELSE branch.
130    Case {
131        /// Optional operand for simple CASE (e.g., `CASE x WHEN ...`).
132        operand: Option<Box<ScalarExpr<F>>>,
133        /// Ordered (WHEN, THEN) branches.
134        branches: Vec<(ScalarExpr<F>, ScalarExpr<F>)>,
135        /// Optional ELSE result.
136        else_expr: Option<Box<ScalarExpr<F>>>,
137    },
138}
139
140/// Aggregate function call within a scalar expression.
141///
142/// Each variant (except `CountStar`) operates on an expression rather than just a column.
143/// This allows aggregates like `AVG(col1 + col2)` or `SUM(-col1)` to work correctly.
144#[derive(Clone, Debug)]
145pub enum AggregateCall<F> {
146    CountStar,
147    Count {
148        expr: Box<ScalarExpr<F>>,
149        distinct: bool,
150    },
151    Sum {
152        expr: Box<ScalarExpr<F>>,
153        distinct: bool,
154    },
155    Avg {
156        expr: Box<ScalarExpr<F>>,
157        distinct: bool,
158    },
159    Min(Box<ScalarExpr<F>>),
160    Max(Box<ScalarExpr<F>>),
161    CountNulls(Box<ScalarExpr<F>>),
162}
163
164impl<F> ScalarExpr<F> {
165    #[inline]
166    pub fn column(field: F) -> Self {
167        Self::Column(field)
168    }
169
170    #[inline]
171    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
172        Self::Literal(lit.into())
173    }
174
175    #[inline]
176    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
177        Self::Binary {
178            left: Box::new(left),
179            op,
180            right: Box::new(right),
181        }
182    }
183
184    #[inline]
185    pub fn logical_not(expr: Self) -> Self {
186        Self::Not(Box::new(expr))
187    }
188
189    #[inline]
190    pub fn is_null(expr: Self, negated: bool) -> Self {
191        Self::IsNull {
192            expr: Box::new(expr),
193            negated,
194        }
195    }
196
197    #[inline]
198    pub fn aggregate(call: AggregateCall<F>) -> Self {
199        Self::Aggregate(call)
200    }
201
202    #[inline]
203    pub fn get_field(base: Self, field_name: String) -> Self {
204        Self::GetField {
205            base: Box::new(base),
206            field_name,
207        }
208    }
209
210    #[inline]
211    pub fn cast(expr: Self, data_type: DataType) -> Self {
212        Self::Cast {
213            expr: Box::new(expr),
214            data_type,
215        }
216    }
217
218    #[inline]
219    pub fn compare(left: Self, op: CompareOp, right: Self) -> Self {
220        Self::Compare {
221            left: Box::new(left),
222            op,
223            right: Box::new(right),
224        }
225    }
226
227    #[inline]
228    pub fn coalesce(exprs: Vec<Self>) -> Self {
229        Self::Coalesce(exprs)
230    }
231
232    #[inline]
233    pub fn scalar_subquery(id: SubqueryId) -> Self {
234        Self::ScalarSubquery(ScalarSubqueryExpr { id })
235    }
236
237    #[inline]
238    pub fn case(
239        operand: Option<Self>,
240        branches: Vec<(Self, Self)>,
241        else_expr: Option<Self>,
242    ) -> Self {
243        Self::Case {
244            operand: operand.map(Box::new),
245            branches,
246            else_expr: else_expr.map(Box::new),
247        }
248    }
249}
250
251/// Arithmetic operator for [`ScalarExpr`].
252#[derive(Clone, Copy, Debug, Eq, PartialEq)]
253pub enum BinaryOp {
254    Add,
255    Subtract,
256    Multiply,
257    Divide,
258    Modulo,
259}
260
261/// Comparison operator for scalar expressions.
262#[derive(Clone, Copy, Debug, Eq, PartialEq)]
263pub enum CompareOp {
264    Eq,
265    NotEq,
266    Lt,
267    LtEq,
268    Gt,
269    GtEq,
270}
271
272/// Single predicate against a field.
273#[derive(Debug, Clone)]
274pub struct Filter<'a, F> {
275    pub field_id: F,
276    pub op: Operator<'a>,
277}
278
279/// Comparison/matching operators over untyped `Literal`s.
280///
281/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
282/// common case of small, static IN lists built at call sites.
283#[derive(Debug, Clone)]
284pub enum Operator<'a> {
285    Equals(Literal),
286    Range {
287        lower: Bound<Literal>,
288        upper: Bound<Literal>,
289    },
290    GreaterThan(Literal),
291    GreaterThanOrEquals(Literal),
292    LessThan(Literal),
293    LessThanOrEquals(Literal),
294    In(&'a [Literal]),
295    StartsWith {
296        pattern: &'a str,
297        case_sensitive: bool,
298    },
299    EndsWith {
300        pattern: &'a str,
301        case_sensitive: bool,
302    },
303    Contains {
304        pattern: &'a str,
305        case_sensitive: bool,
306    },
307    IsNull,
308    IsNotNull,
309}
310
311impl<'a> Operator<'a> {
312    #[inline]
313    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
314        Operator::StartsWith {
315            pattern,
316            case_sensitive,
317        }
318    }
319
320    #[inline]
321    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
322        Operator::EndsWith {
323            pattern,
324            case_sensitive,
325        }
326    }
327
328    #[inline]
329    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
330        Operator::Contains {
331            pattern,
332            case_sensitive,
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn build_simple_exprs() {
343        let f1 = Filter {
344            field_id: 1,
345            op: Operator::Equals("abc".into()),
346        };
347        let f2 = Filter {
348            field_id: 2,
349            op: Operator::LessThan("zzz".into()),
350        };
351        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
352        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
353        let not_all = Expr::not(all);
354        match any {
355            Expr::Or(v) => assert_eq!(v.len(), 2),
356            _ => panic!("expected Or"),
357        }
358        match not_all {
359            Expr::Not(inner) => match *inner {
360                Expr::And(v) => assert_eq!(v.len(), 2),
361                _ => panic!("expected And inside Not"),
362            },
363            _ => panic!("expected Not"),
364        }
365    }
366
367    #[test]
368    fn complex_nested_shape() {
369        // f1: id=1 == "a"
370        // f2: id=2 <  "zzz"
371        // f3: id=3 in ["x","y","z"]
372        // f4: id=4 starts_with "pre"
373        let f1 = Filter {
374            field_id: 1u32,
375            op: Operator::Equals("a".into()),
376        };
377        let f2 = Filter {
378            field_id: 2u32,
379            op: Operator::LessThan("zzz".into()),
380        };
381        let in_values = ["x".into(), "y".into(), "z".into()];
382        let f3 = Filter {
383            field_id: 3u32,
384            op: Operator::In(&in_values),
385        };
386        let f4 = Filter {
387            field_id: 4u32,
388            op: Operator::starts_with("pre", true),
389        };
390
391        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
392        let left = Expr::And(vec![
393            Expr::Pred(f1.clone()),
394            Expr::Or(vec![
395                Expr::Pred(f2.clone()),
396                Expr::not(Expr::Pred(f3.clone())),
397            ]),
398        ]);
399        let right = Expr::And(vec![
400            Expr::not(Expr::Pred(f1.clone())),
401            Expr::Pred(f4.clone()),
402        ]);
403        let top = Expr::Or(vec![left, right]);
404
405        // Shape checks
406        match top {
407            Expr::Or(branches) => {
408                assert_eq!(branches.len(), 2);
409                match &branches[0] {
410                    Expr::And(v) => {
411                        assert_eq!(v.len(), 2);
412                        // AND: [Pred(f1), OR(...)]
413                        match &v[0] {
414                            Expr::Pred(Filter { field_id, .. }) => {
415                                assert_eq!(*field_id, 1)
416                            }
417                            _ => panic!("expected Pred(f1) in left-AND[0]"),
418                        }
419                        match &v[1] {
420                            Expr::Or(or_vec) => {
421                                assert_eq!(or_vec.len(), 2);
422                                match &or_vec[0] {
423                                    Expr::Pred(Filter { field_id, .. }) => {
424                                        assert_eq!(*field_id, 2)
425                                    }
426                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
427                                }
428                                match &or_vec[1] {
429                                    Expr::Not(inner) => match inner.as_ref() {
430                                        Expr::Pred(Filter { field_id, .. }) => {
431                                            assert_eq!(*field_id, 3)
432                                        }
433                                        _ => panic!(
434                                            "expected Not(Pred(f3)) in \
435                                             left-AND[1].OR[1]"
436                                        ),
437                                    },
438                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
439                                }
440                            }
441                            _ => panic!("expected OR in left-AND[1]"),
442                        }
443                    }
444                    _ => panic!("expected AND on left branch of top OR"),
445                }
446                match &branches[1] {
447                    Expr::And(v) => {
448                        assert_eq!(v.len(), 2);
449                        // AND: [Not(f1), Pred(f4)]
450                        match &v[0] {
451                            Expr::Not(inner) => match inner.as_ref() {
452                                Expr::Pred(Filter { field_id, .. }) => {
453                                    assert_eq!(*field_id, 1)
454                                }
455                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
456                            },
457                            _ => panic!("expected Not(...) in right-AND[0]"),
458                        }
459                        match &v[1] {
460                            Expr::Pred(Filter { field_id, .. }) => {
461                                assert_eq!(*field_id, 4)
462                            }
463                            _ => panic!("expected Pred(f4) in right-AND[1]"),
464                        }
465                    }
466                    _ => panic!("expected AND on right branch of top OR"),
467                }
468            }
469            _ => panic!("expected top-level OR"),
470        }
471    }
472
473    #[test]
474    fn range_bounds_roundtrip() {
475        // [aaa, bbb)
476        let f = Filter {
477            field_id: 7u32,
478            op: Operator::Range {
479                lower: Bound::Included("aaa".into()),
480                upper: Bound::Excluded("bbb".into()),
481            },
482        };
483
484        match &f.op {
485            Operator::Range { lower, upper } => {
486                if let Bound::Included(l) = lower {
487                    assert_eq!(*l, Literal::String("aaa".to_string()));
488                } else {
489                    panic!("lower bound should be Included");
490                }
491
492                if let Bound::Excluded(u) = upper {
493                    assert_eq!(*u, Literal::String("bbb".to_string()));
494                } else {
495                    panic!("upper bound should be Excluded");
496                }
497            }
498            _ => panic!("expected Range operator"),
499        }
500    }
501
502    #[test]
503    fn helper_builders_preserve_structure_and_order() {
504        let f1 = Filter {
505            field_id: 1u32,
506            op: Operator::Equals("a".into()),
507        };
508        let f2 = Filter {
509            field_id: 2u32,
510            op: Operator::Equals("b".into()),
511        };
512        let f3 = Filter {
513            field_id: 3u32,
514            op: Operator::Equals("c".into()),
515        };
516
517        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
518        match and_expr {
519            Expr::And(v) => {
520                assert_eq!(v.len(), 3);
521                // Expect Pred(1), Pred(2), Pred(3) in order
522                match &v[0] {
523                    Expr::Pred(Filter { field_id, .. }) => {
524                        assert_eq!(*field_id, 1)
525                    }
526                    _ => panic!(),
527                }
528                match &v[1] {
529                    Expr::Pred(Filter { field_id, .. }) => {
530                        assert_eq!(*field_id, 2)
531                    }
532                    _ => panic!(),
533                }
534                match &v[2] {
535                    Expr::Pred(Filter { field_id, .. }) => {
536                        assert_eq!(*field_id, 3)
537                    }
538                    _ => panic!(),
539                }
540            }
541            _ => panic!("expected And"),
542        }
543
544        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
545        match or_expr {
546            Expr::Or(v) => {
547                assert_eq!(v.len(), 3);
548                // Expect Pred(3), Pred(2), Pred(1) in order
549                match &v[0] {
550                    Expr::Pred(Filter { field_id, .. }) => {
551                        assert_eq!(*field_id, 3)
552                    }
553                    _ => panic!(),
554                }
555                match &v[1] {
556                    Expr::Pred(Filter { field_id, .. }) => {
557                        assert_eq!(*field_id, 2)
558                    }
559                    _ => panic!(),
560                }
561                match &v[2] {
562                    Expr::Pred(Filter { field_id, .. }) => {
563                        assert_eq!(*field_id, 1)
564                    }
565                    _ => panic!(),
566                }
567            }
568            _ => panic!("expected Or"),
569        }
570    }
571
572    #[test]
573    fn set_and_pattern_ops_hold_borrowed_slices() {
574        let in_values = ["aa".into(), "bb".into(), "cc".into()];
575        let f_in = Filter {
576            field_id: 9u32,
577            op: Operator::In(&in_values),
578        };
579        match f_in.op {
580            Operator::In(arr) => {
581                assert_eq!(arr.len(), 3);
582            }
583            _ => panic!("expected In"),
584        }
585
586        let f_sw = Filter {
587            field_id: 10u32,
588            op: Operator::starts_with("pre", true),
589        };
590        let f_ew = Filter {
591            field_id: 11u32,
592            op: Operator::ends_with("suf", true),
593        };
594        let f_ct = Filter {
595            field_id: 12u32,
596            op: Operator::contains("mid", true),
597        };
598
599        match f_sw.op {
600            Operator::StartsWith {
601                pattern: b,
602                case_sensitive,
603            } => {
604                assert_eq!(b, "pre");
605                assert!(case_sensitive);
606            }
607            _ => panic!(),
608        }
609        match f_ew.op {
610            Operator::EndsWith {
611                pattern: b,
612                case_sensitive,
613            } => {
614                assert_eq!(b, "suf");
615                assert!(case_sensitive);
616            }
617            _ => panic!(),
618        }
619        match f_ct.op {
620            Operator::Contains {
621                pattern: b,
622                case_sensitive,
623            } => {
624                assert_eq!(b, "mid");
625                assert!(case_sensitive);
626            }
627            _ => panic!(),
628        }
629    }
630
631    #[test]
632    fn generic_field_id_works_with_strings() {
633        // Demonstrate F = &'static str
634        let f1 = Filter {
635            field_id: "name",
636            op: Operator::Equals("alice".into()),
637        };
638        let f2 = Filter {
639            field_id: "status",
640            op: Operator::GreaterThanOrEquals("active".into()),
641        };
642        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
643
644        match expr {
645            Expr::And(v) => {
646                assert_eq!(v.len(), 2);
647                match &v[0] {
648                    Expr::Pred(Filter { field_id, .. }) => {
649                        assert_eq!(*field_id, "name")
650                    }
651                    _ => panic!("expected Pred(name)"),
652                }
653                match &v[1] {
654                    Expr::Pred(Filter { field_id, .. }) => {
655                        assert_eq!(*field_id, "status")
656                    }
657                    _ => panic!("expected Pred(status)"),
658                }
659            }
660            _ => panic!("expected And"),
661        }
662    }
663
664    #[test]
665    fn very_deep_not_chain() {
666        // Build Not(Not(...Not(Pred)...)) of depth 64
667        let base = Expr::Pred(Filter {
668            field_id: 42u32,
669            op: Operator::Equals("x".into()),
670        });
671        let mut expr = base;
672        for _ in 0..64 {
673            expr = Expr::not(expr);
674        }
675
676        // Count nested NOTs
677        let mut count = 0usize;
678        let mut cur = &expr;
679        loop {
680            match cur {
681                Expr::Not(inner) => {
682                    count += 1;
683                    cur = inner;
684                }
685                Expr::Pred(Filter { field_id, .. }) => {
686                    assert_eq!(*field_id, 42);
687                    break;
688                }
689                _ => panic!("unexpected node inside deep NOT chain"),
690            }
691        }
692        assert_eq!(count, 64);
693    }
694
695    #[test]
696    fn literal_construction() {
697        let f = Filter {
698            field_id: "my_u64_col",
699            op: Operator::Range {
700                lower: Bound::Included(150.into()),
701                upper: Bound::Excluded(300.into()),
702            },
703        };
704
705        match f.op {
706            Operator::Range { lower, upper } => {
707                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
708                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
709            }
710            _ => panic!("Expected a range operator"),
711        }
712
713        let f2 = Filter {
714            field_id: "my_str_col",
715            op: Operator::Equals("hello".into()),
716        };
717
718        match f2.op {
719            Operator::Equals(lit) => {
720                assert_eq!(lit, Literal::String("hello".to_string()));
721            }
722            _ => panic!("Expected an equals operator"),
723        }
724    }
725}