llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13/// Logical expression over predicates.
14#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16    And(Vec<Expr<'a, F>>),
17    Or(Vec<Expr<'a, F>>),
18    Not(Box<Expr<'a, F>>),
19    Pred(Filter<'a, F>),
20    Compare {
21        left: ScalarExpr<F>,
22        op: CompareOp,
23        right: ScalarExpr<F>,
24    },
25    InList {
26        expr: ScalarExpr<F>,
27        list: Vec<ScalarExpr<F>>,
28        negated: bool,
29    },
30    /// A literal boolean value (true/false).
31    /// Used for conditions that are always true or always false (e.g., empty IN lists).
32    Literal(bool),
33}
34
35impl<'a, F> Expr<'a, F> {
36    /// Build an AND of filters.
37    #[inline]
38    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
39        Expr::And(fs.into_iter().map(Expr::Pred).collect())
40    }
41
42    /// Build an OR of filters.
43    #[inline]
44    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
45        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
46    }
47
48    /// Wrap an expression in a logical NOT.
49    #[allow(clippy::should_implement_trait)]
50    #[inline]
51    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
52        Expr::Not(Box::new(e))
53    }
54}
55
56/// Arithmetic scalar expression that can reference multiple fields.
57#[derive(Clone, Debug)]
58pub enum ScalarExpr<F> {
59    Column(F),
60    Literal(Literal),
61    Binary {
62        left: Box<ScalarExpr<F>>,
63        op: BinaryOp,
64        right: Box<ScalarExpr<F>>,
65    },
66    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
67    /// This is used in expressions like COUNT(*) + 1
68    Aggregate(AggregateCall<F>),
69    /// Extract a field from a struct expression.
70    /// For example: `user.address.city` would be represented as
71    /// GetField { base: GetField { base: Column(user), field_name: "address" }, field_name: "city" }
72    GetField {
73        base: Box<ScalarExpr<F>>,
74        field_name: String,
75    },
76}
77
78/// Aggregate function call within a scalar expression
79#[derive(Clone, Debug)]
80pub enum AggregateCall<F> {
81    CountStar,
82    Count(F),
83    Sum(F),
84    Min(F),
85    Max(F),
86    CountNulls(F),
87}
88
89impl<F> ScalarExpr<F> {
90    #[inline]
91    pub fn column(field: F) -> Self {
92        Self::Column(field)
93    }
94
95    #[inline]
96    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
97        Self::Literal(lit.into())
98    }
99
100    #[inline]
101    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
102        Self::Binary {
103            left: Box::new(left),
104            op,
105            right: Box::new(right),
106        }
107    }
108
109    #[inline]
110    pub fn aggregate(call: AggregateCall<F>) -> Self {
111        Self::Aggregate(call)
112    }
113
114    #[inline]
115    pub fn get_field(base: Self, field_name: String) -> Self {
116        Self::GetField {
117            base: Box::new(base),
118            field_name,
119        }
120    }
121}
122
123/// Arithmetic operator for [`ScalarExpr`].
124#[derive(Clone, Copy, Debug, Eq, PartialEq)]
125pub enum BinaryOp {
126    Add,
127    Subtract,
128    Multiply,
129    Divide,
130    Modulo,
131}
132
133/// Comparison operator for scalar expressions.
134#[derive(Clone, Copy, Debug, Eq, PartialEq)]
135pub enum CompareOp {
136    Eq,
137    NotEq,
138    Lt,
139    LtEq,
140    Gt,
141    GtEq,
142}
143
144/// Single predicate against a field.
145#[derive(Debug, Clone)]
146pub struct Filter<'a, F> {
147    pub field_id: F,
148    pub op: Operator<'a>,
149}
150
151/// Comparison/matching operators over untyped `Literal`s.
152///
153/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
154/// common case of small, static IN lists built at call sites.
155#[derive(Debug, Clone)]
156pub enum Operator<'a> {
157    Equals(Literal),
158    Range {
159        lower: Bound<Literal>,
160        upper: Bound<Literal>,
161    },
162    GreaterThan(Literal),
163    GreaterThanOrEquals(Literal),
164    LessThan(Literal),
165    LessThanOrEquals(Literal),
166    In(&'a [Literal]),
167    StartsWith {
168        pattern: &'a str,
169        case_sensitive: bool,
170    },
171    EndsWith {
172        pattern: &'a str,
173        case_sensitive: bool,
174    },
175    Contains {
176        pattern: &'a str,
177        case_sensitive: bool,
178    },
179    IsNull,
180    IsNotNull,
181}
182
183impl<'a> Operator<'a> {
184    #[inline]
185    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
186        Operator::StartsWith {
187            pattern,
188            case_sensitive,
189        }
190    }
191
192    #[inline]
193    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
194        Operator::EndsWith {
195            pattern,
196            case_sensitive,
197        }
198    }
199
200    #[inline]
201    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
202        Operator::Contains {
203            pattern,
204            case_sensitive,
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn build_simple_exprs() {
215        let f1 = Filter {
216            field_id: 1,
217            op: Operator::Equals("abc".into()),
218        };
219        let f2 = Filter {
220            field_id: 2,
221            op: Operator::LessThan("zzz".into()),
222        };
223        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
224        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
225        let not_all = Expr::not(all);
226        match any {
227            Expr::Or(v) => assert_eq!(v.len(), 2),
228            _ => panic!("expected Or"),
229        }
230        match not_all {
231            Expr::Not(inner) => match *inner {
232                Expr::And(v) => assert_eq!(v.len(), 2),
233                _ => panic!("expected And inside Not"),
234            },
235            _ => panic!("expected Not"),
236        }
237    }
238
239    #[test]
240    fn complex_nested_shape() {
241        // f1: id=1 == "a"
242        // f2: id=2 <  "zzz"
243        // f3: id=3 in ["x","y","z"]
244        // f4: id=4 starts_with "pre"
245        let f1 = Filter {
246            field_id: 1u32,
247            op: Operator::Equals("a".into()),
248        };
249        let f2 = Filter {
250            field_id: 2u32,
251            op: Operator::LessThan("zzz".into()),
252        };
253        let in_values = ["x".into(), "y".into(), "z".into()];
254        let f3 = Filter {
255            field_id: 3u32,
256            op: Operator::In(&in_values),
257        };
258        let f4 = Filter {
259            field_id: 4u32,
260            op: Operator::starts_with("pre", true),
261        };
262
263        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
264        let left = Expr::And(vec![
265            Expr::Pred(f1.clone()),
266            Expr::Or(vec![
267                Expr::Pred(f2.clone()),
268                Expr::not(Expr::Pred(f3.clone())),
269            ]),
270        ]);
271        let right = Expr::And(vec![
272            Expr::not(Expr::Pred(f1.clone())),
273            Expr::Pred(f4.clone()),
274        ]);
275        let top = Expr::Or(vec![left, right]);
276
277        // Shape checks
278        match top {
279            Expr::Or(branches) => {
280                assert_eq!(branches.len(), 2);
281                match &branches[0] {
282                    Expr::And(v) => {
283                        assert_eq!(v.len(), 2);
284                        // AND: [Pred(f1), OR(...)]
285                        match &v[0] {
286                            Expr::Pred(Filter { field_id, .. }) => {
287                                assert_eq!(*field_id, 1)
288                            }
289                            _ => panic!("expected Pred(f1) in left-AND[0]"),
290                        }
291                        match &v[1] {
292                            Expr::Or(or_vec) => {
293                                assert_eq!(or_vec.len(), 2);
294                                match &or_vec[0] {
295                                    Expr::Pred(Filter { field_id, .. }) => {
296                                        assert_eq!(*field_id, 2)
297                                    }
298                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
299                                }
300                                match &or_vec[1] {
301                                    Expr::Not(inner) => match inner.as_ref() {
302                                        Expr::Pred(Filter { field_id, .. }) => {
303                                            assert_eq!(*field_id, 3)
304                                        }
305                                        _ => panic!(
306                                            "expected Not(Pred(f3)) in \
307                                             left-AND[1].OR[1]"
308                                        ),
309                                    },
310                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
311                                }
312                            }
313                            _ => panic!("expected OR in left-AND[1]"),
314                        }
315                    }
316                    _ => panic!("expected AND on left branch of top OR"),
317                }
318                match &branches[1] {
319                    Expr::And(v) => {
320                        assert_eq!(v.len(), 2);
321                        // AND: [Not(f1), Pred(f4)]
322                        match &v[0] {
323                            Expr::Not(inner) => match inner.as_ref() {
324                                Expr::Pred(Filter { field_id, .. }) => {
325                                    assert_eq!(*field_id, 1)
326                                }
327                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
328                            },
329                            _ => panic!("expected Not(...) in right-AND[0]"),
330                        }
331                        match &v[1] {
332                            Expr::Pred(Filter { field_id, .. }) => {
333                                assert_eq!(*field_id, 4)
334                            }
335                            _ => panic!("expected Pred(f4) in right-AND[1]"),
336                        }
337                    }
338                    _ => panic!("expected AND on right branch of top OR"),
339                }
340            }
341            _ => panic!("expected top-level OR"),
342        }
343    }
344
345    #[test]
346    fn range_bounds_roundtrip() {
347        // [aaa, bbb)
348        let f = Filter {
349            field_id: 7u32,
350            op: Operator::Range {
351                lower: Bound::Included("aaa".into()),
352                upper: Bound::Excluded("bbb".into()),
353            },
354        };
355
356        match &f.op {
357            Operator::Range { lower, upper } => {
358                if let Bound::Included(l) = lower {
359                    assert_eq!(*l, Literal::String("aaa".to_string()));
360                } else {
361                    panic!("lower bound should be Included");
362                }
363
364                if let Bound::Excluded(u) = upper {
365                    assert_eq!(*u, Literal::String("bbb".to_string()));
366                } else {
367                    panic!("upper bound should be Excluded");
368                }
369            }
370            _ => panic!("expected Range operator"),
371        }
372    }
373
374    #[test]
375    fn helper_builders_preserve_structure_and_order() {
376        let f1 = Filter {
377            field_id: 1u32,
378            op: Operator::Equals("a".into()),
379        };
380        let f2 = Filter {
381            field_id: 2u32,
382            op: Operator::Equals("b".into()),
383        };
384        let f3 = Filter {
385            field_id: 3u32,
386            op: Operator::Equals("c".into()),
387        };
388
389        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
390        match and_expr {
391            Expr::And(v) => {
392                assert_eq!(v.len(), 3);
393                // Expect Pred(1), Pred(2), Pred(3) in order
394                match &v[0] {
395                    Expr::Pred(Filter { field_id, .. }) => {
396                        assert_eq!(*field_id, 1)
397                    }
398                    _ => panic!(),
399                }
400                match &v[1] {
401                    Expr::Pred(Filter { field_id, .. }) => {
402                        assert_eq!(*field_id, 2)
403                    }
404                    _ => panic!(),
405                }
406                match &v[2] {
407                    Expr::Pred(Filter { field_id, .. }) => {
408                        assert_eq!(*field_id, 3)
409                    }
410                    _ => panic!(),
411                }
412            }
413            _ => panic!("expected And"),
414        }
415
416        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
417        match or_expr {
418            Expr::Or(v) => {
419                assert_eq!(v.len(), 3);
420                // Expect Pred(3), Pred(2), Pred(1) in order
421                match &v[0] {
422                    Expr::Pred(Filter { field_id, .. }) => {
423                        assert_eq!(*field_id, 3)
424                    }
425                    _ => panic!(),
426                }
427                match &v[1] {
428                    Expr::Pred(Filter { field_id, .. }) => {
429                        assert_eq!(*field_id, 2)
430                    }
431                    _ => panic!(),
432                }
433                match &v[2] {
434                    Expr::Pred(Filter { field_id, .. }) => {
435                        assert_eq!(*field_id, 1)
436                    }
437                    _ => panic!(),
438                }
439            }
440            _ => panic!("expected Or"),
441        }
442    }
443
444    #[test]
445    fn set_and_pattern_ops_hold_borrowed_slices() {
446        let in_values = ["aa".into(), "bb".into(), "cc".into()];
447        let f_in = Filter {
448            field_id: 9u32,
449            op: Operator::In(&in_values),
450        };
451        match f_in.op {
452            Operator::In(arr) => {
453                assert_eq!(arr.len(), 3);
454            }
455            _ => panic!("expected In"),
456        }
457
458        let f_sw = Filter {
459            field_id: 10u32,
460            op: Operator::starts_with("pre", true),
461        };
462        let f_ew = Filter {
463            field_id: 11u32,
464            op: Operator::ends_with("suf", true),
465        };
466        let f_ct = Filter {
467            field_id: 12u32,
468            op: Operator::contains("mid", true),
469        };
470
471        match f_sw.op {
472            Operator::StartsWith {
473                pattern: b,
474                case_sensitive,
475            } => {
476                assert_eq!(b, "pre");
477                assert!(case_sensitive);
478            }
479            _ => panic!(),
480        }
481        match f_ew.op {
482            Operator::EndsWith {
483                pattern: b,
484                case_sensitive,
485            } => {
486                assert_eq!(b, "suf");
487                assert!(case_sensitive);
488            }
489            _ => panic!(),
490        }
491        match f_ct.op {
492            Operator::Contains {
493                pattern: b,
494                case_sensitive,
495            } => {
496                assert_eq!(b, "mid");
497                assert!(case_sensitive);
498            }
499            _ => panic!(),
500        }
501    }
502
503    #[test]
504    fn generic_field_id_works_with_strings() {
505        // Demonstrate F = &'static str
506        let f1 = Filter {
507            field_id: "name",
508            op: Operator::Equals("alice".into()),
509        };
510        let f2 = Filter {
511            field_id: "status",
512            op: Operator::GreaterThanOrEquals("active".into()),
513        };
514        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
515
516        match expr {
517            Expr::And(v) => {
518                assert_eq!(v.len(), 2);
519                match &v[0] {
520                    Expr::Pred(Filter { field_id, .. }) => {
521                        assert_eq!(*field_id, "name")
522                    }
523                    _ => panic!("expected Pred(name)"),
524                }
525                match &v[1] {
526                    Expr::Pred(Filter { field_id, .. }) => {
527                        assert_eq!(*field_id, "status")
528                    }
529                    _ => panic!("expected Pred(status)"),
530                }
531            }
532            _ => panic!("expected And"),
533        }
534    }
535
536    #[test]
537    fn very_deep_not_chain() {
538        // Build Not(Not(...Not(Pred)...)) of depth 64
539        let base = Expr::Pred(Filter {
540            field_id: 42u32,
541            op: Operator::Equals("x".into()),
542        });
543        let mut expr = base;
544        for _ in 0..64 {
545            expr = Expr::not(expr);
546        }
547
548        // Count nested NOTs
549        let mut count = 0usize;
550        let mut cur = &expr;
551        loop {
552            match cur {
553                Expr::Not(inner) => {
554                    count += 1;
555                    cur = inner;
556                }
557                Expr::Pred(Filter { field_id, .. }) => {
558                    assert_eq!(*field_id, 42);
559                    break;
560                }
561                _ => panic!("unexpected node inside deep NOT chain"),
562            }
563        }
564        assert_eq!(count, 64);
565    }
566
567    #[test]
568    fn literal_construction() {
569        let f = Filter {
570            field_id: "my_u64_col",
571            op: Operator::Range {
572                lower: Bound::Included(150.into()),
573                upper: Bound::Excluded(300.into()),
574            },
575        };
576
577        match f.op {
578            Operator::Range { lower, upper } => {
579                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
580                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
581            }
582            _ => panic!("Expected a range operator"),
583        }
584
585        let f2 = Filter {
586            field_id: "my_str_col",
587            op: Operator::Equals("hello".into()),
588        };
589
590        match f2.op {
591            Operator::Equals(lit) => {
592                assert_eq!(lit, Literal::String("hello".to_string()));
593            }
594            _ => panic!("Expected an equals operator"),
595        }
596    }
597}