llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13/// Logical expression over predicates.
14#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16    And(Vec<Expr<'a, F>>),
17    Or(Vec<Expr<'a, F>>),
18    Not(Box<Expr<'a, F>>),
19    Pred(Filter<'a, F>),
20    Compare {
21        left: ScalarExpr<F>,
22        op: CompareOp,
23        right: ScalarExpr<F>,
24    },
25}
26
27impl<'a, F> Expr<'a, F> {
28    /// Build an AND of filters.
29    #[inline]
30    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
31        Expr::And(fs.into_iter().map(Expr::Pred).collect())
32    }
33
34    /// Build an OR of filters.
35    #[inline]
36    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
37        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
38    }
39
40    /// Wrap an expression in a logical NOT.
41    #[allow(clippy::should_implement_trait)]
42    #[inline]
43    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
44        Expr::Not(Box::new(e))
45    }
46}
47
48/// Arithmetic scalar expression that can reference multiple fields.
49#[derive(Clone, Debug)]
50pub enum ScalarExpr<F> {
51    Column(F),
52    Literal(Literal),
53    Binary {
54        left: Box<ScalarExpr<F>>,
55        op: BinaryOp,
56        right: Box<ScalarExpr<F>>,
57    },
58    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
59    /// This is used in expressions like COUNT(*) + 1
60    Aggregate(AggregateCall<F>),
61    /// Extract a field from a struct expression.
62    /// For example: `user.address.city` would be represented as
63    /// GetField { base: GetField { base: Column(user), field_name: "address" }, field_name: "city" }
64    GetField {
65        base: Box<ScalarExpr<F>>,
66        field_name: String,
67    },
68}
69
70/// Aggregate function call within a scalar expression
71#[derive(Clone, Debug)]
72pub enum AggregateCall<F> {
73    CountStar,
74    Count(F),
75    Sum(F),
76    Min(F),
77    Max(F),
78    CountNulls(F),
79}
80
81impl<F> ScalarExpr<F> {
82    #[inline]
83    pub fn column(field: F) -> Self {
84        Self::Column(field)
85    }
86
87    #[inline]
88    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
89        Self::Literal(lit.into())
90    }
91
92    #[inline]
93    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
94        Self::Binary {
95            left: Box::new(left),
96            op,
97            right: Box::new(right),
98        }
99    }
100
101    #[inline]
102    pub fn aggregate(call: AggregateCall<F>) -> Self {
103        Self::Aggregate(call)
104    }
105
106    #[inline]
107    pub fn get_field(base: Self, field_name: String) -> Self {
108        Self::GetField {
109            base: Box::new(base),
110            field_name,
111        }
112    }
113}
114
115/// Arithmetic operator for [`ScalarExpr`].
116#[derive(Clone, Copy, Debug, Eq, PartialEq)]
117pub enum BinaryOp {
118    Add,
119    Subtract,
120    Multiply,
121    Divide,
122    Modulo,
123}
124
125/// Comparison operator for scalar expressions.
126#[derive(Clone, Copy, Debug, Eq, PartialEq)]
127pub enum CompareOp {
128    Eq,
129    NotEq,
130    Lt,
131    LtEq,
132    Gt,
133    GtEq,
134}
135
136/// Single predicate against a field.
137#[derive(Debug, Clone)]
138pub struct Filter<'a, F> {
139    pub field_id: F,
140    pub op: Operator<'a>,
141}
142
143/// Comparison/matching operators over untyped `Literal`s.
144///
145/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
146/// common case of small, static IN lists built at call sites.
147#[derive(Debug, Clone)]
148pub enum Operator<'a> {
149    Equals(Literal),
150    Range {
151        lower: Bound<Literal>,
152        upper: Bound<Literal>,
153    },
154    GreaterThan(Literal),
155    GreaterThanOrEquals(Literal),
156    LessThan(Literal),
157    LessThanOrEquals(Literal),
158    In(&'a [Literal]),
159    StartsWith {
160        pattern: &'a str,
161        case_sensitive: bool,
162    },
163    EndsWith {
164        pattern: &'a str,
165        case_sensitive: bool,
166    },
167    Contains {
168        pattern: &'a str,
169        case_sensitive: bool,
170    },
171    IsNull,
172    IsNotNull,
173}
174
175impl<'a> Operator<'a> {
176    #[inline]
177    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
178        Operator::StartsWith {
179            pattern,
180            case_sensitive,
181        }
182    }
183
184    #[inline]
185    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
186        Operator::EndsWith {
187            pattern,
188            case_sensitive,
189        }
190    }
191
192    #[inline]
193    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
194        Operator::Contains {
195            pattern,
196            case_sensitive,
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn build_simple_exprs() {
207        let f1 = Filter {
208            field_id: 1,
209            op: Operator::Equals("abc".into()),
210        };
211        let f2 = Filter {
212            field_id: 2,
213            op: Operator::LessThan("zzz".into()),
214        };
215        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
216        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
217        let not_all = Expr::not(all);
218        match any {
219            Expr::Or(v) => assert_eq!(v.len(), 2),
220            _ => panic!("expected Or"),
221        }
222        match not_all {
223            Expr::Not(inner) => match *inner {
224                Expr::And(v) => assert_eq!(v.len(), 2),
225                _ => panic!("expected And inside Not"),
226            },
227            _ => panic!("expected Not"),
228        }
229    }
230
231    #[test]
232    fn complex_nested_shape() {
233        // f1: id=1 == "a"
234        // f2: id=2 <  "zzz"
235        // f3: id=3 in ["x","y","z"]
236        // f4: id=4 starts_with "pre"
237        let f1 = Filter {
238            field_id: 1u32,
239            op: Operator::Equals("a".into()),
240        };
241        let f2 = Filter {
242            field_id: 2u32,
243            op: Operator::LessThan("zzz".into()),
244        };
245        let in_values = ["x".into(), "y".into(), "z".into()];
246        let f3 = Filter {
247            field_id: 3u32,
248            op: Operator::In(&in_values),
249        };
250        let f4 = Filter {
251            field_id: 4u32,
252            op: Operator::starts_with("pre", true),
253        };
254
255        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
256        let left = Expr::And(vec![
257            Expr::Pred(f1.clone()),
258            Expr::Or(vec![
259                Expr::Pred(f2.clone()),
260                Expr::not(Expr::Pred(f3.clone())),
261            ]),
262        ]);
263        let right = Expr::And(vec![
264            Expr::not(Expr::Pred(f1.clone())),
265            Expr::Pred(f4.clone()),
266        ]);
267        let top = Expr::Or(vec![left, right]);
268
269        // Shape checks
270        match top {
271            Expr::Or(branches) => {
272                assert_eq!(branches.len(), 2);
273                match &branches[0] {
274                    Expr::And(v) => {
275                        assert_eq!(v.len(), 2);
276                        // AND: [Pred(f1), OR(...)]
277                        match &v[0] {
278                            Expr::Pred(Filter { field_id, .. }) => {
279                                assert_eq!(*field_id, 1)
280                            }
281                            _ => panic!("expected Pred(f1) in left-AND[0]"),
282                        }
283                        match &v[1] {
284                            Expr::Or(or_vec) => {
285                                assert_eq!(or_vec.len(), 2);
286                                match &or_vec[0] {
287                                    Expr::Pred(Filter { field_id, .. }) => {
288                                        assert_eq!(*field_id, 2)
289                                    }
290                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
291                                }
292                                match &or_vec[1] {
293                                    Expr::Not(inner) => match inner.as_ref() {
294                                        Expr::Pred(Filter { field_id, .. }) => {
295                                            assert_eq!(*field_id, 3)
296                                        }
297                                        _ => panic!(
298                                            "expected Not(Pred(f3)) in \
299                                             left-AND[1].OR[1]"
300                                        ),
301                                    },
302                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
303                                }
304                            }
305                            _ => panic!("expected OR in left-AND[1]"),
306                        }
307                    }
308                    _ => panic!("expected AND on left branch of top OR"),
309                }
310                match &branches[1] {
311                    Expr::And(v) => {
312                        assert_eq!(v.len(), 2);
313                        // AND: [Not(f1), Pred(f4)]
314                        match &v[0] {
315                            Expr::Not(inner) => match inner.as_ref() {
316                                Expr::Pred(Filter { field_id, .. }) => {
317                                    assert_eq!(*field_id, 1)
318                                }
319                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
320                            },
321                            _ => panic!("expected Not(...) in right-AND[0]"),
322                        }
323                        match &v[1] {
324                            Expr::Pred(Filter { field_id, .. }) => {
325                                assert_eq!(*field_id, 4)
326                            }
327                            _ => panic!("expected Pred(f4) in right-AND[1]"),
328                        }
329                    }
330                    _ => panic!("expected AND on right branch of top OR"),
331                }
332            }
333            _ => panic!("expected top-level OR"),
334        }
335    }
336
337    #[test]
338    fn range_bounds_roundtrip() {
339        // [aaa, bbb)
340        let f = Filter {
341            field_id: 7u32,
342            op: Operator::Range {
343                lower: Bound::Included("aaa".into()),
344                upper: Bound::Excluded("bbb".into()),
345            },
346        };
347
348        match &f.op {
349            Operator::Range { lower, upper } => {
350                if let Bound::Included(l) = lower {
351                    assert_eq!(*l, Literal::String("aaa".to_string()));
352                } else {
353                    panic!("lower bound should be Included");
354                }
355
356                if let Bound::Excluded(u) = upper {
357                    assert_eq!(*u, Literal::String("bbb".to_string()));
358                } else {
359                    panic!("upper bound should be Excluded");
360                }
361            }
362            _ => panic!("expected Range operator"),
363        }
364    }
365
366    #[test]
367    fn helper_builders_preserve_structure_and_order() {
368        let f1 = Filter {
369            field_id: 1u32,
370            op: Operator::Equals("a".into()),
371        };
372        let f2 = Filter {
373            field_id: 2u32,
374            op: Operator::Equals("b".into()),
375        };
376        let f3 = Filter {
377            field_id: 3u32,
378            op: Operator::Equals("c".into()),
379        };
380
381        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
382        match and_expr {
383            Expr::And(v) => {
384                assert_eq!(v.len(), 3);
385                // Expect Pred(1), Pred(2), Pred(3) in order
386                match &v[0] {
387                    Expr::Pred(Filter { field_id, .. }) => {
388                        assert_eq!(*field_id, 1)
389                    }
390                    _ => panic!(),
391                }
392                match &v[1] {
393                    Expr::Pred(Filter { field_id, .. }) => {
394                        assert_eq!(*field_id, 2)
395                    }
396                    _ => panic!(),
397                }
398                match &v[2] {
399                    Expr::Pred(Filter { field_id, .. }) => {
400                        assert_eq!(*field_id, 3)
401                    }
402                    _ => panic!(),
403                }
404            }
405            _ => panic!("expected And"),
406        }
407
408        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
409        match or_expr {
410            Expr::Or(v) => {
411                assert_eq!(v.len(), 3);
412                // Expect Pred(3), Pred(2), Pred(1) in order
413                match &v[0] {
414                    Expr::Pred(Filter { field_id, .. }) => {
415                        assert_eq!(*field_id, 3)
416                    }
417                    _ => panic!(),
418                }
419                match &v[1] {
420                    Expr::Pred(Filter { field_id, .. }) => {
421                        assert_eq!(*field_id, 2)
422                    }
423                    _ => panic!(),
424                }
425                match &v[2] {
426                    Expr::Pred(Filter { field_id, .. }) => {
427                        assert_eq!(*field_id, 1)
428                    }
429                    _ => panic!(),
430                }
431            }
432            _ => panic!("expected Or"),
433        }
434    }
435
436    #[test]
437    fn set_and_pattern_ops_hold_borrowed_slices() {
438        let in_values = ["aa".into(), "bb".into(), "cc".into()];
439        let f_in = Filter {
440            field_id: 9u32,
441            op: Operator::In(&in_values),
442        };
443        match f_in.op {
444            Operator::In(arr) => {
445                assert_eq!(arr.len(), 3);
446            }
447            _ => panic!("expected In"),
448        }
449
450        let f_sw = Filter {
451            field_id: 10u32,
452            op: Operator::starts_with("pre", true),
453        };
454        let f_ew = Filter {
455            field_id: 11u32,
456            op: Operator::ends_with("suf", true),
457        };
458        let f_ct = Filter {
459            field_id: 12u32,
460            op: Operator::contains("mid", true),
461        };
462
463        match f_sw.op {
464            Operator::StartsWith {
465                pattern: b,
466                case_sensitive,
467            } => {
468                assert_eq!(b, "pre");
469                assert!(case_sensitive);
470            }
471            _ => panic!(),
472        }
473        match f_ew.op {
474            Operator::EndsWith {
475                pattern: b,
476                case_sensitive,
477            } => {
478                assert_eq!(b, "suf");
479                assert!(case_sensitive);
480            }
481            _ => panic!(),
482        }
483        match f_ct.op {
484            Operator::Contains {
485                pattern: b,
486                case_sensitive,
487            } => {
488                assert_eq!(b, "mid");
489                assert!(case_sensitive);
490            }
491            _ => panic!(),
492        }
493    }
494
495    #[test]
496    fn generic_field_id_works_with_strings() {
497        // Demonstrate F = &'static str
498        let f1 = Filter {
499            field_id: "name",
500            op: Operator::Equals("alice".into()),
501        };
502        let f2 = Filter {
503            field_id: "status",
504            op: Operator::GreaterThanOrEquals("active".into()),
505        };
506        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
507
508        match expr {
509            Expr::And(v) => {
510                assert_eq!(v.len(), 2);
511                match &v[0] {
512                    Expr::Pred(Filter { field_id, .. }) => {
513                        assert_eq!(*field_id, "name")
514                    }
515                    _ => panic!("expected Pred(name)"),
516                }
517                match &v[1] {
518                    Expr::Pred(Filter { field_id, .. }) => {
519                        assert_eq!(*field_id, "status")
520                    }
521                    _ => panic!("expected Pred(status)"),
522                }
523            }
524            _ => panic!("expected And"),
525        }
526    }
527
528    #[test]
529    fn very_deep_not_chain() {
530        // Build Not(Not(...Not(Pred)...)) of depth 64
531        let base = Expr::Pred(Filter {
532            field_id: 42u32,
533            op: Operator::Equals("x".into()),
534        });
535        let mut expr = base;
536        for _ in 0..64 {
537            expr = Expr::not(expr);
538        }
539
540        // Count nested NOTs
541        let mut count = 0usize;
542        let mut cur = &expr;
543        loop {
544            match cur {
545                Expr::Not(inner) => {
546                    count += 1;
547                    cur = inner;
548                }
549                Expr::Pred(Filter { field_id, .. }) => {
550                    assert_eq!(*field_id, 42);
551                    break;
552                }
553                _ => panic!("unexpected node inside deep NOT chain"),
554            }
555        }
556        assert_eq!(count, 64);
557    }
558
559    #[test]
560    fn literal_construction() {
561        let f = Filter {
562            field_id: "my_u64_col",
563            op: Operator::Range {
564                lower: Bound::Included(150.into()),
565                upper: Bound::Excluded(300.into()),
566            },
567        };
568
569        match f.op {
570            Operator::Range { lower, upper } => {
571                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
572                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
573            }
574            _ => panic!("Expected a range operator"),
575        }
576
577        let f2 = Filter {
578            field_id: "my_str_col",
579            op: Operator::Equals("hello".into()),
580        };
581
582        match f2.op {
583            Operator::Equals(lit) => {
584                assert_eq!(lit, Literal::String("hello".to_string()));
585            }
586            _ => panic!("Expected an equals operator"),
587        }
588    }
589}