llkv_expr/
typed_predicate.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::Bound;
4
5use arrow::datatypes::ArrowPrimitiveType;
6
7use crate::expr::Operator;
8use crate::literal::{
9    FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
10};
11
12pub trait PredicateValue: Clone {
13    type Borrowed<'a>: ?Sized
14    where
15        Self: 'a;
16
17    fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
18    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
19    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
20    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
21        let _ = (value, target, case_sensitive);
22        false
23    }
24    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
25        let _ = (value, target, case_sensitive);
26        false
27    }
28    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
29        let _ = (value, target, case_sensitive);
30        false
31    }
32}
33
34#[derive(Debug, Clone)]
35pub enum Predicate<V>
36where
37    V: PredicateValue,
38{
39    All,
40    Equals(V),
41    GreaterThan(V),
42    GreaterThanOrEquals(V),
43    LessThan(V),
44    LessThanOrEquals(V),
45    Range {
46        lower: Option<Bound<V>>,
47        upper: Option<Bound<V>>,
48    },
49    In(Vec<V>),
50    StartsWith {
51        pattern: V,
52        case_sensitive: bool,
53    },
54    EndsWith {
55        pattern: V,
56        case_sensitive: bool,
57    },
58    Contains {
59        pattern: V,
60        case_sensitive: bool,
61    },
62}
63
64impl<V> Predicate<V>
65where
66    V: PredicateValue,
67{
68    pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
69        match self {
70            Predicate::All => true,
71            Predicate::Equals(target) => V::equals(value, target),
72            Predicate::GreaterThan(target) => {
73                matches!(V::compare(value, target), Some(Ordering::Greater))
74            }
75            Predicate::GreaterThanOrEquals(target) => {
76                matches!(
77                    V::compare(value, target),
78                    Some(Ordering::Greater | Ordering::Equal)
79                )
80            }
81            Predicate::LessThan(target) => {
82                matches!(V::compare(value, target), Some(Ordering::Less))
83            }
84            Predicate::LessThanOrEquals(target) => matches!(
85                V::compare(value, target),
86                Some(Ordering::Less | Ordering::Equal)
87            ),
88            Predicate::Range { lower, upper } => {
89                if let Some(bound) = lower
90                    && !match bound {
91                        Bound::Included(target) => matches!(
92                            V::compare(value, target),
93                            Some(Ordering::Greater | Ordering::Equal)
94                        ),
95                        Bound::Excluded(target) => {
96                            matches!(V::compare(value, target), Some(Ordering::Greater))
97                        }
98                        Bound::Unbounded => true,
99                    }
100                {
101                    return false;
102                }
103
104                if let Some(bound) = upper
105                    && !match bound {
106                        Bound::Included(target) => matches!(
107                            V::compare(value, target),
108                            Some(Ordering::Less | Ordering::Equal)
109                        ),
110                        Bound::Excluded(target) => {
111                            matches!(V::compare(value, target), Some(Ordering::Less))
112                        }
113                        Bound::Unbounded => true,
114                    }
115                {
116                    return false;
117                }
118
119                true
120            }
121            Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
122            Predicate::StartsWith {
123                pattern,
124                case_sensitive,
125            } => V::starts_with(value, pattern, *case_sensitive),
126            Predicate::EndsWith {
127                pattern,
128                case_sensitive,
129            } => V::ends_with(value, pattern, *case_sensitive),
130            Predicate::Contains {
131                pattern,
132                case_sensitive,
133            } => V::contains(value, pattern, *case_sensitive),
134        }
135    }
136}
137
138macro_rules! impl_predicate_value_for_primitive {
139    ($($ty:ty),+ $(,)?) => {
140        $(
141            impl PredicateValue for $ty {
142                type Borrowed<'a> = Self where Self: 'a;
143
144                fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
145                    value
146                }
147
148                fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
149                    *value == *target
150                }
151
152                fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
153                    value.partial_cmp(target)
154                }
155            }
156        )+
157    };
158}
159
160impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
161
162impl PredicateValue for String {
163    type Borrowed<'a>
164        = str
165    where
166        Self: 'a;
167
168    fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
169        value.as_str()
170    }
171
172    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
173        value == target.as_str()
174    }
175
176    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
177        Some(value.cmp(target.as_str()))
178    }
179
180    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
181        if case_sensitive {
182            value.contains(target.as_str())
183        } else {
184            value.to_lowercase().contains(&target.to_lowercase())
185        }
186    }
187
188    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
189        if case_sensitive {
190            value.starts_with(target.as_str())
191        } else {
192            value.to_lowercase().starts_with(&target.to_lowercase())
193        }
194    }
195
196    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
197        if case_sensitive {
198            value.ends_with(target.as_str())
199        } else {
200            value.to_lowercase().ends_with(&target.to_lowercase())
201        }
202    }
203}
204
205#[derive(Debug, Clone)]
206pub enum PredicateBuildError {
207    LiteralCast(LiteralCastError),
208    UnsupportedOperator(&'static str),
209}
210
211impl fmt::Display for PredicateBuildError {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
215            PredicateBuildError::UnsupportedOperator(op) => {
216                write!(f, "unsupported operator for typed predicate: {op}")
217            }
218        }
219    }
220}
221
222impl std::error::Error for PredicateBuildError {
223    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
224        match self {
225            PredicateBuildError::LiteralCast(err) => Some(err),
226            PredicateBuildError::UnsupportedOperator(_) => None,
227        }
228    }
229}
230
231impl From<LiteralCastError> for PredicateBuildError {
232    fn from(err: LiteralCastError) -> Self {
233        PredicateBuildError::LiteralCast(err)
234    }
235}
236
237pub fn build_fixed_width_predicate<T>(
238    op: &Operator<'_>,
239) -> Result<Predicate<T::Native>, PredicateBuildError>
240where
241    T: ArrowPrimitiveType,
242    T::Native: FromLiteral + Copy + PredicateValue,
243{
244    match op {
245        Operator::Equals(lit) => Ok(Predicate::Equals(
246            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
247        )),
248        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
249            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
250        )),
251        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
252            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
253        )),
254        Operator::LessThan(lit) => Ok(Predicate::LessThan(
255            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
256        )),
257        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
258            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
259        )),
260        Operator::Range { lower, upper } => {
261            let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
262                Bound::Unbounded => None,
263                other => Some(other),
264            };
265            let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
266                Bound::Unbounded => None,
267                other => Some(other),
268            };
269
270            if lb.is_none() && ub.is_none() {
271                Ok(Predicate::All)
272            } else {
273                Ok(Predicate::Range {
274                    lower: lb,
275                    upper: ub,
276                })
277            }
278        }
279        Operator::In(values) => {
280            let mut natives = Vec::with_capacity(values.len());
281            for lit in *values {
282                natives
283                    .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
284            }
285            Ok(Predicate::In(natives))
286        }
287        _ => Err(PredicateBuildError::UnsupportedOperator(
288            "operator lacks typed literal support",
289        )),
290    }
291}
292
293fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
294    Ok(match bound {
295        Bound::Unbounded => None,
296        Bound::Included(lit) => Some(Bound::Included(
297            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
298        )),
299        Bound::Excluded(lit) => Some(Bound::Excluded(
300            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
301        )),
302    })
303}
304
305pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
306    match op {
307        Operator::Equals(lit) => Ok(Predicate::Equals(
308            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
309        )),
310        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
311            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
312        )),
313        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
314            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315        )),
316        Operator::LessThan(lit) => Ok(Predicate::LessThan(
317            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318        )),
319        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
320            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
321        )),
322        Operator::Range { lower, upper } => {
323            let lb = parse_bool_bound(lower)?;
324            let ub = parse_bool_bound(upper)?;
325            if lb.is_none() && ub.is_none() {
326                Ok(Predicate::All)
327            } else {
328                Ok(Predicate::Range {
329                    lower: lb,
330                    upper: ub,
331                })
332            }
333        }
334        Operator::In(values) => {
335            let mut natives = Vec::with_capacity(values.len());
336            for lit in *values {
337                natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
338            }
339            Ok(Predicate::In(natives))
340        }
341        _ => Err(PredicateBuildError::UnsupportedOperator(
342            "operator lacks boolean literal support",
343        )),
344    }
345}
346
347fn parse_string_bound(
348    bound: &Bound<Literal>,
349) -> Result<Option<Bound<String>>, PredicateBuildError> {
350    match bound {
351        Bound::Unbounded => Ok(None),
352        Bound::Included(lit) => literal_to_string(lit)
353            .map(|s| Some(Bound::Included(s)))
354            .map_err(PredicateBuildError::from),
355        Bound::Excluded(lit) => literal_to_string(lit)
356            .map(|s| Some(Bound::Excluded(s)))
357            .map_err(PredicateBuildError::from),
358    }
359}
360
361pub fn build_var_width_predicate(
362    op: &Operator<'_>,
363) -> Result<Predicate<String>, PredicateBuildError> {
364    match op {
365        Operator::Equals(lit) => Ok(Predicate::Equals(
366            literal_to_string(lit).map_err(PredicateBuildError::from)?,
367        )),
368        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
369            literal_to_string(lit).map_err(PredicateBuildError::from)?,
370        )),
371        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
372            literal_to_string(lit).map_err(PredicateBuildError::from)?,
373        )),
374        Operator::LessThan(lit) => Ok(Predicate::LessThan(
375            literal_to_string(lit).map_err(PredicateBuildError::from)?,
376        )),
377        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
378            literal_to_string(lit).map_err(PredicateBuildError::from)?,
379        )),
380        Operator::Range { lower, upper } => {
381            let lb = parse_string_bound(lower)?;
382            let ub = parse_string_bound(upper)?;
383            if lb.is_none() && ub.is_none() {
384                Ok(Predicate::All)
385            } else {
386                Ok(Predicate::Range {
387                    lower: lb,
388                    upper: ub,
389                })
390            }
391        }
392        Operator::In(values) => {
393            let mut out = Vec::with_capacity(values.len());
394            for lit in *values {
395                out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
396            }
397            Ok(Predicate::In(out))
398        }
399        Operator::StartsWith {
400            pattern,
401            case_sensitive,
402        } => Ok(Predicate::StartsWith {
403            pattern: pattern.to_string(),
404            case_sensitive: *case_sensitive,
405        }),
406        Operator::EndsWith {
407            pattern,
408            case_sensitive,
409        } => Ok(Predicate::EndsWith {
410            pattern: pattern.to_string(),
411            case_sensitive: *case_sensitive,
412        }),
413        Operator::Contains {
414            pattern,
415            case_sensitive,
416        } => Ok(Predicate::Contains {
417            pattern: pattern.to_string(),
418            case_sensitive: *case_sensitive,
419        }),
420        Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
421            "operator lacks string literal support",
422        )),
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::literal::Literal;
430    use std::ops::Bound;
431
432    #[test]
433    fn predicate_matches_equals() {
434        let op = Operator::Equals(42_i64.into());
435        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
436        let forty_two: i64 = 42;
437        let seven: i64 = 7;
438        assert!(predicate.matches(&forty_two));
439        assert!(!predicate.matches(&seven));
440    }
441
442    #[test]
443    fn predicate_range_limits() {
444        let op = Operator::Range {
445            lower: Bound::Included(10.into()),
446            upper: Bound::Excluded(20.into()),
447        };
448        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
449        assert!(predicate.matches(&10));
450        assert!(predicate.matches(&19));
451        assert!(!predicate.matches(&9));
452        assert!(!predicate.matches(&20));
453    }
454
455    #[test]
456    fn predicate_in_operator() {
457        let values = [1.into(), 2.into(), 3.into()];
458        let op = Operator::In(&values);
459        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
460        let two: u8 = 2;
461        let five: u8 = 5;
462        assert!(predicate.matches(&two));
463        assert!(!predicate.matches(&five));
464    }
465
466    #[test]
467    fn unsupported_operator_errors() {
468        let op = Operator::starts_with("foo", true);
469        let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
470        assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
471    }
472
473    #[test]
474    fn literal_cast_error_propagates() {
475        let op = Operator::Equals("foo".into());
476        let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
477        assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
478    }
479
480    #[test]
481    fn empty_bounds_map_to_all() {
482        let op = Operator::Range {
483            lower: Bound::Unbounded,
484            upper: Bound::Unbounded,
485        };
486        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
487        assert!(predicate.matches(&123u32));
488    }
489
490    #[test]
491    fn matches_all_for_empty_in_list() {
492        let values: [Literal; 0] = [];
493        let op = Operator::In(&values);
494        let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
495        assert!(!predicate.matches(&1.23f32));
496    }
497
498    #[test]
499    fn string_predicate_equals() {
500        let op = Operator::Equals("foo".into());
501        let predicate = build_var_width_predicate(&op).unwrap();
502        assert!(predicate.matches("foo"));
503        assert!(!predicate.matches("bar"));
504    }
505
506    #[test]
507    fn string_predicate_range() {
508        let op = Operator::Range {
509            lower: Bound::Included("alpha".into()),
510            upper: Bound::Excluded("omega".into()),
511        };
512        let predicate = build_var_width_predicate(&op).unwrap();
513        assert!(predicate.matches("delta"));
514        assert!(!predicate.matches("zzz"));
515    }
516
517    #[test]
518    fn string_predicate_in_and_patterns() {
519        let vals = ["x".into(), "y".into()];
520        let op = Operator::In(&vals);
521        let predicate = build_var_width_predicate(&op).unwrap();
522        assert!(predicate.matches("x"));
523        assert!(!predicate.matches("z"));
524
525        let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
526            .expect("starts with predicate");
527        assert!(sw_sensitive.matches("prefix"));
528        assert!(!sw_sensitive.matches("Prefix"));
529
530        let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
531            .expect("starts with predicate");
532        assert!(sw_insensitive.matches("prefix"));
533        assert!(sw_insensitive.matches("Prefix"));
534
535        let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
536            .expect("ends with predicate");
537        assert!(ew_sensitive.matches("datsuf"));
538        assert!(!ew_sensitive.matches("datSuf"));
539
540        let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
541            .expect("ends with predicate");
542        assert!(ew_insensitive.matches("datsuf"));
543        assert!(ew_insensitive.matches("datSuf"));
544
545        let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
546            .expect("contains predicate");
547        assert!(ct_sensitive.matches("amidst"));
548        assert!(!ct_sensitive.matches("aMidst"));
549
550        let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
551            .expect("contains predicate");
552        assert!(ct_insensitive.matches("amidst"));
553        assert!(ct_insensitive.matches("aMidst"));
554    }
555}