llkv_expr/
typed_predicate.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::Bound;
4
5use arrow::datatypes::ArrowPrimitiveType;
6
7use crate::expr::Operator;
8use crate::literal::{
9    FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
10};
11
12pub trait PredicateValue: Clone {
13    type Borrowed<'a>: ?Sized
14    where
15        Self: 'a;
16
17    fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
18    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
19    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
20    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
21        let _ = (value, target, case_sensitive);
22        false
23    }
24    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
25        let _ = (value, target, case_sensitive);
26        false
27    }
28    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
29        let _ = (value, target, case_sensitive);
30        false
31    }
32}
33
34#[derive(Debug, Clone)]
35pub enum Predicate<V>
36where
37    V: PredicateValue,
38{
39    All,
40    Equals(V),
41    GreaterThan(V),
42    GreaterThanOrEquals(V),
43    LessThan(V),
44    LessThanOrEquals(V),
45    Range {
46        lower: Option<Bound<V>>,
47        upper: Option<Bound<V>>,
48    },
49    In(Vec<V>),
50    StartsWith {
51        pattern: V,
52        case_sensitive: bool,
53    },
54    EndsWith {
55        pattern: V,
56        case_sensitive: bool,
57    },
58    Contains {
59        pattern: V,
60        case_sensitive: bool,
61    },
62}
63
64impl<V> Predicate<V>
65where
66    V: PredicateValue,
67{
68    pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
69        match self {
70            Predicate::All => true,
71            Predicate::Equals(target) => V::equals(value, target),
72            Predicate::GreaterThan(target) => {
73                matches!(V::compare(value, target), Some(Ordering::Greater))
74            }
75            Predicate::GreaterThanOrEquals(target) => {
76                matches!(
77                    V::compare(value, target),
78                    Some(Ordering::Greater | Ordering::Equal)
79                )
80            }
81            Predicate::LessThan(target) => {
82                matches!(V::compare(value, target), Some(Ordering::Less))
83            }
84            Predicate::LessThanOrEquals(target) => matches!(
85                V::compare(value, target),
86                Some(Ordering::Less | Ordering::Equal)
87            ),
88            Predicate::Range { lower, upper } => {
89                if let Some(bound) = lower
90                    && !match bound {
91                        Bound::Included(target) => matches!(
92                            V::compare(value, target),
93                            Some(Ordering::Greater | Ordering::Equal)
94                        ),
95                        Bound::Excluded(target) => {
96                            matches!(V::compare(value, target), Some(Ordering::Greater))
97                        }
98                        Bound::Unbounded => true,
99                    }
100                {
101                    return false;
102                }
103
104                if let Some(bound) = upper
105                    && !match bound {
106                        Bound::Included(target) => matches!(
107                            V::compare(value, target),
108                            Some(Ordering::Less | Ordering::Equal)
109                        ),
110                        Bound::Excluded(target) => {
111                            matches!(V::compare(value, target), Some(Ordering::Less))
112                        }
113                        Bound::Unbounded => true,
114                    }
115                {
116                    return false;
117                }
118
119                true
120            }
121            Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
122            Predicate::StartsWith {
123                pattern,
124                case_sensitive,
125            } => V::starts_with(value, pattern, *case_sensitive),
126            Predicate::EndsWith {
127                pattern,
128                case_sensitive,
129            } => V::ends_with(value, pattern, *case_sensitive),
130            Predicate::Contains {
131                pattern,
132                case_sensitive,
133            } => V::contains(value, pattern, *case_sensitive),
134        }
135    }
136}
137
138macro_rules! impl_predicate_value_for_primitive {
139    ($($ty:ty),+ $(,)?) => {
140        $(
141            impl PredicateValue for $ty {
142                type Borrowed<'a> = Self where Self: 'a;
143
144                fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
145                    value
146                }
147
148                fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
149                    *value == *target
150                }
151
152                fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
153                    value.partial_cmp(target)
154                }
155            }
156        )+
157    };
158}
159
160impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
161
162impl PredicateValue for String {
163    type Borrowed<'a>
164        = str
165    where
166        Self: 'a;
167
168    fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
169        value.as_str()
170    }
171
172    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
173        value == target.as_str()
174    }
175
176    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
177        Some(value.cmp(target.as_str()))
178    }
179
180    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
181        if case_sensitive {
182            value.contains(target.as_str())
183        } else {
184            value.to_lowercase().contains(&target.to_lowercase())
185        }
186    }
187
188    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
189        if case_sensitive {
190            value.starts_with(target.as_str())
191        } else {
192            value.to_lowercase().starts_with(&target.to_lowercase())
193        }
194    }
195
196    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
197        if case_sensitive {
198            value.ends_with(target.as_str())
199        } else {
200            value.to_lowercase().ends_with(&target.to_lowercase())
201        }
202    }
203}
204
205#[derive(Debug, Clone)]
206pub enum PredicateBuildError {
207    LiteralCast(LiteralCastError),
208    UnsupportedOperator(&'static str),
209}
210
211impl fmt::Display for PredicateBuildError {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
215            PredicateBuildError::UnsupportedOperator(op) => {
216                write!(f, "unsupported operator for typed predicate: {op}")
217            }
218        }
219    }
220}
221
222impl std::error::Error for PredicateBuildError {
223    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
224        match self {
225            PredicateBuildError::LiteralCast(err) => Some(err),
226            PredicateBuildError::UnsupportedOperator(_) => None,
227        }
228    }
229}
230
231impl From<LiteralCastError> for PredicateBuildError {
232    fn from(err: LiteralCastError) -> Self {
233        PredicateBuildError::LiteralCast(err)
234    }
235}
236
237pub fn build_fixed_width_predicate<T>(
238    op: &Operator<'_>,
239) -> Result<Predicate<T::Native>, PredicateBuildError>
240where
241    T: ArrowPrimitiveType,
242    T::Native: FromLiteral + Copy + PredicateValue,
243{
244    match op {
245        Operator::Equals(lit) => Ok(Predicate::Equals(
246            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
247        )),
248        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
249            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
250        )),
251        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
252            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
253        )),
254        Operator::LessThan(lit) => Ok(Predicate::LessThan(
255            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
256        )),
257        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
258            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
259        )),
260        Operator::Range { lower, upper } => {
261            let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
262                Bound::Unbounded => None,
263                other => Some(other),
264            };
265            let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
266                Bound::Unbounded => None,
267                other => Some(other),
268            };
269
270            if lb.is_none() && ub.is_none() {
271                Ok(Predicate::All)
272            } else {
273                Ok(Predicate::Range {
274                    lower: lb,
275                    upper: ub,
276                })
277            }
278        }
279        Operator::In(values) => {
280            let mut natives = Vec::with_capacity(values.len());
281            for lit in *values {
282                natives
283                    .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
284            }
285            Ok(Predicate::In(natives))
286        }
287        _ => Err(PredicateBuildError::UnsupportedOperator(
288            "operator lacks typed literal support",
289        )),
290    }
291}
292
293fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
294    Ok(match bound {
295        Bound::Unbounded => None,
296        Bound::Included(lit) => Some(Bound::Included(
297            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
298        )),
299        Bound::Excluded(lit) => Some(Bound::Excluded(
300            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
301        )),
302    })
303}
304
305pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
306    match op {
307        Operator::Equals(lit) => Ok(Predicate::Equals(
308            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
309        )),
310        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
311            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
312        )),
313        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
314            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315        )),
316        Operator::LessThan(lit) => Ok(Predicate::LessThan(
317            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318        )),
319        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
320            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
321        )),
322        Operator::Range { lower, upper } => {
323            let lb = parse_bool_bound(lower)?;
324            let ub = parse_bool_bound(upper)?;
325            if lb.is_none() && ub.is_none() {
326                Ok(Predicate::All)
327            } else {
328                Ok(Predicate::Range {
329                    lower: lb,
330                    upper: ub,
331                })
332            }
333        }
334        Operator::In(values) => {
335            let mut natives = Vec::with_capacity(values.len());
336            for lit in *values {
337                natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
338            }
339            Ok(Predicate::In(natives))
340        }
341        _ => Err(PredicateBuildError::UnsupportedOperator(
342            "operator lacks boolean literal support",
343        )),
344    }
345}
346
347fn parse_string_bound(
348    bound: &Bound<Literal>,
349) -> Result<Option<Bound<String>>, PredicateBuildError> {
350    match bound {
351        Bound::Unbounded => Ok(None),
352        Bound::Included(lit) => literal_to_string(lit)
353            .map(|s| Some(Bound::Included(s)))
354            .map_err(PredicateBuildError::from),
355        Bound::Excluded(lit) => literal_to_string(lit)
356            .map(|s| Some(Bound::Excluded(s)))
357            .map_err(PredicateBuildError::from),
358    }
359}
360
361pub fn build_var_width_predicate(
362    op: &Operator<'_>,
363) -> Result<Predicate<String>, PredicateBuildError> {
364    match op {
365        Operator::Equals(lit) => Ok(Predicate::Equals(
366            literal_to_string(lit).map_err(PredicateBuildError::from)?,
367        )),
368        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
369            literal_to_string(lit).map_err(PredicateBuildError::from)?,
370        )),
371        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
372            literal_to_string(lit).map_err(PredicateBuildError::from)?,
373        )),
374        Operator::LessThan(lit) => Ok(Predicate::LessThan(
375            literal_to_string(lit).map_err(PredicateBuildError::from)?,
376        )),
377        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
378            literal_to_string(lit).map_err(PredicateBuildError::from)?,
379        )),
380        Operator::Range { lower, upper } => {
381            let lb = parse_string_bound(lower)?;
382            let ub = parse_string_bound(upper)?;
383            if lb.is_none() && ub.is_none() {
384                Ok(Predicate::All)
385            } else {
386                Ok(Predicate::Range {
387                    lower: lb,
388                    upper: ub,
389                })
390            }
391        }
392        Operator::In(values) => {
393            let mut out = Vec::with_capacity(values.len());
394            for lit in *values {
395                out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
396            }
397            Ok(Predicate::In(out))
398        }
399        Operator::StartsWith {
400            pattern,
401            case_sensitive,
402        } => Ok(Predicate::StartsWith {
403            pattern: pattern.to_string(),
404            case_sensitive: *case_sensitive,
405        }),
406        Operator::EndsWith {
407            pattern,
408            case_sensitive,
409        } => Ok(Predicate::EndsWith {
410            pattern: pattern.to_string(),
411            case_sensitive: *case_sensitive,
412        }),
413        Operator::Contains {
414            pattern,
415            case_sensitive,
416        } => Ok(Predicate::Contains {
417            pattern: pattern.to_string(),
418            case_sensitive: *case_sensitive,
419        }),
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::literal::Literal;
427    use std::ops::Bound;
428
429    #[test]
430    fn predicate_matches_equals() {
431        let op = Operator::Equals(42_i64.into());
432        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
433        let forty_two: i64 = 42;
434        let seven: i64 = 7;
435        assert!(predicate.matches(&forty_two));
436        assert!(!predicate.matches(&seven));
437    }
438
439    #[test]
440    fn predicate_range_limits() {
441        let op = Operator::Range {
442            lower: Bound::Included(10.into()),
443            upper: Bound::Excluded(20.into()),
444        };
445        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
446        assert!(predicate.matches(&10));
447        assert!(predicate.matches(&19));
448        assert!(!predicate.matches(&9));
449        assert!(!predicate.matches(&20));
450    }
451
452    #[test]
453    fn predicate_in_operator() {
454        let values = [1.into(), 2.into(), 3.into()];
455        let op = Operator::In(&values);
456        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
457        let two: u8 = 2;
458        let five: u8 = 5;
459        assert!(predicate.matches(&two));
460        assert!(!predicate.matches(&five));
461    }
462
463    #[test]
464    fn unsupported_operator_errors() {
465        let op = Operator::starts_with("foo", true);
466        let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
467        assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
468    }
469
470    #[test]
471    fn literal_cast_error_propagates() {
472        let op = Operator::Equals("foo".into());
473        let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
474        assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
475    }
476
477    #[test]
478    fn empty_bounds_map_to_all() {
479        let op = Operator::Range {
480            lower: Bound::Unbounded,
481            upper: Bound::Unbounded,
482        };
483        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
484        assert!(predicate.matches(&123u32));
485    }
486
487    #[test]
488    fn matches_all_for_empty_in_list() {
489        let values: [Literal; 0] = [];
490        let op = Operator::In(&values);
491        let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
492        assert!(!predicate.matches(&1.23f32));
493    }
494
495    #[test]
496    fn string_predicate_equals() {
497        let op = Operator::Equals("foo".into());
498        let predicate = build_var_width_predicate(&op).unwrap();
499        assert!(predicate.matches("foo"));
500        assert!(!predicate.matches("bar"));
501    }
502
503    #[test]
504    fn string_predicate_range() {
505        let op = Operator::Range {
506            lower: Bound::Included("alpha".into()),
507            upper: Bound::Excluded("omega".into()),
508        };
509        let predicate = build_var_width_predicate(&op).unwrap();
510        assert!(predicate.matches("delta"));
511        assert!(!predicate.matches("zzz"));
512    }
513
514    #[test]
515    fn string_predicate_in_and_patterns() {
516        let vals = ["x".into(), "y".into()];
517        let op = Operator::In(&vals);
518        let predicate = build_var_width_predicate(&op).unwrap();
519        assert!(predicate.matches("x"));
520        assert!(!predicate.matches("z"));
521
522        let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
523            .expect("starts with predicate");
524        assert!(sw_sensitive.matches("prefix"));
525        assert!(!sw_sensitive.matches("Prefix"));
526
527        let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
528            .expect("starts with predicate");
529        assert!(sw_insensitive.matches("prefix"));
530        assert!(sw_insensitive.matches("Prefix"));
531
532        let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
533            .expect("ends with predicate");
534        assert!(ew_sensitive.matches("datsuf"));
535        assert!(!ew_sensitive.matches("datSuf"));
536
537        let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
538            .expect("ends with predicate");
539        assert!(ew_insensitive.matches("datsuf"));
540        assert!(ew_insensitive.matches("datSuf"));
541
542        let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
543            .expect("contains predicate");
544        assert!(ct_sensitive.matches("amidst"));
545        assert!(!ct_sensitive.matches("aMidst"));
546
547        let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
548            .expect("contains predicate");
549        assert!(ct_insensitive.matches("amidst"));
550        assert!(ct_insensitive.matches("aMidst"));
551    }
552}