1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16 And(Vec<Expr<'a, F>>),
17 Or(Vec<Expr<'a, F>>),
18 Not(Box<Expr<'a, F>>),
19 Pred(Filter<'a, F>),
20 Compare {
21 left: ScalarExpr<F>,
22 op: CompareOp,
23 right: ScalarExpr<F>,
24 },
25 InList {
26 expr: ScalarExpr<F>,
27 list: Vec<ScalarExpr<F>>,
28 negated: bool,
29 },
30 Literal(bool),
33}
34
35impl<'a, F> Expr<'a, F> {
36 #[inline]
38 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
39 Expr::And(fs.into_iter().map(Expr::Pred).collect())
40 }
41
42 #[inline]
44 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
45 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
46 }
47
48 #[allow(clippy::should_implement_trait)]
50 #[inline]
51 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
52 Expr::Not(Box::new(e))
53 }
54}
55
56#[derive(Clone, Debug)]
58pub enum ScalarExpr<F> {
59 Column(F),
60 Literal(Literal),
61 Binary {
62 left: Box<ScalarExpr<F>>,
63 op: BinaryOp,
64 right: Box<ScalarExpr<F>>,
65 },
66 Aggregate(AggregateCall<F>),
69 GetField {
73 base: Box<ScalarExpr<F>>,
74 field_name: String,
75 },
76}
77
78#[derive(Clone, Debug)]
80pub enum AggregateCall<F> {
81 CountStar,
82 Count(F),
83 Sum(F),
84 Min(F),
85 Max(F),
86 CountNulls(F),
87}
88
89impl<F> ScalarExpr<F> {
90 #[inline]
91 pub fn column(field: F) -> Self {
92 Self::Column(field)
93 }
94
95 #[inline]
96 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
97 Self::Literal(lit.into())
98 }
99
100 #[inline]
101 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
102 Self::Binary {
103 left: Box::new(left),
104 op,
105 right: Box::new(right),
106 }
107 }
108
109 #[inline]
110 pub fn aggregate(call: AggregateCall<F>) -> Self {
111 Self::Aggregate(call)
112 }
113
114 #[inline]
115 pub fn get_field(base: Self, field_name: String) -> Self {
116 Self::GetField {
117 base: Box::new(base),
118 field_name,
119 }
120 }
121}
122
123#[derive(Clone, Copy, Debug, Eq, PartialEq)]
125pub enum BinaryOp {
126 Add,
127 Subtract,
128 Multiply,
129 Divide,
130 Modulo,
131}
132
133#[derive(Clone, Copy, Debug, Eq, PartialEq)]
135pub enum CompareOp {
136 Eq,
137 NotEq,
138 Lt,
139 LtEq,
140 Gt,
141 GtEq,
142}
143
144#[derive(Debug, Clone)]
146pub struct Filter<'a, F> {
147 pub field_id: F,
148 pub op: Operator<'a>,
149}
150
151#[derive(Debug, Clone)]
156pub enum Operator<'a> {
157 Equals(Literal),
158 Range {
159 lower: Bound<Literal>,
160 upper: Bound<Literal>,
161 },
162 GreaterThan(Literal),
163 GreaterThanOrEquals(Literal),
164 LessThan(Literal),
165 LessThanOrEquals(Literal),
166 In(&'a [Literal]),
167 StartsWith {
168 pattern: &'a str,
169 case_sensitive: bool,
170 },
171 EndsWith {
172 pattern: &'a str,
173 case_sensitive: bool,
174 },
175 Contains {
176 pattern: &'a str,
177 case_sensitive: bool,
178 },
179 IsNull,
180 IsNotNull,
181}
182
183impl<'a> Operator<'a> {
184 #[inline]
185 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
186 Operator::StartsWith {
187 pattern,
188 case_sensitive,
189 }
190 }
191
192 #[inline]
193 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
194 Operator::EndsWith {
195 pattern,
196 case_sensitive,
197 }
198 }
199
200 #[inline]
201 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
202 Operator::Contains {
203 pattern,
204 case_sensitive,
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn build_simple_exprs() {
215 let f1 = Filter {
216 field_id: 1,
217 op: Operator::Equals("abc".into()),
218 };
219 let f2 = Filter {
220 field_id: 2,
221 op: Operator::LessThan("zzz".into()),
222 };
223 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
224 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
225 let not_all = Expr::not(all);
226 match any {
227 Expr::Or(v) => assert_eq!(v.len(), 2),
228 _ => panic!("expected Or"),
229 }
230 match not_all {
231 Expr::Not(inner) => match *inner {
232 Expr::And(v) => assert_eq!(v.len(), 2),
233 _ => panic!("expected And inside Not"),
234 },
235 _ => panic!("expected Not"),
236 }
237 }
238
239 #[test]
240 fn complex_nested_shape() {
241 let f1 = Filter {
246 field_id: 1u32,
247 op: Operator::Equals("a".into()),
248 };
249 let f2 = Filter {
250 field_id: 2u32,
251 op: Operator::LessThan("zzz".into()),
252 };
253 let in_values = ["x".into(), "y".into(), "z".into()];
254 let f3 = Filter {
255 field_id: 3u32,
256 op: Operator::In(&in_values),
257 };
258 let f4 = Filter {
259 field_id: 4u32,
260 op: Operator::starts_with("pre", true),
261 };
262
263 let left = Expr::And(vec![
265 Expr::Pred(f1.clone()),
266 Expr::Or(vec![
267 Expr::Pred(f2.clone()),
268 Expr::not(Expr::Pred(f3.clone())),
269 ]),
270 ]);
271 let right = Expr::And(vec![
272 Expr::not(Expr::Pred(f1.clone())),
273 Expr::Pred(f4.clone()),
274 ]);
275 let top = Expr::Or(vec![left, right]);
276
277 match top {
279 Expr::Or(branches) => {
280 assert_eq!(branches.len(), 2);
281 match &branches[0] {
282 Expr::And(v) => {
283 assert_eq!(v.len(), 2);
284 match &v[0] {
286 Expr::Pred(Filter { field_id, .. }) => {
287 assert_eq!(*field_id, 1)
288 }
289 _ => panic!("expected Pred(f1) in left-AND[0]"),
290 }
291 match &v[1] {
292 Expr::Or(or_vec) => {
293 assert_eq!(or_vec.len(), 2);
294 match &or_vec[0] {
295 Expr::Pred(Filter { field_id, .. }) => {
296 assert_eq!(*field_id, 2)
297 }
298 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
299 }
300 match &or_vec[1] {
301 Expr::Not(inner) => match inner.as_ref() {
302 Expr::Pred(Filter { field_id, .. }) => {
303 assert_eq!(*field_id, 3)
304 }
305 _ => panic!(
306 "expected Not(Pred(f3)) in \
307 left-AND[1].OR[1]"
308 ),
309 },
310 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
311 }
312 }
313 _ => panic!("expected OR in left-AND[1]"),
314 }
315 }
316 _ => panic!("expected AND on left branch of top OR"),
317 }
318 match &branches[1] {
319 Expr::And(v) => {
320 assert_eq!(v.len(), 2);
321 match &v[0] {
323 Expr::Not(inner) => match inner.as_ref() {
324 Expr::Pred(Filter { field_id, .. }) => {
325 assert_eq!(*field_id, 1)
326 }
327 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
328 },
329 _ => panic!("expected Not(...) in right-AND[0]"),
330 }
331 match &v[1] {
332 Expr::Pred(Filter { field_id, .. }) => {
333 assert_eq!(*field_id, 4)
334 }
335 _ => panic!("expected Pred(f4) in right-AND[1]"),
336 }
337 }
338 _ => panic!("expected AND on right branch of top OR"),
339 }
340 }
341 _ => panic!("expected top-level OR"),
342 }
343 }
344
345 #[test]
346 fn range_bounds_roundtrip() {
347 let f = Filter {
349 field_id: 7u32,
350 op: Operator::Range {
351 lower: Bound::Included("aaa".into()),
352 upper: Bound::Excluded("bbb".into()),
353 },
354 };
355
356 match &f.op {
357 Operator::Range { lower, upper } => {
358 if let Bound::Included(l) = lower {
359 assert_eq!(*l, Literal::String("aaa".to_string()));
360 } else {
361 panic!("lower bound should be Included");
362 }
363
364 if let Bound::Excluded(u) = upper {
365 assert_eq!(*u, Literal::String("bbb".to_string()));
366 } else {
367 panic!("upper bound should be Excluded");
368 }
369 }
370 _ => panic!("expected Range operator"),
371 }
372 }
373
374 #[test]
375 fn helper_builders_preserve_structure_and_order() {
376 let f1 = Filter {
377 field_id: 1u32,
378 op: Operator::Equals("a".into()),
379 };
380 let f2 = Filter {
381 field_id: 2u32,
382 op: Operator::Equals("b".into()),
383 };
384 let f3 = Filter {
385 field_id: 3u32,
386 op: Operator::Equals("c".into()),
387 };
388
389 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
390 match and_expr {
391 Expr::And(v) => {
392 assert_eq!(v.len(), 3);
393 match &v[0] {
395 Expr::Pred(Filter { field_id, .. }) => {
396 assert_eq!(*field_id, 1)
397 }
398 _ => panic!(),
399 }
400 match &v[1] {
401 Expr::Pred(Filter { field_id, .. }) => {
402 assert_eq!(*field_id, 2)
403 }
404 _ => panic!(),
405 }
406 match &v[2] {
407 Expr::Pred(Filter { field_id, .. }) => {
408 assert_eq!(*field_id, 3)
409 }
410 _ => panic!(),
411 }
412 }
413 _ => panic!("expected And"),
414 }
415
416 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
417 match or_expr {
418 Expr::Or(v) => {
419 assert_eq!(v.len(), 3);
420 match &v[0] {
422 Expr::Pred(Filter { field_id, .. }) => {
423 assert_eq!(*field_id, 3)
424 }
425 _ => panic!(),
426 }
427 match &v[1] {
428 Expr::Pred(Filter { field_id, .. }) => {
429 assert_eq!(*field_id, 2)
430 }
431 _ => panic!(),
432 }
433 match &v[2] {
434 Expr::Pred(Filter { field_id, .. }) => {
435 assert_eq!(*field_id, 1)
436 }
437 _ => panic!(),
438 }
439 }
440 _ => panic!("expected Or"),
441 }
442 }
443
444 #[test]
445 fn set_and_pattern_ops_hold_borrowed_slices() {
446 let in_values = ["aa".into(), "bb".into(), "cc".into()];
447 let f_in = Filter {
448 field_id: 9u32,
449 op: Operator::In(&in_values),
450 };
451 match f_in.op {
452 Operator::In(arr) => {
453 assert_eq!(arr.len(), 3);
454 }
455 _ => panic!("expected In"),
456 }
457
458 let f_sw = Filter {
459 field_id: 10u32,
460 op: Operator::starts_with("pre", true),
461 };
462 let f_ew = Filter {
463 field_id: 11u32,
464 op: Operator::ends_with("suf", true),
465 };
466 let f_ct = Filter {
467 field_id: 12u32,
468 op: Operator::contains("mid", true),
469 };
470
471 match f_sw.op {
472 Operator::StartsWith {
473 pattern: b,
474 case_sensitive,
475 } => {
476 assert_eq!(b, "pre");
477 assert!(case_sensitive);
478 }
479 _ => panic!(),
480 }
481 match f_ew.op {
482 Operator::EndsWith {
483 pattern: b,
484 case_sensitive,
485 } => {
486 assert_eq!(b, "suf");
487 assert!(case_sensitive);
488 }
489 _ => panic!(),
490 }
491 match f_ct.op {
492 Operator::Contains {
493 pattern: b,
494 case_sensitive,
495 } => {
496 assert_eq!(b, "mid");
497 assert!(case_sensitive);
498 }
499 _ => panic!(),
500 }
501 }
502
503 #[test]
504 fn generic_field_id_works_with_strings() {
505 let f1 = Filter {
507 field_id: "name",
508 op: Operator::Equals("alice".into()),
509 };
510 let f2 = Filter {
511 field_id: "status",
512 op: Operator::GreaterThanOrEquals("active".into()),
513 };
514 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
515
516 match expr {
517 Expr::And(v) => {
518 assert_eq!(v.len(), 2);
519 match &v[0] {
520 Expr::Pred(Filter { field_id, .. }) => {
521 assert_eq!(*field_id, "name")
522 }
523 _ => panic!("expected Pred(name)"),
524 }
525 match &v[1] {
526 Expr::Pred(Filter { field_id, .. }) => {
527 assert_eq!(*field_id, "status")
528 }
529 _ => panic!("expected Pred(status)"),
530 }
531 }
532 _ => panic!("expected And"),
533 }
534 }
535
536 #[test]
537 fn very_deep_not_chain() {
538 let base = Expr::Pred(Filter {
540 field_id: 42u32,
541 op: Operator::Equals("x".into()),
542 });
543 let mut expr = base;
544 for _ in 0..64 {
545 expr = Expr::not(expr);
546 }
547
548 let mut count = 0usize;
550 let mut cur = &expr;
551 loop {
552 match cur {
553 Expr::Not(inner) => {
554 count += 1;
555 cur = inner;
556 }
557 Expr::Pred(Filter { field_id, .. }) => {
558 assert_eq!(*field_id, 42);
559 break;
560 }
561 _ => panic!("unexpected node inside deep NOT chain"),
562 }
563 }
564 assert_eq!(count, 64);
565 }
566
567 #[test]
568 fn literal_construction() {
569 let f = Filter {
570 field_id: "my_u64_col",
571 op: Operator::Range {
572 lower: Bound::Included(150.into()),
573 upper: Bound::Excluded(300.into()),
574 },
575 };
576
577 match f.op {
578 Operator::Range { lower, upper } => {
579 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
580 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
581 }
582 _ => panic!("Expected a range operator"),
583 }
584
585 let f2 = Filter {
586 field_id: "my_str_col",
587 op: Operator::Equals("hello".into()),
588 };
589
590 match f2.op {
591 Operator::Equals(lit) => {
592 assert_eq!(lit, Literal::String("hello".to_string()));
593 }
594 _ => panic!("Expected an equals operator"),
595 }
596 }
597}