llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13/// Logical expression over predicates.
14#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16    And(Vec<Expr<'a, F>>),
17    Or(Vec<Expr<'a, F>>),
18    Not(Box<Expr<'a, F>>),
19    Pred(Filter<'a, F>),
20    Compare {
21        left: ScalarExpr<F>,
22        op: CompareOp,
23        right: ScalarExpr<F>,
24    },
25}
26
27impl<'a, F> Expr<'a, F> {
28    /// Build an AND of filters.
29    #[inline]
30    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
31        Expr::And(fs.into_iter().map(Expr::Pred).collect())
32    }
33
34    /// Build an OR of filters.
35    #[inline]
36    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
37        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
38    }
39
40    /// Wrap an expression in a logical NOT.
41    #[allow(clippy::should_implement_trait)]
42    #[inline]
43    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
44        Expr::Not(Box::new(e))
45    }
46}
47
48/// Arithmetic scalar expression that can reference multiple fields.
49#[derive(Clone, Debug)]
50pub enum ScalarExpr<F> {
51    Column(F),
52    Literal(Literal),
53    Binary {
54        left: Box<ScalarExpr<F>>,
55        op: BinaryOp,
56        right: Box<ScalarExpr<F>>,
57    },
58    /// Aggregate function call (e.g., COUNT(*), SUM(col), etc.)
59    /// This is used in expressions like COUNT(*) + 1
60    Aggregate(AggregateCall<F>),
61}
62
63/// Aggregate function call within a scalar expression
64#[derive(Clone, Debug)]
65pub enum AggregateCall<F> {
66    CountStar,
67    Count(F),
68    Sum(F),
69    Min(F),
70    Max(F),
71    CountNulls(F),
72}
73
74impl<F> ScalarExpr<F> {
75    #[inline]
76    pub fn column(field: F) -> Self {
77        Self::Column(field)
78    }
79
80    #[inline]
81    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
82        Self::Literal(lit.into())
83    }
84
85    #[inline]
86    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
87        Self::Binary {
88            left: Box::new(left),
89            op,
90            right: Box::new(right),
91        }
92    }
93
94    #[inline]
95    pub fn aggregate(call: AggregateCall<F>) -> Self {
96        Self::Aggregate(call)
97    }
98}
99
100/// Arithmetic operator for [`ScalarExpr`].
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum BinaryOp {
103    Add,
104    Subtract,
105    Multiply,
106    Divide,
107    Modulo,
108}
109
110/// Comparison operator for scalar expressions.
111#[derive(Clone, Copy, Debug, Eq, PartialEq)]
112pub enum CompareOp {
113    Eq,
114    NotEq,
115    Lt,
116    LtEq,
117    Gt,
118    GtEq,
119}
120
121/// Single predicate against a field.
122#[derive(Debug, Clone)]
123pub struct Filter<'a, F> {
124    pub field_id: F,
125    pub op: Operator<'a>,
126}
127
128/// Comparison/matching operators over untyped `Literal`s.
129///
130/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
131/// common case of small, static IN lists built at call sites.
132#[derive(Debug, Clone)]
133pub enum Operator<'a> {
134    Equals(Literal),
135    Range {
136        lower: Bound<Literal>,
137        upper: Bound<Literal>,
138    },
139    GreaterThan(Literal),
140    GreaterThanOrEquals(Literal),
141    LessThan(Literal),
142    LessThanOrEquals(Literal),
143    In(&'a [Literal]),
144    StartsWith {
145        pattern: &'a str,
146        case_sensitive: bool,
147    },
148    EndsWith {
149        pattern: &'a str,
150        case_sensitive: bool,
151    },
152    Contains {
153        pattern: &'a str,
154        case_sensitive: bool,
155    },
156}
157
158impl<'a> Operator<'a> {
159    #[inline]
160    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
161        Operator::StartsWith {
162            pattern,
163            case_sensitive,
164        }
165    }
166
167    #[inline]
168    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
169        Operator::EndsWith {
170            pattern,
171            case_sensitive,
172        }
173    }
174
175    #[inline]
176    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
177        Operator::Contains {
178            pattern,
179            case_sensitive,
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn build_simple_exprs() {
190        let f1 = Filter {
191            field_id: 1,
192            op: Operator::Equals("abc".into()),
193        };
194        let f2 = Filter {
195            field_id: 2,
196            op: Operator::LessThan("zzz".into()),
197        };
198        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
199        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
200        let not_all = Expr::not(all);
201        match any {
202            Expr::Or(v) => assert_eq!(v.len(), 2),
203            _ => panic!("expected Or"),
204        }
205        match not_all {
206            Expr::Not(inner) => match *inner {
207                Expr::And(v) => assert_eq!(v.len(), 2),
208                _ => panic!("expected And inside Not"),
209            },
210            _ => panic!("expected Not"),
211        }
212    }
213
214    #[test]
215    fn complex_nested_shape() {
216        // f1: id=1 == "a"
217        // f2: id=2 <  "zzz"
218        // f3: id=3 in ["x","y","z"]
219        // f4: id=4 starts_with "pre"
220        let f1 = Filter {
221            field_id: 1u32,
222            op: Operator::Equals("a".into()),
223        };
224        let f2 = Filter {
225            field_id: 2u32,
226            op: Operator::LessThan("zzz".into()),
227        };
228        let in_values = ["x".into(), "y".into(), "z".into()];
229        let f3 = Filter {
230            field_id: 3u32,
231            op: Operator::In(&in_values),
232        };
233        let f4 = Filter {
234            field_id: 4u32,
235            op: Operator::starts_with("pre", true),
236        };
237
238        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
239        let left = Expr::And(vec![
240            Expr::Pred(f1.clone()),
241            Expr::Or(vec![
242                Expr::Pred(f2.clone()),
243                Expr::not(Expr::Pred(f3.clone())),
244            ]),
245        ]);
246        let right = Expr::And(vec![
247            Expr::not(Expr::Pred(f1.clone())),
248            Expr::Pred(f4.clone()),
249        ]);
250        let top = Expr::Or(vec![left, right]);
251
252        // Shape checks
253        match top {
254            Expr::Or(branches) => {
255                assert_eq!(branches.len(), 2);
256                match &branches[0] {
257                    Expr::And(v) => {
258                        assert_eq!(v.len(), 2);
259                        // AND: [Pred(f1), OR(...)]
260                        match &v[0] {
261                            Expr::Pred(Filter { field_id, .. }) => {
262                                assert_eq!(*field_id, 1)
263                            }
264                            _ => panic!("expected Pred(f1) in left-AND[0]"),
265                        }
266                        match &v[1] {
267                            Expr::Or(or_vec) => {
268                                assert_eq!(or_vec.len(), 2);
269                                match &or_vec[0] {
270                                    Expr::Pred(Filter { field_id, .. }) => {
271                                        assert_eq!(*field_id, 2)
272                                    }
273                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
274                                }
275                                match &or_vec[1] {
276                                    Expr::Not(inner) => match inner.as_ref() {
277                                        Expr::Pred(Filter { field_id, .. }) => {
278                                            assert_eq!(*field_id, 3)
279                                        }
280                                        _ => panic!(
281                                            "expected Not(Pred(f3)) in \
282                                             left-AND[1].OR[1]"
283                                        ),
284                                    },
285                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
286                                }
287                            }
288                            _ => panic!("expected OR in left-AND[1]"),
289                        }
290                    }
291                    _ => panic!("expected AND on left branch of top OR"),
292                }
293                match &branches[1] {
294                    Expr::And(v) => {
295                        assert_eq!(v.len(), 2);
296                        // AND: [Not(f1), Pred(f4)]
297                        match &v[0] {
298                            Expr::Not(inner) => match inner.as_ref() {
299                                Expr::Pred(Filter { field_id, .. }) => {
300                                    assert_eq!(*field_id, 1)
301                                }
302                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
303                            },
304                            _ => panic!("expected Not(...) in right-AND[0]"),
305                        }
306                        match &v[1] {
307                            Expr::Pred(Filter { field_id, .. }) => {
308                                assert_eq!(*field_id, 4)
309                            }
310                            _ => panic!("expected Pred(f4) in right-AND[1]"),
311                        }
312                    }
313                    _ => panic!("expected AND on right branch of top OR"),
314                }
315            }
316            _ => panic!("expected top-level OR"),
317        }
318    }
319
320    #[test]
321    fn range_bounds_roundtrip() {
322        // [aaa, bbb)
323        let f = Filter {
324            field_id: 7u32,
325            op: Operator::Range {
326                lower: Bound::Included("aaa".into()),
327                upper: Bound::Excluded("bbb".into()),
328            },
329        };
330
331        match &f.op {
332            Operator::Range { lower, upper } => {
333                if let Bound::Included(l) = lower {
334                    assert_eq!(*l, Literal::String("aaa".to_string()));
335                } else {
336                    panic!("lower bound should be Included");
337                }
338
339                if let Bound::Excluded(u) = upper {
340                    assert_eq!(*u, Literal::String("bbb".to_string()));
341                } else {
342                    panic!("upper bound should be Excluded");
343                }
344            }
345            _ => panic!("expected Range operator"),
346        }
347    }
348
349    #[test]
350    fn helper_builders_preserve_structure_and_order() {
351        let f1 = Filter {
352            field_id: 1u32,
353            op: Operator::Equals("a".into()),
354        };
355        let f2 = Filter {
356            field_id: 2u32,
357            op: Operator::Equals("b".into()),
358        };
359        let f3 = Filter {
360            field_id: 3u32,
361            op: Operator::Equals("c".into()),
362        };
363
364        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
365        match and_expr {
366            Expr::And(v) => {
367                assert_eq!(v.len(), 3);
368                // Expect Pred(1), Pred(2), Pred(3) in order
369                match &v[0] {
370                    Expr::Pred(Filter { field_id, .. }) => {
371                        assert_eq!(*field_id, 1)
372                    }
373                    _ => panic!(),
374                }
375                match &v[1] {
376                    Expr::Pred(Filter { field_id, .. }) => {
377                        assert_eq!(*field_id, 2)
378                    }
379                    _ => panic!(),
380                }
381                match &v[2] {
382                    Expr::Pred(Filter { field_id, .. }) => {
383                        assert_eq!(*field_id, 3)
384                    }
385                    _ => panic!(),
386                }
387            }
388            _ => panic!("expected And"),
389        }
390
391        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
392        match or_expr {
393            Expr::Or(v) => {
394                assert_eq!(v.len(), 3);
395                // Expect Pred(3), Pred(2), Pred(1) in order
396                match &v[0] {
397                    Expr::Pred(Filter { field_id, .. }) => {
398                        assert_eq!(*field_id, 3)
399                    }
400                    _ => panic!(),
401                }
402                match &v[1] {
403                    Expr::Pred(Filter { field_id, .. }) => {
404                        assert_eq!(*field_id, 2)
405                    }
406                    _ => panic!(),
407                }
408                match &v[2] {
409                    Expr::Pred(Filter { field_id, .. }) => {
410                        assert_eq!(*field_id, 1)
411                    }
412                    _ => panic!(),
413                }
414            }
415            _ => panic!("expected Or"),
416        }
417    }
418
419    #[test]
420    fn set_and_pattern_ops_hold_borrowed_slices() {
421        let in_values = ["aa".into(), "bb".into(), "cc".into()];
422        let f_in = Filter {
423            field_id: 9u32,
424            op: Operator::In(&in_values),
425        };
426        match f_in.op {
427            Operator::In(arr) => {
428                assert_eq!(arr.len(), 3);
429            }
430            _ => panic!("expected In"),
431        }
432
433        let f_sw = Filter {
434            field_id: 10u32,
435            op: Operator::starts_with("pre", true),
436        };
437        let f_ew = Filter {
438            field_id: 11u32,
439            op: Operator::ends_with("suf", true),
440        };
441        let f_ct = Filter {
442            field_id: 12u32,
443            op: Operator::contains("mid", true),
444        };
445
446        match f_sw.op {
447            Operator::StartsWith {
448                pattern: b,
449                case_sensitive,
450            } => {
451                assert_eq!(b, "pre");
452                assert!(case_sensitive);
453            }
454            _ => panic!(),
455        }
456        match f_ew.op {
457            Operator::EndsWith {
458                pattern: b,
459                case_sensitive,
460            } => {
461                assert_eq!(b, "suf");
462                assert!(case_sensitive);
463            }
464            _ => panic!(),
465        }
466        match f_ct.op {
467            Operator::Contains {
468                pattern: b,
469                case_sensitive,
470            } => {
471                assert_eq!(b, "mid");
472                assert!(case_sensitive);
473            }
474            _ => panic!(),
475        }
476    }
477
478    #[test]
479    fn generic_field_id_works_with_strings() {
480        // Demonstrate F = &'static str
481        let f1 = Filter {
482            field_id: "name",
483            op: Operator::Equals("alice".into()),
484        };
485        let f2 = Filter {
486            field_id: "status",
487            op: Operator::GreaterThanOrEquals("active".into()),
488        };
489        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
490
491        match expr {
492            Expr::And(v) => {
493                assert_eq!(v.len(), 2);
494                match &v[0] {
495                    Expr::Pred(Filter { field_id, .. }) => {
496                        assert_eq!(*field_id, "name")
497                    }
498                    _ => panic!("expected Pred(name)"),
499                }
500                match &v[1] {
501                    Expr::Pred(Filter { field_id, .. }) => {
502                        assert_eq!(*field_id, "status")
503                    }
504                    _ => panic!("expected Pred(status)"),
505                }
506            }
507            _ => panic!("expected And"),
508        }
509    }
510
511    #[test]
512    fn very_deep_not_chain() {
513        // Build Not(Not(...Not(Pred)...)) of depth 64
514        let base = Expr::Pred(Filter {
515            field_id: 42u32,
516            op: Operator::Equals("x".into()),
517        });
518        let mut expr = base;
519        for _ in 0..64 {
520            expr = Expr::not(expr);
521        }
522
523        // Count nested NOTs
524        let mut count = 0usize;
525        let mut cur = &expr;
526        loop {
527            match cur {
528                Expr::Not(inner) => {
529                    count += 1;
530                    cur = inner;
531                }
532                Expr::Pred(Filter { field_id, .. }) => {
533                    assert_eq!(*field_id, 42);
534                    break;
535                }
536                _ => panic!("unexpected node inside deep NOT chain"),
537            }
538        }
539        assert_eq!(count, 64);
540    }
541
542    #[test]
543    fn literal_construction() {
544        let f = Filter {
545            field_id: "my_u64_col",
546            op: Operator::Range {
547                lower: Bound::Included(150.into()),
548                upper: Bound::Excluded(300.into()),
549            },
550        };
551
552        match f.op {
553            Operator::Range { lower, upper } => {
554                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
555                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
556            }
557            _ => panic!("Expected a range operator"),
558        }
559
560        let f2 = Filter {
561            field_id: "my_str_col",
562            op: Operator::Equals("hello".into()),
563        };
564
565        match f2.op {
566            Operator::Equals(lit) => {
567                assert_eq!(lit, Literal::String("hello".to_string()));
568            }
569            _ => panic!("Expected an equals operator"),
570        }
571    }
572}