1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::Bound;
4
5use arrow::datatypes::ArrowPrimitiveType;
6
7use crate::expr::Operator;
8use crate::literal::{
9 FromLiteral, Literal, LiteralCastError, bound_to_native, literal_to_native, literal_to_string,
10};
11
12pub trait PredicateValue: Clone {
13 type Borrowed<'a>: ?Sized
14 where
15 Self: 'a;
16
17 fn borrowed(value: &Self) -> &Self::Borrowed<'_>;
18 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool;
19 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering>;
20 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
21 let _ = (value, target, case_sensitive);
22 false
23 }
24 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
25 let _ = (value, target, case_sensitive);
26 false
27 }
28 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
29 let _ = (value, target, case_sensitive);
30 false
31 }
32}
33
34#[derive(Debug, Clone)]
35pub enum Predicate<V>
36where
37 V: PredicateValue,
38{
39 All,
40 Equals(V),
41 GreaterThan(V),
42 GreaterThanOrEquals(V),
43 LessThan(V),
44 LessThanOrEquals(V),
45 Range {
46 lower: Option<Bound<V>>,
47 upper: Option<Bound<V>>,
48 },
49 In(Vec<V>),
50 StartsWith {
51 pattern: V,
52 case_sensitive: bool,
53 },
54 EndsWith {
55 pattern: V,
56 case_sensitive: bool,
57 },
58 Contains {
59 pattern: V,
60 case_sensitive: bool,
61 },
62}
63
64impl<V> Predicate<V>
65where
66 V: PredicateValue,
67{
68 pub fn matches(&self, value: &V::Borrowed<'_>) -> bool {
69 match self {
70 Predicate::All => true,
71 Predicate::Equals(target) => V::equals(value, target),
72 Predicate::GreaterThan(target) => {
73 matches!(V::compare(value, target), Some(Ordering::Greater))
74 }
75 Predicate::GreaterThanOrEquals(target) => {
76 matches!(
77 V::compare(value, target),
78 Some(Ordering::Greater | Ordering::Equal)
79 )
80 }
81 Predicate::LessThan(target) => {
82 matches!(V::compare(value, target), Some(Ordering::Less))
83 }
84 Predicate::LessThanOrEquals(target) => matches!(
85 V::compare(value, target),
86 Some(Ordering::Less | Ordering::Equal)
87 ),
88 Predicate::Range { lower, upper } => {
89 if let Some(bound) = lower
90 && !match bound {
91 Bound::Included(target) => matches!(
92 V::compare(value, target),
93 Some(Ordering::Greater | Ordering::Equal)
94 ),
95 Bound::Excluded(target) => {
96 matches!(V::compare(value, target), Some(Ordering::Greater))
97 }
98 Bound::Unbounded => true,
99 }
100 {
101 return false;
102 }
103
104 if let Some(bound) = upper
105 && !match bound {
106 Bound::Included(target) => matches!(
107 V::compare(value, target),
108 Some(Ordering::Less | Ordering::Equal)
109 ),
110 Bound::Excluded(target) => {
111 matches!(V::compare(value, target), Some(Ordering::Less))
112 }
113 Bound::Unbounded => true,
114 }
115 {
116 return false;
117 }
118
119 true
120 }
121 Predicate::In(values) => values.iter().any(|target| V::equals(value, target)),
122 Predicate::StartsWith {
123 pattern,
124 case_sensitive,
125 } => V::starts_with(value, pattern, *case_sensitive),
126 Predicate::EndsWith {
127 pattern,
128 case_sensitive,
129 } => V::ends_with(value, pattern, *case_sensitive),
130 Predicate::Contains {
131 pattern,
132 case_sensitive,
133 } => V::contains(value, pattern, *case_sensitive),
134 }
135 }
136}
137
138macro_rules! impl_predicate_value_for_primitive {
139 ($($ty:ty),+ $(,)?) => {
140 $(
141 impl PredicateValue for $ty {
142 type Borrowed<'a> = Self where Self: 'a;
143
144 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
145 value
146 }
147
148 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
149 *value == *target
150 }
151
152 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
153 value.partial_cmp(target)
154 }
155 }
156 )+
157 };
158}
159
160impl_predicate_value_for_primitive!(u64, u32, u16, u8, i64, i32, i16, i8, f64, f32, bool);
161
162impl PredicateValue for String {
163 type Borrowed<'a>
164 = str
165 where
166 Self: 'a;
167
168 fn borrowed(value: &Self) -> &Self::Borrowed<'_> {
169 value.as_str()
170 }
171
172 fn equals(value: &Self::Borrowed<'_>, target: &Self) -> bool {
173 value == target.as_str()
174 }
175
176 fn compare(value: &Self::Borrowed<'_>, target: &Self) -> Option<Ordering> {
177 Some(value.cmp(target.as_str()))
178 }
179
180 fn contains(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
181 if case_sensitive {
182 value.contains(target.as_str())
183 } else {
184 value.to_lowercase().contains(&target.to_lowercase())
185 }
186 }
187
188 fn starts_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
189 if case_sensitive {
190 value.starts_with(target.as_str())
191 } else {
192 value.to_lowercase().starts_with(&target.to_lowercase())
193 }
194 }
195
196 fn ends_with(value: &Self::Borrowed<'_>, target: &Self, case_sensitive: bool) -> bool {
197 if case_sensitive {
198 value.ends_with(target.as_str())
199 } else {
200 value.to_lowercase().ends_with(&target.to_lowercase())
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
206pub enum PredicateBuildError {
207 LiteralCast(LiteralCastError),
208 UnsupportedOperator(&'static str),
209}
210
211impl fmt::Display for PredicateBuildError {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 PredicateBuildError::LiteralCast(err) => write!(f, "literal cast error: {err}"),
215 PredicateBuildError::UnsupportedOperator(op) => {
216 write!(f, "unsupported operator for typed predicate: {op}")
217 }
218 }
219 }
220}
221
222impl std::error::Error for PredicateBuildError {
223 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
224 match self {
225 PredicateBuildError::LiteralCast(err) => Some(err),
226 PredicateBuildError::UnsupportedOperator(_) => None,
227 }
228 }
229}
230
231impl From<LiteralCastError> for PredicateBuildError {
232 fn from(err: LiteralCastError) -> Self {
233 PredicateBuildError::LiteralCast(err)
234 }
235}
236
237pub fn build_fixed_width_predicate<T>(
238 op: &Operator<'_>,
239) -> Result<Predicate<T::Native>, PredicateBuildError>
240where
241 T: ArrowPrimitiveType,
242 T::Native: FromLiteral + Copy + PredicateValue,
243{
244 match op {
245 Operator::Equals(lit) => Ok(Predicate::Equals(
246 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
247 )),
248 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
249 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
250 )),
251 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
252 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
253 )),
254 Operator::LessThan(lit) => Ok(Predicate::LessThan(
255 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
256 )),
257 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
258 literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?,
259 )),
260 Operator::Range { lower, upper } => {
261 let lb = match bound_to_native::<T>(lower).map_err(PredicateBuildError::from)? {
262 Bound::Unbounded => None,
263 other => Some(other),
264 };
265 let ub = match bound_to_native::<T>(upper).map_err(PredicateBuildError::from)? {
266 Bound::Unbounded => None,
267 other => Some(other),
268 };
269
270 if lb.is_none() && ub.is_none() {
271 Ok(Predicate::All)
272 } else {
273 Ok(Predicate::Range {
274 lower: lb,
275 upper: ub,
276 })
277 }
278 }
279 Operator::In(values) => {
280 let mut natives = Vec::with_capacity(values.len());
281 for lit in *values {
282 natives
283 .push(literal_to_native::<T::Native>(lit).map_err(PredicateBuildError::from)?);
284 }
285 Ok(Predicate::In(natives))
286 }
287 _ => Err(PredicateBuildError::UnsupportedOperator(
288 "operator lacks typed literal support",
289 )),
290 }
291}
292
293fn parse_bool_bound(bound: &Bound<Literal>) -> Result<Option<Bound<bool>>, PredicateBuildError> {
294 Ok(match bound {
295 Bound::Unbounded => None,
296 Bound::Included(lit) => Some(Bound::Included(
297 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
298 )),
299 Bound::Excluded(lit) => Some(Bound::Excluded(
300 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
301 )),
302 })
303}
304
305pub fn build_bool_predicate(op: &Operator<'_>) -> Result<Predicate<bool>, PredicateBuildError> {
306 match op {
307 Operator::Equals(lit) => Ok(Predicate::Equals(
308 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
309 )),
310 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
311 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
312 )),
313 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
314 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
315 )),
316 Operator::LessThan(lit) => Ok(Predicate::LessThan(
317 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
318 )),
319 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
320 literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?,
321 )),
322 Operator::Range { lower, upper } => {
323 let lb = parse_bool_bound(lower)?;
324 let ub = parse_bool_bound(upper)?;
325 if lb.is_none() && ub.is_none() {
326 Ok(Predicate::All)
327 } else {
328 Ok(Predicate::Range {
329 lower: lb,
330 upper: ub,
331 })
332 }
333 }
334 Operator::In(values) => {
335 let mut natives = Vec::with_capacity(values.len());
336 for lit in *values {
337 natives.push(literal_to_native::<bool>(lit).map_err(PredicateBuildError::from)?);
338 }
339 Ok(Predicate::In(natives))
340 }
341 _ => Err(PredicateBuildError::UnsupportedOperator(
342 "operator lacks boolean literal support",
343 )),
344 }
345}
346
347fn parse_string_bound(
348 bound: &Bound<Literal>,
349) -> Result<Option<Bound<String>>, PredicateBuildError> {
350 match bound {
351 Bound::Unbounded => Ok(None),
352 Bound::Included(lit) => literal_to_string(lit)
353 .map(|s| Some(Bound::Included(s)))
354 .map_err(PredicateBuildError::from),
355 Bound::Excluded(lit) => literal_to_string(lit)
356 .map(|s| Some(Bound::Excluded(s)))
357 .map_err(PredicateBuildError::from),
358 }
359}
360
361pub fn build_var_width_predicate(
362 op: &Operator<'_>,
363) -> Result<Predicate<String>, PredicateBuildError> {
364 match op {
365 Operator::Equals(lit) => Ok(Predicate::Equals(
366 literal_to_string(lit).map_err(PredicateBuildError::from)?,
367 )),
368 Operator::GreaterThan(lit) => Ok(Predicate::GreaterThan(
369 literal_to_string(lit).map_err(PredicateBuildError::from)?,
370 )),
371 Operator::GreaterThanOrEquals(lit) => Ok(Predicate::GreaterThanOrEquals(
372 literal_to_string(lit).map_err(PredicateBuildError::from)?,
373 )),
374 Operator::LessThan(lit) => Ok(Predicate::LessThan(
375 literal_to_string(lit).map_err(PredicateBuildError::from)?,
376 )),
377 Operator::LessThanOrEquals(lit) => Ok(Predicate::LessThanOrEquals(
378 literal_to_string(lit).map_err(PredicateBuildError::from)?,
379 )),
380 Operator::Range { lower, upper } => {
381 let lb = parse_string_bound(lower)?;
382 let ub = parse_string_bound(upper)?;
383 if lb.is_none() && ub.is_none() {
384 Ok(Predicate::All)
385 } else {
386 Ok(Predicate::Range {
387 lower: lb,
388 upper: ub,
389 })
390 }
391 }
392 Operator::In(values) => {
393 let mut out = Vec::with_capacity(values.len());
394 for lit in *values {
395 out.push(literal_to_string(lit).map_err(PredicateBuildError::from)?);
396 }
397 Ok(Predicate::In(out))
398 }
399 Operator::StartsWith {
400 pattern,
401 case_sensitive,
402 } => Ok(Predicate::StartsWith {
403 pattern: pattern.to_string(),
404 case_sensitive: *case_sensitive,
405 }),
406 Operator::EndsWith {
407 pattern,
408 case_sensitive,
409 } => Ok(Predicate::EndsWith {
410 pattern: pattern.to_string(),
411 case_sensitive: *case_sensitive,
412 }),
413 Operator::Contains {
414 pattern,
415 case_sensitive,
416 } => Ok(Predicate::Contains {
417 pattern: pattern.to_string(),
418 case_sensitive: *case_sensitive,
419 }),
420 Operator::IsNull | Operator::IsNotNull => Err(PredicateBuildError::UnsupportedOperator(
421 "operator lacks string literal support",
422 )),
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::literal::Literal;
430 use std::ops::Bound;
431
432 #[test]
433 fn predicate_matches_equals() {
434 let op = Operator::Equals(42_i64.into());
435 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int64Type>(&op).unwrap();
436 let forty_two: i64 = 42;
437 let seven: i64 = 7;
438 assert!(predicate.matches(&forty_two));
439 assert!(!predicate.matches(&seven));
440 }
441
442 #[test]
443 fn predicate_range_limits() {
444 let op = Operator::Range {
445 lower: Bound::Included(10.into()),
446 upper: Bound::Excluded(20.into()),
447 };
448 let predicate = build_fixed_width_predicate::<arrow::datatypes::Int32Type>(&op).unwrap();
449 assert!(predicate.matches(&10));
450 assert!(predicate.matches(&19));
451 assert!(!predicate.matches(&9));
452 assert!(!predicate.matches(&20));
453 }
454
455 #[test]
456 fn predicate_in_operator() {
457 let values = [1.into(), 2.into(), 3.into()];
458 let op = Operator::In(&values);
459 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt8Type>(&op).unwrap();
460 let two: u8 = 2;
461 let five: u8 = 5;
462 assert!(predicate.matches(&two));
463 assert!(!predicate.matches(&five));
464 }
465
466 #[test]
467 fn unsupported_operator_errors() {
468 let op = Operator::starts_with("foo", true);
469 let err = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap_err();
470 assert!(matches!(err, PredicateBuildError::UnsupportedOperator(_)));
471 }
472
473 #[test]
474 fn literal_cast_error_propagates() {
475 let op = Operator::Equals("foo".into());
476 let err = build_fixed_width_predicate::<arrow::datatypes::UInt16Type>(&op).unwrap_err();
477 assert!(matches!(err, PredicateBuildError::LiteralCast(_)));
478 }
479
480 #[test]
481 fn empty_bounds_map_to_all() {
482 let op = Operator::Range {
483 lower: Bound::Unbounded,
484 upper: Bound::Unbounded,
485 };
486 let predicate = build_fixed_width_predicate::<arrow::datatypes::UInt32Type>(&op).unwrap();
487 assert!(predicate.matches(&123u32));
488 }
489
490 #[test]
491 fn matches_all_for_empty_in_list() {
492 let values: [Literal; 0] = [];
493 let op = Operator::In(&values);
494 let predicate = build_fixed_width_predicate::<arrow::datatypes::Float32Type>(&op).unwrap();
495 assert!(!predicate.matches(&1.23f32));
496 }
497
498 #[test]
499 fn string_predicate_equals() {
500 let op = Operator::Equals("foo".into());
501 let predicate = build_var_width_predicate(&op).unwrap();
502 assert!(predicate.matches("foo"));
503 assert!(!predicate.matches("bar"));
504 }
505
506 #[test]
507 fn string_predicate_range() {
508 let op = Operator::Range {
509 lower: Bound::Included("alpha".into()),
510 upper: Bound::Excluded("omega".into()),
511 };
512 let predicate = build_var_width_predicate(&op).unwrap();
513 assert!(predicate.matches("delta"));
514 assert!(!predicate.matches("zzz"));
515 }
516
517 #[test]
518 fn string_predicate_in_and_patterns() {
519 let vals = ["x".into(), "y".into()];
520 let op = Operator::In(&vals);
521 let predicate = build_var_width_predicate(&op).unwrap();
522 assert!(predicate.matches("x"));
523 assert!(!predicate.matches("z"));
524
525 let sw_sensitive = build_var_width_predicate(&Operator::starts_with("pre", true))
526 .expect("starts with predicate");
527 assert!(sw_sensitive.matches("prefix"));
528 assert!(!sw_sensitive.matches("Prefix"));
529
530 let sw_insensitive = build_var_width_predicate(&Operator::starts_with("Pre", false))
531 .expect("starts with predicate");
532 assert!(sw_insensitive.matches("prefix"));
533 assert!(sw_insensitive.matches("Prefix"));
534
535 let ew_sensitive = build_var_width_predicate(&Operator::ends_with("suf", true))
536 .expect("ends with predicate");
537 assert!(ew_sensitive.matches("datsuf"));
538 assert!(!ew_sensitive.matches("datSuf"));
539
540 let ew_insensitive = build_var_width_predicate(&Operator::ends_with("SUF", false))
541 .expect("ends with predicate");
542 assert!(ew_insensitive.matches("datsuf"));
543 assert!(ew_insensitive.matches("datSuf"));
544
545 let ct_sensitive = build_var_width_predicate(&Operator::contains("mid", true))
546 .expect("contains predicate");
547 assert!(ct_sensitive.matches("amidst"));
548 assert!(!ct_sensitive.matches("aMidst"));
549
550 let ct_insensitive = build_var_width_predicate(&Operator::contains("MiD", false))
551 .expect("contains predicate");
552 assert!(ct_insensitive.matches("amidst"));
553 assert!(ct_insensitive.matches("aMidst"));
554 }
555}