1use std::cmp::Ordering;
8use std::fmt;
9use std::ops::Bound;
10
11use arrow::datatypes::ArrowPrimitiveType;
12
13use crate::expr::Operator;
14use crate::literal::{FromLiteral, Literal, LiteralCastError, LiteralExt};
15
16pub trait PredicateValue: Clone {
18 type Borrowed<'a>: ?Sized
19 where
20 Self: 'a;
21
22 fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
23 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
24 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
25 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
26 let _ = (value, target, case_sensitive);
27 false
28 }
29 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
30 let _ = (value, target, case_sensitive);
31 false
32 }
33 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
34 let _ = (value, target, case_sensitive);
35 false
36 }
37}
38
39#[derive(Debug, Clone)]
41pub enum Predicate<V>
42where
43 V: PredicateValue,
44{
45 All,
46 Equals(V),
47 GreaterThan(V),
48 GreaterThanOrEquals(V),
49 LessThan(V),
50 LessThanOrEquals(V),
51 Range {
52 lower: Option<Bound<V>>,
53 upper: Option<Bound<V>>,
54 },
55 In(Vec<V>),
56 StartsWith {
57 pattern: V,
58 case_sensitive: bool,
59 },
60 EndsWith {
61 pattern: V,
62 case_sensitive: bool,
63 },
64 Contains {
65 pattern: V,
66 case_sensitive: bool,
67 },
68}
69
70impl<V> Predicate<V>
71where
72 V: PredicateValue,
73{
74 pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
76 match self {
77 Predicate::All => true,
78 Predicate::Equals(target) => V::equals(value, target),
79 Predicate::GreaterThan(target) => {
80 matches!(V::compare(value, target), Some(Ordering::Greater))
81 }
82 Predicate::GreaterThanOrEquals(target) => {
83 matches!(
84 V::compare(value, target),
85 Some(Ordering::Greater | Ordering::Equal)
86 )
87 }
88 Predicate::LessThan(target) => {
89 matches!(V::compare(value, target), Some(Ordering::Less))
90 }
91 Predicate::LessThanOrEquals(target) => matches!(
92 V::compare(value, target),
93 Some(Ordering::Less | Ordering::Equal)
94 ),
95 Predicate::Range { lower, upper } => {
96 if let Some(bound) = lower
97 && !match bound {
98 Bound::Included(target) => matches!(
99 V::compare(value, target),
100 Some(Ordering::Greater | Ordering::Equal)
101 ),
102 Bound::Excluded(target) => {
103 matches!(V::compare(value, target), Some(Ordering::Greater))
104 }
105 Bound::Unbounded => true,
106 }
107 {
108 return false;
109 }
110
111 if let Some(bound) = upper
112 && !match bound {
113 Bound::Included(target) => matches!(
114 V::compare(value, target),
115 Some(Ordering::Less | Ordering::Equal)
116 ),
117 Bound::Excluded(target) => {
118 matches!(V::compare(value, target), Some(Ordering::Less))
119 }
120 Bound::Unbounded => true,
121 }
122 {
123 return false;
124 }
125
126 true
127 }
128 Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
129 Predicate::StartsWith {
130 pattern,
131 case_sensitive,
132 } => V::starts_with(value, pattern, *case_sensitive),
133 Predicate::EndsWith {
134 pattern,
135 case_sensitive,
136 } => V::ends_with(value, pattern, *case_sensitive),
137 Predicate::Contains {
138 pattern,
139 case_sensitive,
140 } => V::contains(value, pattern, *case_sensitive),
141 }
142 }
143}
144
145macro_rules! impl_predicate_value_for_primitive {
146 ($($ty:ty),+ $(,)?) => {
147 $(
148 impl PredicateValue for $ty {
149 type Borrowed<'a> = Self where Self: 'a;
150
151 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
152 value
153 }
154
155 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
156 *value == *target
157 }
158
159 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
160 value.partial_cmp(target)
161 }
162 }
163 )+
164 };
165}
166
167impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
168
169impl PredicateValue for String {
170 type Borrowed<'a>
171 = str
172 where
173 Self: 'a;
174
175 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
176 value.as_str()
177 }
178
179 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
180 value == target.as_str()
181 }
182
183 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
184 Some(value.cmp(target.as_str()))
185 }
186
187 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
188 if case_sensitive {
189 value.contains(target.as_str())
190 } else {
191 value.to_lowercase().contains(&target.to_lowercase())
192 }
193 }
194
195 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
196 if case_sensitive {
197 value.starts_with(target.as_str())
198 } else {
199 value.to_lowercase().starts_with(&target.to_lowercase())
200 }
201 }
202
203 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
204 if case_sensitive {
205 value.ends_with(target.as_str())
206 } else {
207 value.to_lowercase().ends_with(&target.to_lowercase())
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub enum PredicateBuildError {
215 LiteralCast(LiteralCastError),
216 UnsupportedOperator(&'static str),
217}
218
219impl fmt::Display for PredicateBuildError {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 match self {
222 PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
223 PredicateBuildError::UnsupportedOperator(op) => {
224 write!(f, "unsupported operator for typed predicate: {op}")
225 }
226 }
227 }
228}
229
230impl std::error::Error for PredicateBuildError {
231 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232 match self {
233 PredicateBuildError::LiteralCast(err) => Some(err),
234 PredicateBuildError::UnsupportedOperator(_) => None,
235 }
236 }
237}
238
239impl From<LiteralCastError> for PredicateBuildError {
240 fn from(err: LiteralCastError) -> Self {
241 PredicateBuildError::LiteralCast(err)
242 }
243}
244
245pub fn build_fixed_width_predicate<T>(
253 op: &Operator<'_>,
254) -> Result<Predicate<T::Native>, PredicateBuildError>
255where
256 T: ArrowPrimitiveType,
257 T::Native: FromLiteral + Copy + PredicateValue,
258{
259 match op {
260 Operator::Equals(lit) => Ok(Predicate::Equals(
261 lit.to_native::<T::Native>()
262 .map_err(PredicateBuildError::from)?,
263 )),
264 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
265 lit.to_native::<T::Native>()
266 .map_err(PredicateBuildError::from)?,
267 )),
268 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
269 lit.to_native::<T::Native>()
270 .map_err(PredicateBuildError::from)?,
271 )),
272 Operator::LessThan(lit) => Ok(Predicate::LessThan(
273 lit.to_native::<T::Native>()
274 .map_err(PredicateBuildError::from)?,
275 )),
276 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
277 lit.to_native::<T::Native>()
278 .map_err(PredicateBuildError::from)?,
279 )),
280 Operator::Range { lower, upper } => {
281 let lb =
282 match Literal::bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
283 Bound::Unbounded => None,
284 other => Some(other),
285 };
286 let ub =
287 match Literal::bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
288 Bound::Unbounded => None,
289 other => Some(other),
290 };
291
292 if lb.is_none() && ub.is_none() {
293 Ok(Predicate::All)
294 } else {
295 Ok(Predicate::Range {
296 lower: lb,
297 upper: ub,
298 })
299 }
300 }
301 Operator::In(values) => {
302 let mut natives = Vec::with_capacity(values.len());
303 for lit in *values {
304 natives.push(
305 lit.to_native::<T::Native>()
306 .map_err(PredicateBuildError::from)?,
307 );
308 }
309 Ok(Predicate::In(natives))
310 }
311 _ => Err(PredicateBuildError::UnsupportedOperator(
312 "operator lacks typed literal support",
313 )),
314 }
315}
316
317fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
318 Ok(match bound {
319 Bound::Unbounded => None,
320 Bound::Included(lit) => Some(Bound::Included(
321 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
322 )),
323 Bound::Excluded(lit) => Some(Bound::Excluded(
324 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
325 )),
326 })
327}
328
329pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
337 match op {
338 Operator::Equals(lit) => Ok(Predicate::Equals(
339 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
340 )),
341 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
342 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
343 )),
344 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
345 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
346 )),
347 Operator::LessThan(lit) => Ok(Predicate::LessThan(
348 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
349 )),
350 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
351 lit.to_native::<bool>().map_err(PredicateBuildError::from)?,
352 )),
353 Operator::Range { lower, upper } => {
354 let lb = parse_bool_bound(lower)?;
355 let ub = parse_bool_bound(upper)?;
356 if lb.is_none() && ub.is_none() {
357 Ok(Predicate::All)
358 } else {
359 Ok(Predicate::Range {
360 lower: lb,
361 upper: ub,
362 })
363 }
364 }
365 Operator::In(values) => {
366 let mut natives = Vec::with_capacity(values.len());
367 for lit in *values {
368 natives.push(lit.to_native::<bool>().map_err(PredicateBuildError::from)?);
369 }
370 Ok(Predicate::In(natives))
371 }
372 _ => Err(PredicateBuildError::UnsupportedOperator(
373 "operator lacks boolean literal support",
374 )),
375 }
376}
377
378fn parse_string_bound(
379 bound: &Bound<Literal>,
380) -> Result<Option<Bound<String>>, PredicateBuildError> {
381 match bound {
382 Bound::Unbounded => Ok(None),
383 Bound::Included(lit) => lit
384 .to_string_owned()
385 .map(|s| Some(Bound::Included(s)))
386 .map_err(PredicateBuildError::from),
387 Bound::Excluded(lit) => lit
388 .to_string_owned()
389 .map(|s| Some(Bound::Excluded(s)))
390 .map_err(PredicateBuildError::from),
391 }
392}
393
394pub fn build_var_width_predicate(
402 op: &Operator<'_>,
403) -> Result<Predicate<String>, PredicateBuildError> {
404 match op {
405 Operator::Equals(lit) => Ok(Predicate::Equals(
406 lit.to_string_owned().map_err(PredicateBuildError::from)?,
407 )),
408 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
409 lit.to_string_owned().map_err(PredicateBuildError::from)?,
410 )),
411 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
412 lit.to_string_owned().map_err(PredicateBuildError::from)?,
413 )),
414 Operator::LessThan(lit) => Ok(Predicate::LessThan(
415 lit.to_string_owned().map_err(PredicateBuildError::from)?,
416 )),
417 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
418 lit.to_string_owned().map_err(PredicateBuildError::from)?,
419 )),
420 Operator::Range { lower, upper } => {
421 let lb = parse_string_bound(lower)?;
422 let ub = parse_string_bound(upper)?;
423 if lb.is_none() && ub.is_none() {
424 Ok(Predicate::All)
425 } else {
426 Ok(Predicate::Range {
427 lower: lb,
428 upper: ub,
429 })
430 }
431 }
432 Operator::In(values) => {
433 let mut out = Vec::with_capacity(values.len());
434 for lit in *values {
435 out.push(lit.to_string_owned().map_err(PredicateBuildError::from)?);
436 }
437 Ok(Predicate::In(out))
438 }
439 Operator::StartsWith {
440 pattern,
441 case_sensitive,
442 } => Ok(Predicate::StartsWith {
443 pattern: pattern.to_string(),
444 case_sensitive: *case_sensitive,
445 }),
446 Operator::EndsWith {
447 pattern,
448 case_sensitive,
449 } => Ok(Predicate::EndsWith {
450 pattern: pattern.to_string(),
451 case_sensitive: *case_sensitive,
452 }),
453 Operator::Contains {
454 pattern,
455 case_sensitive,
456 } => Ok(Predicate::Contains {
457 pattern: pattern.to_string(),
458 case_sensitive: *case_sensitive,
459 }),
460 Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
461 "operator lacks string literal support",
462 )),
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use crate::literal::Literal;
470 use std::ops::Bound;
471
472 #[test]
473 fn predicate_matches_equals() {
474 let op = Operator::Equals(42_i64.into());
475 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
476 let forty_two: i64 = 42;
477 let seven: i64 = 7;
478 assert!(predicate.matches(&forty_two));
479 assert!(!predicate.matches(&seven));
480 }
481
482 #[test]
483 fn predicate_range_limits() {
484 let op = Operator::Range {
485 lower: Bound::Included(10.into()),
486 upper: Bound::Excluded(20.into()),
487 };
488 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
489 assert!(predicate.matches(&10));
490 assert!(predicate.matches(&19));
491 assert!(!predicate.matches(&9));
492 assert!(!predicate.matches(&20));
493 }
494
495 #[test]
496 fn predicate_in_operator() {
497 let values = [1.into(), 2.into(), 3.into()];
498 let op = Operator::In(&values);
499 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
500 let two: u8 = 2;
501 let five: u8 = 5;
502 assert!(predicate.matches(&two));
503 assert!(!predicate.matches(&five));
504 }
505
506 #[test]
507 fn unsupported_operator_errors() {
508 let op = Operator::starts_with("foo".to_string(), true);
509 let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
510 assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
511 }
512
513 #[test]
514 fn literal_cast_error_propagates() {
515 let op = Operator::Equals("foo".into());
516 let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
517 assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
518 }
519
520 #[test]
521 fn empty_bounds_map_to_all() {
522 let op = Operator::Range {
523 lower: Bound::Unbounded,
524 upper: Bound::Unbounded,
525 };
526 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
527 assert!(predicate.matches(&123u32));
528 }
529
530 #[test]
531 fn matches_all_for_empty_in_list() {
532 let values: [Literal; 0] = [];
533 let op = Operator::In(&values);
534 let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
535 assert!(!predicate.matches(&1.23f32));
536 }
537
538 #[test]
539 fn string_predicate_equals() {
540 let op = Operator::Equals("foo".into());
541 let predicate = build_var_width_predicate(&op).unwrap();
542 assert!(predicate.matches("foo"));
543 assert!(!predicate.matches("bar"));
544 }
545
546 #[test]
547 fn string_predicate_range() {
548 let op = Operator::Range {
549 lower: Bound::Included("alpha".into()),
550 upper: Bound::Excluded("omega".into()),
551 };
552 let predicate = build_var_width_predicate(&op).unwrap();
553 assert!(predicate.matches("delta"));
554 assert!(!predicate.matches("zzz"));
555 }
556
557 #[test]
558 fn string_predicate_in_and_patterns() {
559 let vals = ["x".into(), "y".into()];
560 let op = Operator::In(&vals);
561 let predicate = build_var_width_predicate(&op).unwrap();
562 assert!(predicate.matches("x"));
563 assert!(!predicate.matches("z"));
564
565 let sw_sensitive =
566 build_var_width_predicate(&Operator::starts_with("pre".to_string(), true))
567 .expect("starts with predicate");
568 assert!(sw_sensitive.matches("prefix"));
569 assert!(!sw_sensitive.matches("Prefix"));
570
571 let sw_insensitive =
572 build_var_width_predicate(&Operator::starts_with("Pre".to_string(), false))
573 .expect("starts with predicate");
574 assert!(sw_insensitive.matches("prefix"));
575 assert!(sw_insensitive.matches("Prefix"));
576
577 let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf".to_string(), true))
578 .expect("ends with predicate");
579 assert!(ew_sensitive.matches("datsuf"));
580 assert!(!ew_sensitive.matches("datSuf"));
581
582 let ew_insensitive =
583 build_var_width_predicate(&Operator::ends_with("SUF".to_string(), false))
584 .expect("ends with predicate");
585 assert!(ew_insensitive.matches("datsuf"));
586 assert!(ew_insensitive.matches("datSuf"));
587
588 let ct_sensitive = build_var_width_predicate(&Operator::contains("mid".to_string(), true))
589 .expect("contains predicate");
590 assert!(ct_sensitive.matches("amidst"));
591 assert!(!ct_sensitive.matches("aMidst"));
592
593 let ct_insensitive =
594 build_var_width_predicate(&Operator::contains("MiD".to_string(), false))
595 .expect("contains predicate");
596 assert!(ct_insensitive.matches("amidst"));
597 assert!(ct_insensitive.matches("aMidst"));
598 }
599}