1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::Bound;
4
5use arrow::datatypes::ArrowPrimitiveType;
6
7use crate::expr::Operator;
8use crate::literal::{
9 FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
10};
11
12pub trait PredicateValue: Clone {
13 type Borrowed<'a>: ?Sized
14 where
15 Self: 'a;
16
17 fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
18 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
19 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
20 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
21 let _ = (value, target, case_sensitive);
22 false
23 }
24 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
25 let _ = (value, target, case_sensitive);
26 false
27 }
28 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
29 let _ = (value, target, case_sensitive);
30 false
31 }
32}
33
34#[derive(Debug, Clone)]
35pub enum Predicate<V>
36where
37 V: PredicateValue,
38{
39 All,
40 Equals(V),
41 GreaterThan(V),
42 GreaterThanOrEquals(V),
43 LessThan(V),
44 LessThanOrEquals(V),
45 Range {
46 lower: Option<Bound<V>>,
47 upper: Option<Bound<V>>,
48 },
49 In(Vec<V>),
50 StartsWith {
51 pattern: V,
52 case_sensitive: bool,
53 },
54 EndsWith {
55 pattern: V,
56 case_sensitive: bool,
57 },
58 Contains {
59 pattern: V,
60 case_sensitive: bool,
61 },
62}
63
64impl<V> Predicate<V>
65where
66 V: PredicateValue,
67{
68 pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
69 match self {
70 Predicate::All => true,
71 Predicate::Equals(target) => V::equals(value, target),
72 Predicate::GreaterThan(target) => {
73 matches!(V::compare(value, target), Some(Ordering::Greater))
74 }
75 Predicate::GreaterThanOrEquals(target) => {
76 matches!(
77 V::compare(value, target),
78 Some(Ordering::Greater | Ordering::Equal)
79 )
80 }
81 Predicate::LessThan(target) => {
82 matches!(V::compare(value, target), Some(Ordering::Less))
83 }
84 Predicate::LessThanOrEquals(target) => matches!(
85 V::compare(value, target),
86 Some(Ordering::Less | Ordering::Equal)
87 ),
88 Predicate::Range { lower, upper } => {
89 if let Some(bound) = lower
90 && !match bound {
91 Bound::Included(target) => matches!(
92 V::compare(value, target),
93 Some(Ordering::Greater | Ordering::Equal)
94 ),
95 Bound::Excluded(target) => {
96 matches!(V::compare(value, target), Some(Ordering::Greater))
97 }
98 Bound::Unbounded => true,
99 }
100 {
101 return false;
102 }
103
104 if let Some(bound) = upper
105 && !match bound {
106 Bound::Included(target) => matches!(
107 V::compare(value, target),
108 Some(Ordering::Less | Ordering::Equal)
109 ),
110 Bound::Excluded(target) => {
111 matches!(V::compare(value, target), Some(Ordering::Less))
112 }
113 Bound::Unbounded => true,
114 }
115 {
116 return false;
117 }
118
119 true
120 }
121 Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
122 Predicate::StartsWith {
123 pattern,
124 case_sensitive,
125 } => V::starts_with(value, pattern, *case_sensitive),
126 Predicate::EndsWith {
127 pattern,
128 case_sensitive,
129 } => V::ends_with(value, pattern, *case_sensitive),
130 Predicate::Contains {
131 pattern,
132 case_sensitive,
133 } => V::contains(value, pattern, *case_sensitive),
134 }
135 }
136}
137
138macro_rules! impl_predicate_value_for_primitive {
139 ($($ty:ty),+ $(,)?) => {
140 $(
141 impl PredicateValue for $ty {
142 type Borrowed<'a> = Self where Self: 'a;
143
144 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
145 value
146 }
147
148 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
149 *value == *target
150 }
151
152 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
153 value.partial_cmp(target)
154 }
155 }
156 )+
157 };
158}
159
160impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
161
162impl PredicateValue for String {
163 type Borrowed<'a>
164 = str
165 where
166 Self: 'a;
167
168 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
169 value.as_str()
170 }
171
172 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
173 value == target.as_str()
174 }
175
176 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
177 Some(value.cmp(target.as_str()))
178 }
179
180 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
181 if case_sensitive {
182 value.contains(target.as_str())
183 } else {
184 value.to_lowercase().contains(&target.to_lowercase())
185 }
186 }
187
188 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
189 if case_sensitive {
190 value.starts_with(target.as_str())
191 } else {
192 value.to_lowercase().starts_with(&target.to_lowercase())
193 }
194 }
195
196 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
197 if case_sensitive {
198 value.ends_with(target.as_str())
199 } else {
200 value.to_lowercase().ends_with(&target.to_lowercase())
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
206pub enum PredicateBuildError {
207 LiteralCast(LiteralCastError),
208 UnsupportedOperator(&'static str),
209}
210
211impl fmt::Display for PredicateBuildError {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
215 PredicateBuildError::UnsupportedOperator(op) => {
216 write!(f, "unsupported operator for typed predicate: {op}")
217 }
218 }
219 }
220}
221
222impl std::error::Error for PredicateBuildError {
223 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
224 match self {
225 PredicateBuildError::LiteralCast(err) => Some(err),
226 PredicateBuildError::UnsupportedOperator(_) => None,
227 }
228 }
229}
230
231impl From<LiteralCastError> for PredicateBuildError {
232 fn from(err: LiteralCastError) -> Self {
233 PredicateBuildError::LiteralCast(err)
234 }
235}
236
237pub fn build_fixed_width_predicate<T>(
238 op: &Operator<'_>,
239) -> Result<Predicate<T::Native>, PredicateBuildError>
240where
241 T: ArrowPrimitiveType,
242 T::Native: FromLiteral + Copy + PredicateValue,
243{
244 match op {
245 Operator::Equals(lit) => Ok(Predicate::Equals(
246 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
247 )),
248 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
249 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
250 )),
251 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
252 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
253 )),
254 Operator::LessThan(lit) => Ok(Predicate::LessThan(
255 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
256 )),
257 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
258 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
259 )),
260 Operator::Range { lower, upper } => {
261 let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
262 Bound::Unbounded => None,
263 other => Some(other),
264 };
265 let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
266 Bound::Unbounded => None,
267 other => Some(other),
268 };
269
270 if lb.is_none() && ub.is_none() {
271 Ok(Predicate::All)
272 } else {
273 Ok(Predicate::Range {
274 lower: lb,
275 upper: ub,
276 })
277 }
278 }
279 Operator::In(values) => {
280 let mut natives = Vec::with_capacity(values.len());
281 for lit in *values {
282 natives
283 .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
284 }
285 Ok(Predicate::In(natives))
286 }
287 _ => Err(PredicateBuildError::UnsupportedOperator(
288 "operator lacks typed literal support",
289 )),
290 }
291}
292
293fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
294 Ok(match bound {
295 Bound::Unbounded => None,
296 Bound::Included(lit) => Some(Bound::Included(
297 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
298 )),
299 Bound::Excluded(lit) => Some(Bound::Excluded(
300 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
301 )),
302 })
303}
304
305pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
306 match op {
307 Operator::Equals(lit) => Ok(Predicate::Equals(
308 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
309 )),
310 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
311 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
312 )),
313 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
314 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315 )),
316 Operator::LessThan(lit) => Ok(Predicate::LessThan(
317 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318 )),
319 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
320 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
321 )),
322 Operator::Range { lower, upper } => {
323 let lb = parse_bool_bound(lower)?;
324 let ub = parse_bool_bound(upper)?;
325 if lb.is_none() && ub.is_none() {
326 Ok(Predicate::All)
327 } else {
328 Ok(Predicate::Range {
329 lower: lb,
330 upper: ub,
331 })
332 }
333 }
334 Operator::In(values) => {
335 let mut natives = Vec::with_capacity(values.len());
336 for lit in *values {
337 natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
338 }
339 Ok(Predicate::In(natives))
340 }
341 _ => Err(PredicateBuildError::UnsupportedOperator(
342 "operator lacks boolean literal support",
343 )),
344 }
345}
346
347fn parse_string_bound(
348 bound: &Bound<Literal>,
349) -> Result<Option<Bound<String>>, PredicateBuildError> {
350 match bound {
351 Bound::Unbounded => Ok(None),
352 Bound::Included(lit) => literal_to_string(lit)
353 .map(|s| Some(Bound::Included(s)))
354 .map_err(PredicateBuildError::from),
355 Bound::Excluded(lit) => literal_to_string(lit)
356 .map(|s| Some(Bound::Excluded(s)))
357 .map_err(PredicateBuildError::from),
358 }
359}
360
361pub fn build_var_width_predicate(
362 op: &Operator<'_>,
363) -> Result<Predicate<String>, PredicateBuildError> {
364 match op {
365 Operator::Equals(lit) => Ok(Predicate::Equals(
366 literal_to_string(lit).map_err(PredicateBuildError::from)?,
367 )),
368 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
369 literal_to_string(lit).map_err(PredicateBuildError::from)?,
370 )),
371 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
372 literal_to_string(lit).map_err(PredicateBuildError::from)?,
373 )),
374 Operator::LessThan(lit) => Ok(Predicate::LessThan(
375 literal_to_string(lit).map_err(PredicateBuildError::from)?,
376 )),
377 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
378 literal_to_string(lit).map_err(PredicateBuildError::from)?,
379 )),
380 Operator::Range { lower, upper } => {
381 let lb = parse_string_bound(lower)?;
382 let ub = parse_string_bound(upper)?;
383 if lb.is_none() && ub.is_none() {
384 Ok(Predicate::All)
385 } else {
386 Ok(Predicate::Range {
387 lower: lb,
388 upper: ub,
389 })
390 }
391 }
392 Operator::In(values) => {
393 let mut out = Vec::with_capacity(values.len());
394 for lit in *values {
395 out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
396 }
397 Ok(Predicate::In(out))
398 }
399 Operator::StartsWith {
400 pattern,
401 case_sensitive,
402 } => Ok(Predicate::StartsWith {
403 pattern: pattern.to_string(),
404 case_sensitive: *case_sensitive,
405 }),
406 Operator::EndsWith {
407 pattern,
408 case_sensitive,
409 } => Ok(Predicate::EndsWith {
410 pattern: pattern.to_string(),
411 case_sensitive: *case_sensitive,
412 }),
413 Operator::Contains {
414 pattern,
415 case_sensitive,
416 } => Ok(Predicate::Contains {
417 pattern: pattern.to_string(),
418 case_sensitive: *case_sensitive,
419 }),
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::literal::Literal;
427 use std::ops::Bound;
428
429 #[test]
430 fn predicate_matches_equals() {
431 let op = Operator::Equals(42_i64.into());
432 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
433 let forty_two: i64 = 42;
434 let seven: i64 = 7;
435 assert!(predicate.matches(&forty_two));
436 assert!(!predicate.matches(&seven));
437 }
438
439 #[test]
440 fn predicate_range_limits() {
441 let op = Operator::Range {
442 lower: Bound::Included(10.into()),
443 upper: Bound::Excluded(20.into()),
444 };
445 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
446 assert!(predicate.matches(&10));
447 assert!(predicate.matches(&19));
448 assert!(!predicate.matches(&9));
449 assert!(!predicate.matches(&20));
450 }
451
452 #[test]
453 fn predicate_in_operator() {
454 let values = [1.into(), 2.into(), 3.into()];
455 let op = Operator::In(&values);
456 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
457 let two: u8 = 2;
458 let five: u8 = 5;
459 assert!(predicate.matches(&two));
460 assert!(!predicate.matches(&five));
461 }
462
463 #[test]
464 fn unsupported_operator_errors() {
465 let op = Operator::starts_with("foo", true);
466 let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
467 assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
468 }
469
470 #[test]
471 fn literal_cast_error_propagates() {
472 let op = Operator::Equals("foo".into());
473 let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
474 assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
475 }
476
477 #[test]
478 fn empty_bounds_map_to_all() {
479 let op = Operator::Range {
480 lower: Bound::Unbounded,
481 upper: Bound::Unbounded,
482 };
483 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
484 assert!(predicate.matches(&123u32));
485 }
486
487 #[test]
488 fn matches_all_for_empty_in_list() {
489 let values: [Literal; 0] = [];
490 let op = Operator::In(&values);
491 let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
492 assert!(!predicate.matches(&1.23f32));
493 }
494
495 #[test]
496 fn string_predicate_equals() {
497 let op = Operator::Equals("foo".into());
498 let predicate = build_var_width_predicate(&op).unwrap();
499 assert!(predicate.matches("foo"));
500 assert!(!predicate.matches("bar"));
501 }
502
503 #[test]
504 fn string_predicate_range() {
505 let op = Operator::Range {
506 lower: Bound::Included("alpha".into()),
507 upper: Bound::Excluded("omega".into()),
508 };
509 let predicate = build_var_width_predicate(&op).unwrap();
510 assert!(predicate.matches("delta"));
511 assert!(!predicate.matches("zzz"));
512 }
513
514 #[test]
515 fn string_predicate_in_and_patterns() {
516 let vals = ["x".into(), "y".into()];
517 let op = Operator::In(&vals);
518 let predicate = build_var_width_predicate(&op).unwrap();
519 assert!(predicate.matches("x"));
520 assert!(!predicate.matches("z"));
521
522 let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
523 .expect("starts with predicate");
524 assert!(sw_sensitive.matches("prefix"));
525 assert!(!sw_sensitive.matches("Prefix"));
526
527 let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
528 .expect("starts with predicate");
529 assert!(sw_insensitive.matches("prefix"));
530 assert!(sw_insensitive.matches("Prefix"));
531
532 let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
533 .expect("ends with predicate");
534 assert!(ew_sensitive.matches("datsuf"));
535 assert!(!ew_sensitive.matches("datSuf"));
536
537 let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
538 .expect("ends with predicate");
539 assert!(ew_insensitive.matches("datsuf"));
540 assert!(ew_insensitive.matches("datSuf"));
541
542 let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
543 .expect("contains predicate");
544 assert!(ct_sensitive.matches("amidst"));
545 assert!(!ct_sensitive.matches("aMidst"));
546
547 let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
548 .expect("contains predicate");
549 assert!(ct_insensitive.matches("amidst"));
550 assert!(ct_insensitive.matches("aMidst"));
551 }
552}