llkv_expr/
typed_predicate.rs

1//! Build and evaluate fully typed predicates derived from logical expressions.
2//!
3//! The conversion utilities bridge the logical [`crate::expr::Operator`] values that
4//! operate on untyped [`crate::literal::Literal`] instances and the concrete predicate
5//! evaluators needed by execution code.
6
7use std::cmp::Ordering;
8use std::fmt;
9use std::ops::Bound;
10
11use arrow::datatypes::ArrowPrimitiveType;
12
13use crate::expr::Operator;
14use crate::literal::{FromLiteral, Literal, LiteralCastError, LiteralExt};
15
16/// Value that can participate in typed predicate evaluation.
17pub trait PredicateValue: Clone {
18    type Borrowed<'a>: ?Sized
19    where
20        Self: 'a;
21
22    fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
23    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
24    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
25    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
26        let _ = (value, target, case_sensitive);
27        false
28    }
29    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
30        let _ = (value, target, case_sensitive);
31        false
32    }
33    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
34        let _ = (value, target, case_sensitive);
35        false
36    }
37}
38
39/// Fully typed predicate ready to be matched against borrowed values.
40#[derive(Debug, Clone)]
41pub enum Predicate<V>
42where
43    V: PredicateValue,
44{
45    All,
46    Equals(V),
47    GreaterThan(V),
48    GreaterThanOrEquals(V),
49    LessThan(V),
50    LessThanOrEquals(V),
51    Range {
52        lower: Option<Bound<V>>,
53        upper: Option<Bound<V>>,
54    },
55    In(Vec<V>),
56    StartsWith {
57        pattern: V,
58        case_sensitive: bool,
59    },
60    EndsWith {
61        pattern: V,
62        case_sensitive: bool,
63    },
64    Contains {
65        pattern: V,
66        case_sensitive: bool,
67    },
68}
69
70impl<V> Predicate<V>
71where
72    V: PredicateValue,
73{
74    /// Return `true` when `value` satisfies the predicate variant.
75    pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
76        match self {
77            Predicate::All => true,
78            Predicate::Equals(target) => V::equals(value, target),
79            Predicate::GreaterThan(target) => {
80                matches!(V::compare(value, target), Some(Ordering::Greater))
81            }
82            Predicate::GreaterThanOrEquals(target) => {
83                matches!(
84                    V::compare(value, target),
85                    Some(Ordering::Greater | Ordering::Equal)
86                )
87            }
88            Predicate::LessThan(target) => {
89                matches!(V::compare(value, target), Some(Ordering::Less))
90            }
91            Predicate::LessThanOrEquals(target) => matches!(
92                V::compare(value, target),
93                Some(Ordering::Less | Ordering::Equal)
94            ),
95            Predicate::Range { lower, upper } => {
96                if let Some(bound) = lower
97                    && !match bound {
98                        Bound::Included(target) => matches!(
99                            V::compare(value, target),
100                            Some(Ordering::Greater | Ordering::Equal)
101                        ),
102                        Bound::Excluded(target) => {
103                            matches!(V::compare(value, target), Some(Ordering::Greater))
104                        }
105                        Bound::Unbounded => true,
106                    }
107                {
108                    return false;
109                }
110
111                if let Some(bound) = upper
112                    && !match bound {
113                        Bound::Included(target) => matches!(
114                            V::compare(value, target),
115                            Some(Ordering::Less | Ordering::Equal)
116                        ),
117                        Bound::Excluded(target) => {
118                            matches!(V::compare(value, target), Some(Ordering::Less))
119                        }
120                        Bound::Unbounded => true,
121                    }
122                {
123                    return false;
124                }
125
126                true
127            }
128            Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
129            Predicate::StartsWith {
130                pattern,
131                case_sensitive,
132            } => V::starts_with(value, pattern, *case_sensitive),
133            Predicate::EndsWith {
134                pattern,
135                case_sensitive,
136            } => V::ends_with(value, pattern, *case_sensitive),
137            Predicate::Contains {
138                pattern,
139                case_sensitive,
140            } => V::contains(value, pattern, *case_sensitive),
141        }
142    }
143}
144
145macro_rules! impl_predicate_value_for_primitive {
146    ($($ty:ty),+ $(,)?) => {
147        $(
148            impl PredicateValue for $ty {
149                type Borrowed<'a> = Self where Self: 'a;
150
151                fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
152                    value
153                }
154
155                fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
156                    *value == *target
157                }
158
159                fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
160                    value.partial_cmp(target)
161                }
162            }
163        )+
164    };
165}
166
167impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
168
169impl PredicateValue for String {
170    type Borrowed<'a>
171        = str
172    where
173        Self: 'a;
174
175    fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
176        value.as_str()
177    }
178
179    fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
180        value == target.as_str()
181    }
182
183    fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
184        Some(value.cmp(target.as_str()))
185    }
186
187    fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
188        if case_sensitive {
189            value.contains(target.as_str())
190        } else {
191            value.to_lowercase().contains(&target.to_lowercase())
192        }
193    }
194
195    fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
196        if case_sensitive {
197            value.starts_with(target.as_str())
198        } else {
199            value.to_lowercase().starts_with(&target.to_lowercase())
200        }
201    }
202
203    fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
204        if case_sensitive {
205            value.ends_with(target.as_str())
206        } else {
207            value.to_lowercase().ends_with(&target.to_lowercase())
208        }
209    }
210}
211
212/// Error building a typed predicate from a logical operator.
213#[derive(Debug, Clone)]
214pub enum PredicateBuildError {
215    LiteralCast(LiteralCastError),
216    UnsupportedOperator(&'static str),
217}
218
219impl fmt::Display for PredicateBuildError {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        match self {
222            PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
223            PredicateBuildError::UnsupportedOperator(op) => {
224                write!(f, "unsupported operator for typed predicate: {op}")
225            }
226        }
227    }
228}
229
230impl std::error::Error for PredicateBuildError {
231    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232        match self {
233            PredicateBuildError::LiteralCast(err) => Some(err),
234            PredicateBuildError::UnsupportedOperator(_) => None,
235        }
236    }
237}
238
239impl From<LiteralCastError> for PredicateBuildError {
240    fn from(err: LiteralCastError) -> Self {
241        PredicateBuildError::LiteralCast(err)
242    }
243}
244
245/// Convert a logical operator into a predicate for fixed-width Arrow types.
246///
247/// # Errors
248///
249/// Returns [`PredicateBuildError::LiteralCast`] when the provided literal cannot be coerced into
250/// the target native type or [`PredicateBuildError::UnsupportedOperator`] when the operator is not
251/// supported for fixed-width values.
252pub fn build_fixed_width_predicate<T>(
253    op: &Operator<'_>,
254) -> Result<Predicate<T::Native>, PredicateBuildError>
255where
256    T: ArrowPrimitiveType,
257    T::Native: FromLiteral + Copy + PredicateValue,
258{
259    match op {
260        Operator::Equals(lit) => Ok(Predicate::Equals(
261            lit.to_native::<T::Native>()
262                .map_err(PredicateBuildError::from)?,
263        )),
264        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
265            lit.to_native::<T::Native>()
266                .map_err(PredicateBuildError::from)?,
267        )),
268        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
269            lit.to_native::<T::Native>()
270                .map_err(PredicateBuildError::from)?,
271        )),
272        Operator::LessThan(lit) => Ok(Predicate::LessThan(
273            lit.to_native::<T::Native>()
274                .map_err(PredicateBuildError::from)?,
275        )),
276        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
277            lit.to_native::<T::Native>()
278                .map_err(PredicateBuildError::from)?,
279        )),
280        Operator::Range { lower, upper } => {
281            let lb =
282                match Literal::bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
283                    Bound::Unbounded => None,
284                    other => Some(other),
285                };
286            let ub =
287                match Literal::bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
288                    Bound::Unbounded => None,
289                    other => Some(other),
290                };
291
292            if lb.is_none() && ub.is_none() {
293                Ok(Predicate::All)
294            } else {
295                Ok(Predicate::Range {
296                    lower: lb,
297                    upper: ub,
298                })
299            }
300        }
301        Operator::In(values) => {
302            let mut natives = Vec::with_capacity(values.len());
303            for lit in *values {
304                natives.push(
305                    lit.to_native::<T::Native>()
306                        .map_err(PredicateBuildError::from)?,
307                );
308            }
309            Ok(Predicate::In(natives))
310        }
311        _ => Err(PredicateBuildError::UnsupportedOperator(
312            "operator lacks typed literal support",
313        )),
314    }
315}
316
317fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
318    Ok(match bound {
319        Bound::Unbounded => None,
320        Bound::Included(lit) => Some(Bound::Included(
321            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
322        )),
323        Bound::Excluded(lit) => Some(Bound::Excluded(
324            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
325        )),
326    })
327}
328
329/// Convert a logical operator into a predicate over boolean values.
330///
331/// # Errors
332///
333/// Returns [`PredicateBuildError::LiteralCast`] when the literal cannot be interpreted as a
334/// boolean or [`PredicateBuildError::UnsupportedOperator`] when string-specific predicates are
335/// attempted.
336pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
337    match op {
338        Operator::Equals(lit) => Ok(Predicate::Equals(
339            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
340        )),
341        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
342            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
343        )),
344        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
345            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
346        )),
347        Operator::LessThan(lit) => Ok(Predicate::LessThan(
348            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
349        )),
350        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
351            lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
352        )),
353        Operator::Range { lower, upper } => {
354            let lb = parse_bool_bound(lower)?;
355            let ub = parse_bool_bound(upper)?;
356            if lb.is_none() && ub.is_none() {
357                Ok(Predicate::All)
358            } else {
359                Ok(Predicate::Range {
360                    lower: lb,
361                    upper: ub,
362                })
363            }
364        }
365        Operator::In(values) => {
366            let mut natives = Vec::with_capacity(values.len());
367            for lit in *values {
368                natives.push(lit.to_native::<bool>().map_err(PredicateBuildError::from)?);
369            }
370            Ok(Predicate::In(natives))
371        }
372        _ => Err(PredicateBuildError::UnsupportedOperator(
373            "operator lacks boolean literal support",
374        )),
375    }
376}
377
378fn parse_string_bound(
379    bound: &Bound<Literal>,
380) -> Result<Option<Bound<String>>, PredicateBuildError> {
381    match bound {
382        Bound::Unbounded => Ok(None),
383        Bound::Included(lit) => lit
384            .to_string_owned()
385            .map(|s| Some(Bound::Included(s)))
386            .map_err(PredicateBuildError::from),
387        Bound::Excluded(lit) => lit
388            .to_string_owned()
389            .map(|s| Some(Bound::Excluded(s)))
390            .map_err(PredicateBuildError::from),
391    }
392}
393
394/// Convert a logical operator into a predicate over UTF-8 string values.
395///
396/// # Errors
397///
398/// Returns [`PredicateBuildError::LiteralCast`] when literals cannot be converted into strings or
399/// [`PredicateBuildError::UnsupportedOperator`] when the operator is not yet implemented for
400/// strings.
401pub fn build_var_width_predicate(
402    op: &Operator<'_>,
403) -> Result<Predicate<String>, PredicateBuildError> {
404    match op {
405        Operator::Equals(lit) => Ok(Predicate::Equals(
406            lit.to_string_owned().map_err(PredicateBuildError::from)?,
407        )),
408        Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
409            lit.to_string_owned().map_err(PredicateBuildError::from)?,
410        )),
411        Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
412            lit.to_string_owned().map_err(PredicateBuildError::from)?,
413        )),
414        Operator::LessThan(lit) => Ok(Predicate::LessThan(
415            lit.to_string_owned().map_err(PredicateBuildError::from)?,
416        )),
417        Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
418            lit.to_string_owned().map_err(PredicateBuildError::from)?,
419        )),
420        Operator::Range { lower, upper } => {
421            let lb = parse_string_bound(lower)?;
422            let ub = parse_string_bound(upper)?;
423            if lb.is_none() && ub.is_none() {
424                Ok(Predicate::All)
425            } else {
426                Ok(Predicate::Range {
427                    lower: lb,
428                    upper: ub,
429                })
430            }
431        }
432        Operator::In(values) => {
433            let mut out = Vec::with_capacity(values.len());
434            for lit in *values {
435                out.push(lit.to_string_owned().map_err(PredicateBuildError::from)?);
436            }
437            Ok(Predicate::In(out))
438        }
439        Operator::StartsWith {
440            pattern,
441            case_sensitive,
442        } => Ok(Predicate::StartsWith {
443            pattern: pattern.to_string(),
444            case_sensitive: *case_sensitive,
445        }),
446        Operator::EndsWith {
447            pattern,
448            case_sensitive,
449        } => Ok(Predicate::EndsWith {
450            pattern: pattern.to_string(),
451            case_sensitive: *case_sensitive,
452        }),
453        Operator::Contains {
454            pattern,
455            case_sensitive,
456        } => Ok(Predicate::Contains {
457            pattern: pattern.to_string(),
458            case_sensitive: *case_sensitive,
459        }),
460        Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
461            "operator lacks string literal support",
462        )),
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::literal::Literal;
470    use std::ops::Bound;
471
472    #[test]
473    fn predicate_matches_equals() {
474        let op = Operator::Equals(42_i64.into());
475        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
476        let forty_two: i64 = 42;
477        let seven: i64 = 7;
478        assert!(predicate.matches(&forty_two));
479        assert!(!predicate.matches(&seven));
480    }
481
482    #[test]
483    fn predicate_range_limits() {
484        let op = Operator::Range {
485            lower: Bound::Included(10.into()),
486            upper: Bound::Excluded(20.into()),
487        };
488        let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
489        assert!(predicate.matches(&10));
490        assert!(predicate.matches(&19));
491        assert!(!predicate.matches(&9));
492        assert!(!predicate.matches(&20));
493    }
494
495    #[test]
496    fn predicate_in_operator() {
497        let values = [1.into(), 2.into(), 3.into()];
498        let op = Operator::In(&values);
499        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
500        let two: u8 = 2;
501        let five: u8 = 5;
502        assert!(predicate.matches(&two));
503        assert!(!predicate.matches(&five));
504    }
505
506    #[test]
507    fn unsupported_operator_errors() {
508        let op = Operator::starts_with("foo".to_string(), true);
509        let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
510        assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
511    }
512
513    #[test]
514    fn literal_cast_error_propagates() {
515        let op = Operator::Equals("foo".into());
516        let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
517        assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
518    }
519
520    #[test]
521    fn empty_bounds_map_to_all() {
522        let op = Operator::Range {
523            lower: Bound::Unbounded,
524            upper: Bound::Unbounded,
525        };
526        let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
527        assert!(predicate.matches(&123u32));
528    }
529
530    #[test]
531    fn matches_all_for_empty_in_list() {
532        let values: [Literal; 0] = [];
533        let op = Operator::In(&values);
534        let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
535        assert!(!predicate.matches(&1.23f32));
536    }
537
538    #[test]
539    fn string_predicate_equals() {
540        let op = Operator::Equals("foo".into());
541        let predicate = build_var_width_predicate(&op).unwrap();
542        assert!(predicate.matches("foo"));
543        assert!(!predicate.matches("bar"));
544    }
545
546    #[test]
547    fn string_predicate_range() {
548        let op = Operator::Range {
549            lower: Bound::Included("alpha".into()),
550            upper: Bound::Excluded("omega".into()),
551        };
552        let predicate = build_var_width_predicate(&op).unwrap();
553        assert!(predicate.matches("delta"));
554        assert!(!predicate.matches("zzz"));
555    }
556
557    #[test]
558    fn string_predicate_in_and_patterns() {
559        let vals = ["x".into(), "y".into()];
560        let op = Operator::In(&vals);
561        let predicate = build_var_width_predicate(&op).unwrap();
562        assert!(predicate.matches("x"));
563        assert!(!predicate.matches("z"));
564
565        let sw_sensitive =
566            build_var_width_predicate(&Operator::starts_with("pre".to_string(), true))
567                .expect("starts with predicate");
568        assert!(sw_sensitive.matches("prefix"));
569        assert!(!sw_sensitive.matches("Prefix"));
570
571        let sw_insensitive =
572            build_var_width_predicate(&Operator::starts_with("Pre".to_string(), false))
573                .expect("starts with predicate");
574        assert!(sw_insensitive.matches("prefix"));
575        assert!(sw_insensitive.matches("Prefix"));
576
577        let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf".to_string(), true))
578            .expect("ends with predicate");
579        assert!(ew_sensitive.matches("datsuf"));
580        assert!(!ew_sensitive.matches("datSuf"));
581
582        let ew_insensitive =
583            build_var_width_predicate(&Operator::ends_with("SUF".to_string(), false))
584                .expect("ends with predicate");
585        assert!(ew_insensitive.matches("datsuf"));
586        assert!(ew_insensitive.matches("datSuf"));
587
588        let ct_sensitive = build_var_width_predicate(&Operator::contains("mid".to_string(), true))
589            .expect("contains predicate");
590        assert!(ct_sensitive.matches("amidst"));
591        assert!(!ct_sensitive.matches("aMidst"));
592
593        let ct_insensitive =
594            build_var_width_predicate(&Operator::contains("MiD".to_string(), false))
595                .expect("contains predicate");
596        assert!(ct_insensitive.matches("amidst"));
597        assert!(ct_insensitive.matches("aMidst"));
598    }
599}