llkv_expr/
expr.rs

1//! Type-aware, Arrow-native predicate AST.
2//!
3//! This module defines a small predicate-expression AST that is decoupled
4//! from Arrow's concrete scalar types by using `Literal`. Concrete typing
5//! is deferred to the consumer (e.g., a table/scan layer) which knows the
6//! column types and can coerce `Literal` into native values.
7
8#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13/// Logical expression over predicates.
14#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16    And(Vec<Expr<'a, F>>),
17    Or(Vec<Expr<'a, F>>),
18    Not(Box<Expr<'a, F>>),
19    Pred(Filter<'a, F>),
20    Compare {
21        left: ScalarExpr<F>,
22        op: CompareOp,
23        right: ScalarExpr<F>,
24    },
25}
26
27impl<'a, F> Expr<'a, F> {
28    /// Build an AND of filters.
29    #[inline]
30    pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
31        Expr::And(fs.into_iter().map(Expr::Pred).collect())
32    }
33
34    /// Build an OR of filters.
35    #[inline]
36    pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
37        Expr::Or(fs.into_iter().map(Expr::Pred).collect())
38    }
39
40    /// Wrap an expression in a logical NOT.
41    #[allow(clippy::should_implement_trait)]
42    #[inline]
43    pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
44        Expr::Not(Box::new(e))
45    }
46}
47
48/// Arithmetic scalar expression that can reference multiple fields.
49#[derive(Clone, Debug)]
50pub enum ScalarExpr<F> {
51    Column(F),
52    Literal(Literal),
53    Binary {
54        left: Box<ScalarExpr<F>>,
55        op: BinaryOp,
56        right: Box<ScalarExpr<F>>,
57    },
58}
59
60impl<F> ScalarExpr<F> {
61    #[inline]
62    pub fn column(field: F) -> Self {
63        Self::Column(field)
64    }
65
66    #[inline]
67    pub fn literal<L: Into<Literal>>(lit: L) -> Self {
68        Self::Literal(lit.into())
69    }
70
71    #[inline]
72    pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
73        Self::Binary {
74            left: Box::new(left),
75            op,
76            right: Box::new(right),
77        }
78    }
79}
80
81/// Arithmetic operator for [`ScalarExpr`].
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
83pub enum BinaryOp {
84    Add,
85    Subtract,
86    Multiply,
87    Divide,
88}
89
90/// Comparison operator for scalar expressions.
91#[derive(Clone, Copy, Debug, Eq, PartialEq)]
92pub enum CompareOp {
93    Eq,
94    NotEq,
95    Lt,
96    LtEq,
97    Gt,
98    GtEq,
99}
100
101/// Single predicate against a field.
102#[derive(Debug, Clone)]
103pub struct Filter<'a, F> {
104    pub field_id: F,
105    pub op: Operator<'a>,
106}
107
108/// Comparison/matching operators over untyped `Literal`s.
109///
110/// `In` accepts a borrowed slice of `Literal`s to avoid allocations in the
111/// common case of small, static IN lists built at call sites.
112#[derive(Debug, Clone)]
113pub enum Operator<'a> {
114    Equals(Literal),
115    Range {
116        lower: Bound<Literal>,
117        upper: Bound<Literal>,
118    },
119    GreaterThan(Literal),
120    GreaterThanOrEquals(Literal),
121    LessThan(Literal),
122    LessThanOrEquals(Literal),
123    In(&'a [Literal]),
124    StartsWith {
125        pattern: &'a str,
126        case_sensitive: bool,
127    },
128    EndsWith {
129        pattern: &'a str,
130        case_sensitive: bool,
131    },
132    Contains {
133        pattern: &'a str,
134        case_sensitive: bool,
135    },
136}
137
138impl<'a> Operator<'a> {
139    #[inline]
140    pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
141        Operator::StartsWith {
142            pattern,
143            case_sensitive,
144        }
145    }
146
147    #[inline]
148    pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
149        Operator::EndsWith {
150            pattern,
151            case_sensitive,
152        }
153    }
154
155    #[inline]
156    pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
157        Operator::Contains {
158            pattern,
159            case_sensitive,
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn build_simple_exprs() {
170        let f1 = Filter {
171            field_id: 1,
172            op: Operator::Equals("abc".into()),
173        };
174        let f2 = Filter {
175            field_id: 2,
176            op: Operator::LessThan("zzz".into()),
177        };
178        let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
179        let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
180        let not_all = Expr::not(all);
181        match any {
182            Expr::Or(v) => assert_eq!(v.len(), 2),
183            _ => panic!("expected Or"),
184        }
185        match not_all {
186            Expr::Not(inner) => match *inner {
187                Expr::And(v) => assert_eq!(v.len(), 2),
188                _ => panic!("expected And inside Not"),
189            },
190            _ => panic!("expected Not"),
191        }
192    }
193
194    #[test]
195    fn complex_nested_shape() {
196        // f1: id=1 == "a"
197        // f2: id=2 <  "zzz"
198        // f3: id=3 in ["x","y","z"]
199        // f4: id=4 starts_with "pre"
200        let f1 = Filter {
201            field_id: 1u32,
202            op: Operator::Equals("a".into()),
203        };
204        let f2 = Filter {
205            field_id: 2u32,
206            op: Operator::LessThan("zzz".into()),
207        };
208        let in_values = ["x".into(), "y".into(), "z".into()];
209        let f3 = Filter {
210            field_id: 3u32,
211            op: Operator::In(&in_values),
212        };
213        let f4 = Filter {
214            field_id: 4u32,
215            op: Operator::starts_with("pre", true),
216        };
217
218        // ( f1 AND ( f2 OR NOT f3 ) )  OR  ( NOT f1 AND f4 )
219        let left = Expr::And(vec![
220            Expr::Pred(f1.clone()),
221            Expr::Or(vec![
222                Expr::Pred(f2.clone()),
223                Expr::not(Expr::Pred(f3.clone())),
224            ]),
225        ]);
226        let right = Expr::And(vec![
227            Expr::not(Expr::Pred(f1.clone())),
228            Expr::Pred(f4.clone()),
229        ]);
230        let top = Expr::Or(vec![left, right]);
231
232        // Shape checks
233        match top {
234            Expr::Or(branches) => {
235                assert_eq!(branches.len(), 2);
236                match &branches[0] {
237                    Expr::And(v) => {
238                        assert_eq!(v.len(), 2);
239                        // AND: [Pred(f1), OR(...)]
240                        match &v[0] {
241                            Expr::Pred(Filter { field_id, .. }) => {
242                                assert_eq!(*field_id, 1)
243                            }
244                            _ => panic!("expected Pred(f1) in left-AND[0]"),
245                        }
246                        match &v[1] {
247                            Expr::Or(or_vec) => {
248                                assert_eq!(or_vec.len(), 2);
249                                match &or_vec[0] {
250                                    Expr::Pred(Filter { field_id, .. }) => {
251                                        assert_eq!(*field_id, 2)
252                                    }
253                                    _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
254                                }
255                                match &or_vec[1] {
256                                    Expr::Not(inner) => match inner.as_ref() {
257                                        Expr::Pred(Filter { field_id, .. }) => {
258                                            assert_eq!(*field_id, 3)
259                                        }
260                                        _ => panic!(
261                                            "expected Not(Pred(f3)) in \
262                                             left-AND[1].OR[1]"
263                                        ),
264                                    },
265                                    _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
266                                }
267                            }
268                            _ => panic!("expected OR in left-AND[1]"),
269                        }
270                    }
271                    _ => panic!("expected AND on left branch of top OR"),
272                }
273                match &branches[1] {
274                    Expr::And(v) => {
275                        assert_eq!(v.len(), 2);
276                        // AND: [Not(f1), Pred(f4)]
277                        match &v[0] {
278                            Expr::Not(inner) => match inner.as_ref() {
279                                Expr::Pred(Filter { field_id, .. }) => {
280                                    assert_eq!(*field_id, 1)
281                                }
282                                _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
283                            },
284                            _ => panic!("expected Not(...) in right-AND[0]"),
285                        }
286                        match &v[1] {
287                            Expr::Pred(Filter { field_id, .. }) => {
288                                assert_eq!(*field_id, 4)
289                            }
290                            _ => panic!("expected Pred(f4) in right-AND[1]"),
291                        }
292                    }
293                    _ => panic!("expected AND on right branch of top OR"),
294                }
295            }
296            _ => panic!("expected top-level OR"),
297        }
298    }
299
300    #[test]
301    fn range_bounds_roundtrip() {
302        // [aaa, bbb)
303        let f = Filter {
304            field_id: 7u32,
305            op: Operator::Range {
306                lower: Bound::Included("aaa".into()),
307                upper: Bound::Excluded("bbb".into()),
308            },
309        };
310
311        match &f.op {
312            Operator::Range { lower, upper } => {
313                if let Bound::Included(l) = lower {
314                    assert_eq!(*l, Literal::String("aaa".to_string()));
315                } else {
316                    panic!("lower bound should be Included");
317                }
318
319                if let Bound::Excluded(u) = upper {
320                    assert_eq!(*u, Literal::String("bbb".to_string()));
321                } else {
322                    panic!("upper bound should be Excluded");
323                }
324            }
325            _ => panic!("expected Range operator"),
326        }
327    }
328
329    #[test]
330    fn helper_builders_preserve_structure_and_order() {
331        let f1 = Filter {
332            field_id: 1u32,
333            op: Operator::Equals("a".into()),
334        };
335        let f2 = Filter {
336            field_id: 2u32,
337            op: Operator::Equals("b".into()),
338        };
339        let f3 = Filter {
340            field_id: 3u32,
341            op: Operator::Equals("c".into()),
342        };
343
344        let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
345        match and_expr {
346            Expr::And(v) => {
347                assert_eq!(v.len(), 3);
348                // Expect Pred(1), Pred(2), Pred(3) in order
349                match &v[0] {
350                    Expr::Pred(Filter { field_id, .. }) => {
351                        assert_eq!(*field_id, 1)
352                    }
353                    _ => panic!(),
354                }
355                match &v[1] {
356                    Expr::Pred(Filter { field_id, .. }) => {
357                        assert_eq!(*field_id, 2)
358                    }
359                    _ => panic!(),
360                }
361                match &v[2] {
362                    Expr::Pred(Filter { field_id, .. }) => {
363                        assert_eq!(*field_id, 3)
364                    }
365                    _ => panic!(),
366                }
367            }
368            _ => panic!("expected And"),
369        }
370
371        let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
372        match or_expr {
373            Expr::Or(v) => {
374                assert_eq!(v.len(), 3);
375                // Expect Pred(3), Pred(2), Pred(1) in order
376                match &v[0] {
377                    Expr::Pred(Filter { field_id, .. }) => {
378                        assert_eq!(*field_id, 3)
379                    }
380                    _ => panic!(),
381                }
382                match &v[1] {
383                    Expr::Pred(Filter { field_id, .. }) => {
384                        assert_eq!(*field_id, 2)
385                    }
386                    _ => panic!(),
387                }
388                match &v[2] {
389                    Expr::Pred(Filter { field_id, .. }) => {
390                        assert_eq!(*field_id, 1)
391                    }
392                    _ => panic!(),
393                }
394            }
395            _ => panic!("expected Or"),
396        }
397    }
398
399    #[test]
400    fn set_and_pattern_ops_hold_borrowed_slices() {
401        let in_values = ["aa".into(), "bb".into(), "cc".into()];
402        let f_in = Filter {
403            field_id: 9u32,
404            op: Operator::In(&in_values),
405        };
406        match f_in.op {
407            Operator::In(arr) => {
408                assert_eq!(arr.len(), 3);
409            }
410            _ => panic!("expected In"),
411        }
412
413        let f_sw = Filter {
414            field_id: 10u32,
415            op: Operator::starts_with("pre", true),
416        };
417        let f_ew = Filter {
418            field_id: 11u32,
419            op: Operator::ends_with("suf", true),
420        };
421        let f_ct = Filter {
422            field_id: 12u32,
423            op: Operator::contains("mid", true),
424        };
425
426        match f_sw.op {
427            Operator::StartsWith {
428                pattern: b,
429                case_sensitive,
430            } => {
431                assert_eq!(b, "pre");
432                assert!(case_sensitive);
433            }
434            _ => panic!(),
435        }
436        match f_ew.op {
437            Operator::EndsWith {
438                pattern: b,
439                case_sensitive,
440            } => {
441                assert_eq!(b, "suf");
442                assert!(case_sensitive);
443            }
444            _ => panic!(),
445        }
446        match f_ct.op {
447            Operator::Contains {
448                pattern: b,
449                case_sensitive,
450            } => {
451                assert_eq!(b, "mid");
452                assert!(case_sensitive);
453            }
454            _ => panic!(),
455        }
456    }
457
458    #[test]
459    fn generic_field_id_works_with_strings() {
460        // Demonstrate F = &'static str
461        let f1 = Filter {
462            field_id: "name",
463            op: Operator::Equals("alice".into()),
464        };
465        let f2 = Filter {
466            field_id: "status",
467            op: Operator::GreaterThanOrEquals("active".into()),
468        };
469        let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
470
471        match expr {
472            Expr::And(v) => {
473                assert_eq!(v.len(), 2);
474                match &v[0] {
475                    Expr::Pred(Filter { field_id, .. }) => {
476                        assert_eq!(*field_id, "name")
477                    }
478                    _ => panic!("expected Pred(name)"),
479                }
480                match &v[1] {
481                    Expr::Pred(Filter { field_id, .. }) => {
482                        assert_eq!(*field_id, "status")
483                    }
484                    _ => panic!("expected Pred(status)"),
485                }
486            }
487            _ => panic!("expected And"),
488        }
489    }
490
491    #[test]
492    fn very_deep_not_chain() {
493        // Build Not(Not(...Not(Pred)...)) of depth 64
494        let base = Expr::Pred(Filter {
495            field_id: 42u32,
496            op: Operator::Equals("x".into()),
497        });
498        let mut expr = base;
499        for _ in 0..64 {
500            expr = Expr::not(expr);
501        }
502
503        // Count nested NOTs
504        let mut count = 0usize;
505        let mut cur = &expr;
506        loop {
507            match cur {
508                Expr::Not(inner) => {
509                    count += 1;
510                    cur = inner;
511                }
512                Expr::Pred(Filter { field_id, .. }) => {
513                    assert_eq!(*field_id, 42);
514                    break;
515                }
516                _ => panic!("unexpected node inside deep NOT chain"),
517            }
518        }
519        assert_eq!(count, 64);
520    }
521
522    #[test]
523    fn literal_construction() {
524        let f = Filter {
525            field_id: "my_u64_col",
526            op: Operator::Range {
527                lower: Bound::Included(150.into()),
528                upper: Bound::Excluded(300.into()),
529            },
530        };
531
532        match f.op {
533            Operator::Range { lower, upper } => {
534                assert_eq!(lower, Bound::Included(Literal::Integer(150)));
535                assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
536            }
537            _ => panic!("Expected a range operator"),
538        }
539
540        let f2 = Filter {
541            field_id: "my_str_col",
542            op: Operator::Equals("hello".into()),
543        };
544
545        match f2.op {
546            Operator::Equals(lit) => {
547                assert_eq!(lit, Literal::String("hello".to_string()));
548            }
549            _ => panic!("Expected an equals operator"),
550        }
551    }
552}