llkv_expr/
typed_predicate.rs

1//! Build and evaluate fully typed predicates derived from logical expressions.
2//!
3//! The conversion utilities bridge the logical [`crate::expr::Operator`] values that
4//! operate on untyped [`crate::literal::Literal`] instances and the concrete predicate
5//! evaluators needed by execution code.
6
7use std::cmp::Ordering;
8use std::fmt;
9use std::ops::Bound;
10
11use arrow::datatypes::ArrowPrimitiveType;
12
13use crate::expr::Operator;
14use crate::literal::{
15    FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
16};
17
18/// Value that can participate in typed predicate evaluation.
19pub trait PredicateValue: Clone {
20    type Borrowed<'a>: ?Sized
21    where
22        Self: 'a;
23
24    fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
25    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
26    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
27    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
28        let _ = (value, target, case_sensitive);
29        false
30    }
31    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
32        let _ = (value, target, case_sensitive);
33        false
34    }
35    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
36        let _ = (value, target, case_sensitive);
37        false
38    }
39}
40
41/// Fully typed predicate ready to be matched against borrowed values.
42#[derive(Debug, Clone)]
43pub enum Predicate<V>
44where
45    V: PredicateValue,
46{
47    All,
48    Equals(V),
49    GreaterThan(V),
50    GreaterThanOrEquals(V),
51    LessThan(V),
52    LessThanOrEquals(V),
53    Range {
54        lower: Option<Bound<V>>,
55        upper: Option<Bound<V>>,
56    },
57    In(Vec<V>),
58    StartsWith {
59        pattern: V,
60        case_sensitive: bool,
61    },
62    EndsWith {
63        pattern: V,
64        case_sensitive: bool,
65    },
66    Contains {
67        pattern: V,
68        case_sensitive: bool,
69    },
70}
71
72impl<V> Predicate<V>
73where
74    V: PredicateValue,
75{
76    /// Return `true` when `value` satisfies the predicate variant.
77    pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
78        match self {
79            Predicate::All => true,
80            Predicate::Equals(target) => V::equals(value, target),
81            Predicate::GreaterThan(target) => {
82                matches!(V::compare(value, target), Some(Ordering::Greater))
83            }
84            Predicate::GreaterThanOrEquals(target) => {
85                matches!(
86                    V::compare(value, target),
87                    Some(Ordering::Greater | Ordering::Equal)
88                )
89            }
90            Predicate::LessThan(target) => {
91                matches!(V::compare(value, target), Some(Ordering::Less))
92            }
93            Predicate::LessThanOrEquals(target) => matches!(
94                V::compare(value, target),
95                Some(Ordering::Less | Ordering::Equal)
96            ),
97            Predicate::Range { lower, upper } => {
98                if let Some(bound) = lower
99                    && !match bound {
100                        Bound::Included(target) => matches!(
101                            V::compare(value, target),
102                            Some(Ordering::Greater | Ordering::Equal)
103                        ),
104                        Bound::Excluded(target) => {
105                            matches!(V::compare(value, target), Some(Ordering::Greater))
106                        }
107                        Bound::Unbounded => true,
108                    }
109                {
110                    return false;
111                }
112
113                if let Some(bound) = upper
114                    && !match bound {
115                        Bound::Included(target) => matches!(
116                            V::compare(value, target),
117                            Some(Ordering::Less | Ordering::Equal)
118                        ),
119                        Bound::Excluded(target) => {
120                            matches!(V::compare(value, target), Some(Ordering::Less))
121                        }
122                        Bound::Unbounded => true,
123                    }
124                {
125                    return false;
126                }
127
128                true
129            }
130            Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
131            Predicate::StartsWith {
132                pattern,
133                case_sensitive,
134            } => V::starts_with(value, pattern, *case_sensitive),
135            Predicate::EndsWith {
136                pattern,
137                case_sensitive,
138            } => V::ends_with(value, pattern, *case_sensitive),
139            Predicate::Contains {
140                pattern,
141                case_sensitive,
142            } => V::contains(value, pattern, *case_sensitive),
143        }
144    }
145}
146
147macro_rules! impl_predicate_value_for_primitive {
148    ($($ty:ty),+ $(,)?) => {
149        $(
150            impl PredicateValue for $ty {
151                type Borrowed<'a> = Self where Self: 'a;
152
153                fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
154                    value
155                }
156
157                fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
158                    *value == *target
159                }
160
161                fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
162                    value.partial_cmp(target)
163                }
164            }
165        )+
166    };
167}
168
169impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
170
171impl PredicateValue for String {
172    type Borrowed<'a>
173        = str
174    where
175        Self: 'a;
176
177    fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
178        value.as_str()
179    }
180
181    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
182        value == target.as_str()
183    }
184
185    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
186        Some(value.cmp(target.as_str()))
187    }
188
189    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
190        if case_sensitive {
191            value.contains(target.as_str())
192        } else {
193            value.to_lowercase().contains(&target.to_lowercase())
194        }
195    }
196
197    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
198        if case_sensitive {
199            value.starts_with(target.as_str())
200        } else {
201            value.to_lowercase().starts_with(&target.to_lowercase())
202        }
203    }
204
205    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
206        if case_sensitive {
207            value.ends_with(target.as_str())
208        } else {
209            value.to_lowercase().ends_with(&target.to_lowercase())
210        }
211    }
212}
213
214/// Error building a typed predicate from a logical operator.
215#[derive(Debug, Clone)]
216pub enum PredicateBuildError {
217    LiteralCast(LiteralCastError),
218    UnsupportedOperator(&'static str),
219}
220
221impl fmt::Display for PredicateBuildError {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        match self {
224            PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
225            PredicateBuildError::UnsupportedOperator(op) => {
226                write!(f, "unsupported operator for typed predicate: {op}")
227            }
228        }
229    }
230}
231
232impl std::error::Error for PredicateBuildError {
233    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
234        match self {
235            PredicateBuildError::LiteralCast(err) => Some(err),
236            PredicateBuildError::UnsupportedOperator(_) => None,
237        }
238    }
239}
240
241impl From<LiteralCastError> for PredicateBuildError {
242    fn from(err: LiteralCastError) -> Self {
243        PredicateBuildError::LiteralCast(err)
244    }
245}
246
247/// Convert a logical operator into a predicate for fixed-width Arrow types.
248///
249/// # Errors
250///
251/// Returns [`PredicateBuildError::LiteralCast`] when the provided literal cannot be coerced into
252/// the target native type or [`PredicateBuildError::UnsupportedOperator`] when the operator is not
253/// supported for fixed-width values.
254pub fn build_fixed_width_predicate<T>(
255    op: &Operator<'_>,
256) -> Result<Predicate<T::Native>, PredicateBuildError>
257where
258    T: ArrowPrimitiveType,
259    T::Native: FromLiteral + Copy + PredicateValue,
260{
261    match op {
262        Operator::Equals(lit) => Ok(Predicate::Equals(
263            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
264        )),
265        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
266            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
267        )),
268        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
269            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
270        )),
271        Operator::LessThan(lit) => Ok(Predicate::LessThan(
272            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
273        )),
274        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
275            literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
276        )),
277        Operator::Range { lower, upper } => {
278            let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
279                Bound::Unbounded => None,
280                other => Some(other),
281            };
282            let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
283                Bound::Unbounded => None,
284                other => Some(other),
285            };
286
287            if lb.is_none() && ub.is_none() {
288                Ok(Predicate::All)
289            } else {
290                Ok(Predicate::Range {
291                    lower: lb,
292                    upper: ub,
293                })
294            }
295        }
296        Operator::In(values) => {
297            let mut natives = Vec::with_capacity(values.len());
298            for lit in *values {
299                natives
300                    .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
301            }
302            Ok(Predicate::In(natives))
303        }
304        _ => Err(PredicateBuildError::UnsupportedOperator(
305            "operator lacks typed literal support",
306        )),
307    }
308}
309
310fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
311    Ok(match bound {
312        Bound::Unbounded => None,
313        Bound::Included(lit) => Some(Bound::Included(
314            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315        )),
316        Bound::Excluded(lit) => Some(Bound::Excluded(
317            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318        )),
319    })
320}
321
322/// Convert a logical operator into a predicate over boolean values.
323///
324/// # Errors
325///
326/// Returns [`PredicateBuildError::LiteralCast`] when the literal cannot be interpreted as a
327/// boolean or [`PredicateBuildError::UnsupportedOperator`] when string-specific predicates are
328/// attempted.
329pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
330    match op {
331        Operator::Equals(lit) => Ok(Predicate::Equals(
332            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
333        )),
334        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
335            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
336        )),
337        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
338            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
339        )),
340        Operator::LessThan(lit) => Ok(Predicate::LessThan(
341            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
342        )),
343        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
344            literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
345        )),
346        Operator::Range { lower, upper } => {
347            let lb = parse_bool_bound(lower)?;
348            let ub = parse_bool_bound(upper)?;
349            if lb.is_none() && ub.is_none() {
350                Ok(Predicate::All)
351            } else {
352                Ok(Predicate::Range {
353                    lower: lb,
354                    upper: ub,
355                })
356            }
357        }
358        Operator::In(values) => {
359            let mut natives = Vec::with_capacity(values.len());
360            for lit in *values {
361                natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
362            }
363            Ok(Predicate::In(natives))
364        }
365        _ => Err(PredicateBuildError::UnsupportedOperator(
366            "operator lacks boolean literal support",
367        )),
368    }
369}
370
371fn parse_string_bound(
372    bound: &Bound<Literal>,
373) -> Result<Option<Bound<String>>, PredicateBuildError> {
374    match bound {
375        Bound::Unbounded => Ok(None),
376        Bound::Included(lit) => literal_to_string(lit)
377            .map(|s| Some(Bound::Included(s)))
378            .map_err(PredicateBuildError::from),
379        Bound::Excluded(lit) => literal_to_string(lit)
380            .map(|s| Some(Bound::Excluded(s)))
381            .map_err(PredicateBuildError::from),
382    }
383}
384
385/// Convert a logical operator into a predicate over UTF-8 string values.
386///
387/// # Errors
388///
389/// Returns [`PredicateBuildError::LiteralCast`] when literals cannot be converted into strings or
390/// [`PredicateBuildError::UnsupportedOperator`] when the operator is not yet implemented for
391/// strings.
392pub fn build_var_width_predicate(
393    op: &Operator<'_>,
394) -> Result<Predicate<String>, PredicateBuildError> {
395    match op {
396        Operator::Equals(lit) => Ok(Predicate::Equals(
397            literal_to_string(lit).map_err(PredicateBuildError::from)?,
398        )),
399        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
400            literal_to_string(lit).map_err(PredicateBuildError::from)?,
401        )),
402        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
403            literal_to_string(lit).map_err(PredicateBuildError::from)?,
404        )),
405        Operator::LessThan(lit) => Ok(Predicate::LessThan(
406            literal_to_string(lit).map_err(PredicateBuildError::from)?,
407        )),
408        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
409            literal_to_string(lit).map_err(PredicateBuildError::from)?,
410        )),
411        Operator::Range { lower, upper } => {
412            let lb = parse_string_bound(lower)?;
413            let ub = parse_string_bound(upper)?;
414            if lb.is_none() && ub.is_none() {
415                Ok(Predicate::All)
416            } else {
417                Ok(Predicate::Range {
418                    lower: lb,
419                    upper: ub,
420                })
421            }
422        }
423        Operator::In(values) => {
424            let mut out = Vec::with_capacity(values.len());
425            for lit in *values {
426                out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
427            }
428            Ok(Predicate::In(out))
429        }
430        Operator::StartsWith {
431            pattern,
432            case_sensitive,
433        } => Ok(Predicate::StartsWith {
434            pattern: pattern.to_string(),
435            case_sensitive: *case_sensitive,
436        }),
437        Operator::EndsWith {
438            pattern,
439            case_sensitive,
440        } => Ok(Predicate::EndsWith {
441            pattern: pattern.to_string(),
442            case_sensitive: *case_sensitive,
443        }),
444        Operator::Contains {
445            pattern,
446            case_sensitive,
447        } => Ok(Predicate::Contains {
448            pattern: pattern.to_string(),
449            case_sensitive: *case_sensitive,
450        }),
451        Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
452            "operator lacks string literal support",
453        )),
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::literal::Literal;
461    use std::ops::Bound;
462
463    #[test]
464    fn predicate_matches_equals() {
465        let op = Operator::Equals(42_i64.into());
466        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
467        let forty_two: i64 = 42;
468        let seven: i64 = 7;
469        assert!(predicate.matches(&forty_two));
470        assert!(!predicate.matches(&seven));
471    }
472
473    #[test]
474    fn predicate_range_limits() {
475        let op = Operator::Range {
476            lower: Bound::Included(10.into()),
477            upper: Bound::Excluded(20.into()),
478        };
479        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
480        assert!(predicate.matches(&10));
481        assert!(predicate.matches(&19));
482        assert!(!predicate.matches(&9));
483        assert!(!predicate.matches(&20));
484    }
485
486    #[test]
487    fn predicate_in_operator() {
488        let values = [1.into(), 2.into(), 3.into()];
489        let op = Operator::In(&values);
490        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
491        let two: u8 = 2;
492        let five: u8 = 5;
493        assert!(predicate.matches(&two));
494        assert!(!predicate.matches(&five));
495    }
496
497    #[test]
498    fn unsupported_operator_errors() {
499        let op = Operator::starts_with("foo", true);
500        let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
501        assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
502    }
503
504    #[test]
505    fn literal_cast_error_propagates() {
506        let op = Operator::Equals("foo".into());
507        let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
508        assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
509    }
510
511    #[test]
512    fn empty_bounds_map_to_all() {
513        let op = Operator::Range {
514            lower: Bound::Unbounded,
515            upper: Bound::Unbounded,
516        };
517        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
518        assert!(predicate.matches(&123u32));
519    }
520
521    #[test]
522    fn matches_all_for_empty_in_list() {
523        let values: [Literal; 0] = [];
524        let op = Operator::In(&values);
525        let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
526        assert!(!predicate.matches(&1.23f32));
527    }
528
529    #[test]
530    fn string_predicate_equals() {
531        let op = Operator::Equals("foo".into());
532        let predicate = build_var_width_predicate(&op).unwrap();
533        assert!(predicate.matches("foo"));
534        assert!(!predicate.matches("bar"));
535    }
536
537    #[test]
538    fn string_predicate_range() {
539        let op = Operator::Range {
540            lower: Bound::Included("alpha".into()),
541            upper: Bound::Excluded("omega".into()),
542        };
543        let predicate = build_var_width_predicate(&op).unwrap();
544        assert!(predicate.matches("delta"));
545        assert!(!predicate.matches("zzz"));
546    }
547
548    #[test]
549    fn string_predicate_in_and_patterns() {
550        let vals = ["x".into(), "y".into()];
551        let op = Operator::In(&vals);
552        let predicate = build_var_width_predicate(&op).unwrap();
553        assert!(predicate.matches("x"));
554        assert!(!predicate.matches("z"));
555
556        let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
557            .expect("starts with predicate");
558        assert!(sw_sensitive.matches("prefix"));
559        assert!(!sw_sensitive.matches("Prefix"));
560
561        let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
562            .expect("starts with predicate");
563        assert!(sw_insensitive.matches("prefix"));
564        assert!(sw_insensitive.matches("Prefix"));
565
566        let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
567            .expect("ends with predicate");
568        assert!(ew_sensitive.matches("datsuf"));
569        assert!(!ew_sensitive.matches("datSuf"));
570
571        let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
572            .expect("ends with predicate");
573        assert!(ew_insensitive.matches("datsuf"));
574        assert!(ew_insensitive.matches("datSuf"));
575
576        let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
577            .expect("contains predicate");
578        assert!(ct_sensitive.matches("amidst"));
579        assert!(!ct_sensitive.matches("aMidst"));
580
581        let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
582            .expect("contains predicate");
583        assert!(ct_insensitive.matches("amidst"));
584        assert!(ct_insensitive.matches("aMidst"));
585    }
586}