1use std::cmp::Ordering;
8use std::fmt;
9use std::ops::Bound;
10
11use arrow::datatypes::ArrowPrimitiveType;
12
13use crate::expr::Operator;
14use crate::literal::{
15 FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
16};
17
18pub trait PredicateValue: Clone {
20 type Borrowed<'a>: ?Sized
21 where
22 Self: 'a;
23
24 fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
25 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
26 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
27 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
28 let _ = (value, target, case_sensitive);
29 false
30 }
31 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
32 let _ = (value, target, case_sensitive);
33 false
34 }
35 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
36 let _ = (value, target, case_sensitive);
37 false
38 }
39}
40
41#[derive(Debug, Clone)]
43pub enum Predicate<V>
44where
45 V: PredicateValue,
46{
47 All,
48 Equals(V),
49 GreaterThan(V),
50 GreaterThanOrEquals(V),
51 LessThan(V),
52 LessThanOrEquals(V),
53 Range {
54 lower: Option<Bound<V>>,
55 upper: Option<Bound<V>>,
56 },
57 In(Vec<V>),
58 StartsWith {
59 pattern: V,
60 case_sensitive: bool,
61 },
62 EndsWith {
63 pattern: V,
64 case_sensitive: bool,
65 },
66 Contains {
67 pattern: V,
68 case_sensitive: bool,
69 },
70}
71
72impl<V> Predicate<V>
73where
74 V: PredicateValue,
75{
76 pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
78 match self {
79 Predicate::All => true,
80 Predicate::Equals(target) => V::equals(value, target),
81 Predicate::GreaterThan(target) => {
82 matches!(V::compare(value, target), Some(Ordering::Greater))
83 }
84 Predicate::GreaterThanOrEquals(target) => {
85 matches!(
86 V::compare(value, target),
87 Some(Ordering::Greater | Ordering::Equal)
88 )
89 }
90 Predicate::LessThan(target) => {
91 matches!(V::compare(value, target), Some(Ordering::Less))
92 }
93 Predicate::LessThanOrEquals(target) => matches!(
94 V::compare(value, target),
95 Some(Ordering::Less | Ordering::Equal)
96 ),
97 Predicate::Range { lower, upper } => {
98 if let Some(bound) = lower
99 && !match bound {
100 Bound::Included(target) => matches!(
101 V::compare(value, target),
102 Some(Ordering::Greater | Ordering::Equal)
103 ),
104 Bound::Excluded(target) => {
105 matches!(V::compare(value, target), Some(Ordering::Greater))
106 }
107 Bound::Unbounded => true,
108 }
109 {
110 return false;
111 }
112
113 if let Some(bound) = upper
114 && !match bound {
115 Bound::Included(target) => matches!(
116 V::compare(value, target),
117 Some(Ordering::Less | Ordering::Equal)
118 ),
119 Bound::Excluded(target) => {
120 matches!(V::compare(value, target), Some(Ordering::Less))
121 }
122 Bound::Unbounded => true,
123 }
124 {
125 return false;
126 }
127
128 true
129 }
130 Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
131 Predicate::StartsWith {
132 pattern,
133 case_sensitive,
134 } => V::starts_with(value, pattern, *case_sensitive),
135 Predicate::EndsWith {
136 pattern,
137 case_sensitive,
138 } => V::ends_with(value, pattern, *case_sensitive),
139 Predicate::Contains {
140 pattern,
141 case_sensitive,
142 } => V::contains(value, pattern, *case_sensitive),
143 }
144 }
145}
146
147macro_rules! impl_predicate_value_for_primitive {
148 ($($ty:ty),+ $(,)?) => {
149 $(
150 impl PredicateValue for $ty {
151 type Borrowed<'a> = Self where Self: 'a;
152
153 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
154 value
155 }
156
157 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
158 *value == *target
159 }
160
161 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
162 value.partial_cmp(target)
163 }
164 }
165 )+
166 };
167}
168
169impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
170
171impl PredicateValue for String {
172 type Borrowed<'a>
173 = str
174 where
175 Self: 'a;
176
177 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
178 value.as_str()
179 }
180
181 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
182 value == target.as_str()
183 }
184
185 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
186 Some(value.cmp(target.as_str()))
187 }
188
189 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
190 if case_sensitive {
191 value.contains(target.as_str())
192 } else {
193 value.to_lowercase().contains(&target.to_lowercase())
194 }
195 }
196
197 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
198 if case_sensitive {
199 value.starts_with(target.as_str())
200 } else {
201 value.to_lowercase().starts_with(&target.to_lowercase())
202 }
203 }
204
205 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
206 if case_sensitive {
207 value.ends_with(target.as_str())
208 } else {
209 value.to_lowercase().ends_with(&target.to_lowercase())
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
216pub enum PredicateBuildError {
217 LiteralCast(LiteralCastError),
218 UnsupportedOperator(&'static str),
219}
220
221impl fmt::Display for PredicateBuildError {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 match self {
224 PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
225 PredicateBuildError::UnsupportedOperator(op) => {
226 write!(f, "unsupported operator for typed predicate: {op}")
227 }
228 }
229 }
230}
231
232impl std::error::Error for PredicateBuildError {
233 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
234 match self {
235 PredicateBuildError::LiteralCast(err) => Some(err),
236 PredicateBuildError::UnsupportedOperator(_) => None,
237 }
238 }
239}
240
241impl From<LiteralCastError> for PredicateBuildError {
242 fn from(err: LiteralCastError) -> Self {
243 PredicateBuildError::LiteralCast(err)
244 }
245}
246
247pub fn build_fixed_width_predicate<T>(
255 op: &Operator<'_>,
256) -> Result<Predicate<T::Native>, PredicateBuildError>
257where
258 T: ArrowPrimitiveType,
259 T::Native: FromLiteral + Copy + PredicateValue,
260{
261 match op {
262 Operator::Equals(lit) => Ok(Predicate::Equals(
263 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
264 )),
265 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
266 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
267 )),
268 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
269 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
270 )),
271 Operator::LessThan(lit) => Ok(Predicate::LessThan(
272 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
273 )),
274 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
275 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
276 )),
277 Operator::Range { lower, upper } => {
278 let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
279 Bound::Unbounded => None,
280 other => Some(other),
281 };
282 let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
283 Bound::Unbounded => None,
284 other => Some(other),
285 };
286
287 if lb.is_none() && ub.is_none() {
288 Ok(Predicate::All)
289 } else {
290 Ok(Predicate::Range {
291 lower: lb,
292 upper: ub,
293 })
294 }
295 }
296 Operator::In(values) => {
297 let mut natives = Vec::with_capacity(values.len());
298 for lit in *values {
299 natives
300 .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
301 }
302 Ok(Predicate::In(natives))
303 }
304 _ => Err(PredicateBuildError::UnsupportedOperator(
305 "operator lacks typed literal support",
306 )),
307 }
308}
309
310fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
311 Ok(match bound {
312 Bound::Unbounded => None,
313 Bound::Included(lit) => Some(Bound::Included(
314 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315 )),
316 Bound::Excluded(lit) => Some(Bound::Excluded(
317 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318 )),
319 })
320}
321
322pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
330 match op {
331 Operator::Equals(lit) => Ok(Predicate::Equals(
332 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
333 )),
334 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
335 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
336 )),
337 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
338 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
339 )),
340 Operator::LessThan(lit) => Ok(Predicate::LessThan(
341 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
342 )),
343 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
344 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
345 )),
346 Operator::Range { lower, upper } => {
347 let lb = parse_bool_bound(lower)?;
348 let ub = parse_bool_bound(upper)?;
349 if lb.is_none() && ub.is_none() {
350 Ok(Predicate::All)
351 } else {
352 Ok(Predicate::Range {
353 lower: lb,
354 upper: ub,
355 })
356 }
357 }
358 Operator::In(values) => {
359 let mut natives = Vec::with_capacity(values.len());
360 for lit in *values {
361 natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
362 }
363 Ok(Predicate::In(natives))
364 }
365 _ => Err(PredicateBuildError::UnsupportedOperator(
366 "operator lacks boolean literal support",
367 )),
368 }
369}
370
371fn parse_string_bound(
372 bound: &Bound<Literal>,
373) -> Result<Option<Bound<String>>, PredicateBuildError> {
374 match bound {
375 Bound::Unbounded => Ok(None),
376 Bound::Included(lit) => literal_to_string(lit)
377 .map(|s| Some(Bound::Included(s)))
378 .map_err(PredicateBuildError::from),
379 Bound::Excluded(lit) => literal_to_string(lit)
380 .map(|s| Some(Bound::Excluded(s)))
381 .map_err(PredicateBuildError::from),
382 }
383}
384
385pub fn build_var_width_predicate(
393 op: &Operator<'_>,
394) -> Result<Predicate<String>, PredicateBuildError> {
395 match op {
396 Operator::Equals(lit) => Ok(Predicate::Equals(
397 literal_to_string(lit).map_err(PredicateBuildError::from)?,
398 )),
399 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
400 literal_to_string(lit).map_err(PredicateBuildError::from)?,
401 )),
402 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
403 literal_to_string(lit).map_err(PredicateBuildError::from)?,
404 )),
405 Operator::LessThan(lit) => Ok(Predicate::LessThan(
406 literal_to_string(lit).map_err(PredicateBuildError::from)?,
407 )),
408 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
409 literal_to_string(lit).map_err(PredicateBuildError::from)?,
410 )),
411 Operator::Range { lower, upper } => {
412 let lb = parse_string_bound(lower)?;
413 let ub = parse_string_bound(upper)?;
414 if lb.is_none() && ub.is_none() {
415 Ok(Predicate::All)
416 } else {
417 Ok(Predicate::Range {
418 lower: lb,
419 upper: ub,
420 })
421 }
422 }
423 Operator::In(values) => {
424 let mut out = Vec::with_capacity(values.len());
425 for lit in *values {
426 out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
427 }
428 Ok(Predicate::In(out))
429 }
430 Operator::StartsWith {
431 pattern,
432 case_sensitive,
433 } => Ok(Predicate::StartsWith {
434 pattern: pattern.to_string(),
435 case_sensitive: *case_sensitive,
436 }),
437 Operator::EndsWith {
438 pattern,
439 case_sensitive,
440 } => Ok(Predicate::EndsWith {
441 pattern: pattern.to_string(),
442 case_sensitive: *case_sensitive,
443 }),
444 Operator::Contains {
445 pattern,
446 case_sensitive,
447 } => Ok(Predicate::Contains {
448 pattern: pattern.to_string(),
449 case_sensitive: *case_sensitive,
450 }),
451 Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
452 "operator lacks string literal support",
453 )),
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::literal::Literal;
461 use std::ops::Bound;
462
463 #[test]
464 fn predicate_matches_equals() {
465 let op = Operator::Equals(42_i64.into());
466 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
467 let forty_two: i64 = 42;
468 let seven: i64 = 7;
469 assert!(predicate.matches(&forty_two));
470 assert!(!predicate.matches(&seven));
471 }
472
473 #[test]
474 fn predicate_range_limits() {
475 let op = Operator::Range {
476 lower: Bound::Included(10.into()),
477 upper: Bound::Excluded(20.into()),
478 };
479 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
480 assert!(predicate.matches(&10));
481 assert!(predicate.matches(&19));
482 assert!(!predicate.matches(&9));
483 assert!(!predicate.matches(&20));
484 }
485
486 #[test]
487 fn predicate_in_operator() {
488 let values = [1.into(), 2.into(), 3.into()];
489 let op = Operator::In(&values);
490 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
491 let two: u8 = 2;
492 let five: u8 = 5;
493 assert!(predicate.matches(&two));
494 assert!(!predicate.matches(&five));
495 }
496
497 #[test]
498 fn unsupported_operator_errors() {
499 let op = Operator::starts_with("foo", true);
500 let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
501 assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
502 }
503
504 #[test]
505 fn literal_cast_error_propagates() {
506 let op = Operator::Equals("foo".into());
507 let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
508 assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
509 }
510
511 #[test]
512 fn empty_bounds_map_to_all() {
513 let op = Operator::Range {
514 lower: Bound::Unbounded,
515 upper: Bound::Unbounded,
516 };
517 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
518 assert!(predicate.matches(&123u32));
519 }
520
521 #[test]
522 fn matches_all_for_empty_in_list() {
523 let values: [Literal; 0] = [];
524 let op = Operator::In(&values);
525 let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
526 assert!(!predicate.matches(&1.23f32));
527 }
528
529 #[test]
530 fn string_predicate_equals() {
531 let op = Operator::Equals("foo".into());
532 let predicate = build_var_width_predicate(&op).unwrap();
533 assert!(predicate.matches("foo"));
534 assert!(!predicate.matches("bar"));
535 }
536
537 #[test]
538 fn string_predicate_range() {
539 let op = Operator::Range {
540 lower: Bound::Included("alpha".into()),
541 upper: Bound::Excluded("omega".into()),
542 };
543 let predicate = build_var_width_predicate(&op).unwrap();
544 assert!(predicate.matches("delta"));
545 assert!(!predicate.matches("zzz"));
546 }
547
548 #[test]
549 fn string_predicate_in_and_patterns() {
550 let vals = ["x".into(), "y".into()];
551 let op = Operator::In(&vals);
552 let predicate = build_var_width_predicate(&op).unwrap();
553 assert!(predicate.matches("x"));
554 assert!(!predicate.matches("z"));
555
556 let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
557 .expect("starts with predicate");
558 assert!(sw_sensitive.matches("prefix"));
559 assert!(!sw_sensitive.matches("Prefix"));
560
561 let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
562 .expect("starts with predicate");
563 assert!(sw_insensitive.matches("prefix"));
564 assert!(sw_insensitive.matches("Prefix"));
565
566 let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
567 .expect("ends with predicate");
568 assert!(ew_sensitive.matches("datsuf"));
569 assert!(!ew_sensitive.matches("datSuf"));
570
571 let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
572 .expect("ends with predicate");
573 assert!(ew_insensitive.matches("datsuf"));
574 assert!(ew_insensitive.matches("datSuf"));
575
576 let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
577 .expect("contains predicate");
578 assert!(ct_sensitive.matches("amidst"));
579 assert!(!ct_sensitive.matches("aMidst"));
580
581 let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
582 .expect("contains predicate");
583 assert!(ct_insensitive.matches("amidst"));
584 assert!(ct_insensitive.matches("aMidst"));
585 }
586}