1#![forbid(unsafe_code)]
9
10pub use crate::literal::*;
11use std::ops::Bound;
12
13#[derive(Clone, Debug)]
15pub enum Expr<'a, F> {
16 And(Vec<Expr<'a, F>>),
17 Or(Vec<Expr<'a, F>>),
18 Not(Box<Expr<'a, F>>),
19 Pred(Filter<'a, F>),
20 Compare {
21 left: ScalarExpr<F>,
22 op: CompareOp,
23 right: ScalarExpr<F>,
24 },
25}
26
27impl<'a, F> Expr<'a, F> {
28 #[inline]
30 pub fn all_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
31 Expr::And(fs.into_iter().map(Expr::Pred).collect())
32 }
33
34 #[inline]
36 pub fn any_of(fs: Vec<Filter<'a, F>>) -> Expr<'a, F> {
37 Expr::Or(fs.into_iter().map(Expr::Pred).collect())
38 }
39
40 #[allow(clippy::should_implement_trait)]
42 #[inline]
43 pub fn not(e: Expr<'a, F>) -> Expr<'a, F> {
44 Expr::Not(Box::new(e))
45 }
46}
47
48#[derive(Clone, Debug)]
50pub enum ScalarExpr<F> {
51 Column(F),
52 Literal(Literal),
53 Binary {
54 left: Box<ScalarExpr<F>>,
55 op: BinaryOp,
56 right: Box<ScalarExpr<F>>,
57 },
58 Aggregate(AggregateCall<F>),
61 GetField {
65 base: Box<ScalarExpr<F>>,
66 field_name: String,
67 },
68}
69
70#[derive(Clone, Debug)]
72pub enum AggregateCall<F> {
73 CountStar,
74 Count(F),
75 Sum(F),
76 Min(F),
77 Max(F),
78 CountNulls(F),
79}
80
81impl<F> ScalarExpr<F> {
82 #[inline]
83 pub fn column(field: F) -> Self {
84 Self::Column(field)
85 }
86
87 #[inline]
88 pub fn literal<L: Into<Literal>>(lit: L) -> Self {
89 Self::Literal(lit.into())
90 }
91
92 #[inline]
93 pub fn binary(left: Self, op: BinaryOp, right: Self) -> Self {
94 Self::Binary {
95 left: Box::new(left),
96 op,
97 right: Box::new(right),
98 }
99 }
100
101 #[inline]
102 pub fn aggregate(call: AggregateCall<F>) -> Self {
103 Self::Aggregate(call)
104 }
105
106 #[inline]
107 pub fn get_field(base: Self, field_name: String) -> Self {
108 Self::GetField {
109 base: Box::new(base),
110 field_name,
111 }
112 }
113}
114
115#[derive(Clone, Copy, Debug, Eq, PartialEq)]
117pub enum BinaryOp {
118 Add,
119 Subtract,
120 Multiply,
121 Divide,
122 Modulo,
123}
124
125#[derive(Clone, Copy, Debug, Eq, PartialEq)]
127pub enum CompareOp {
128 Eq,
129 NotEq,
130 Lt,
131 LtEq,
132 Gt,
133 GtEq,
134}
135
136#[derive(Debug, Clone)]
138pub struct Filter<'a, F> {
139 pub field_id: F,
140 pub op: Operator<'a>,
141}
142
143#[derive(Debug, Clone)]
148pub enum Operator<'a> {
149 Equals(Literal),
150 Range {
151 lower: Bound<Literal>,
152 upper: Bound<Literal>,
153 },
154 GreaterThan(Literal),
155 GreaterThanOrEquals(Literal),
156 LessThan(Literal),
157 LessThanOrEquals(Literal),
158 In(&'a [Literal]),
159 StartsWith {
160 pattern: &'a str,
161 case_sensitive: bool,
162 },
163 EndsWith {
164 pattern: &'a str,
165 case_sensitive: bool,
166 },
167 Contains {
168 pattern: &'a str,
169 case_sensitive: bool,
170 },
171 IsNull,
172 IsNotNull,
173}
174
175impl<'a> Operator<'a> {
176 #[inline]
177 pub fn starts_with(pattern: &'a str, case_sensitive: bool) -> Self {
178 Operator::StartsWith {
179 pattern,
180 case_sensitive,
181 }
182 }
183
184 #[inline]
185 pub fn ends_with(pattern: &'a str, case_sensitive: bool) -> Self {
186 Operator::EndsWith {
187 pattern,
188 case_sensitive,
189 }
190 }
191
192 #[inline]
193 pub fn contains(pattern: &'a str, case_sensitive: bool) -> Self {
194 Operator::Contains {
195 pattern,
196 case_sensitive,
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn build_simple_exprs() {
207 let f1 = Filter {
208 field_id: 1,
209 op: Operator::Equals("abc".into()),
210 };
211 let f2 = Filter {
212 field_id: 2,
213 op: Operator::LessThan("zzz".into()),
214 };
215 let all = Expr::all_of(vec![f1.clone(), f2.clone()]);
216 let any = Expr::any_of(vec![f1.clone(), f2.clone()]);
217 let not_all = Expr::not(all);
218 match any {
219 Expr::Or(v) => assert_eq!(v.len(), 2),
220 _ => panic!("expected Or"),
221 }
222 match not_all {
223 Expr::Not(inner) => match *inner {
224 Expr::And(v) => assert_eq!(v.len(), 2),
225 _ => panic!("expected And inside Not"),
226 },
227 _ => panic!("expected Not"),
228 }
229 }
230
231 #[test]
232 fn complex_nested_shape() {
233 let f1 = Filter {
238 field_id: 1u32,
239 op: Operator::Equals("a".into()),
240 };
241 let f2 = Filter {
242 field_id: 2u32,
243 op: Operator::LessThan("zzz".into()),
244 };
245 let in_values = ["x".into(), "y".into(), "z".into()];
246 let f3 = Filter {
247 field_id: 3u32,
248 op: Operator::In(&in_values),
249 };
250 let f4 = Filter {
251 field_id: 4u32,
252 op: Operator::starts_with("pre", true),
253 };
254
255 let left = Expr::And(vec![
257 Expr::Pred(f1.clone()),
258 Expr::Or(vec![
259 Expr::Pred(f2.clone()),
260 Expr::not(Expr::Pred(f3.clone())),
261 ]),
262 ]);
263 let right = Expr::And(vec![
264 Expr::not(Expr::Pred(f1.clone())),
265 Expr::Pred(f4.clone()),
266 ]);
267 let top = Expr::Or(vec![left, right]);
268
269 match top {
271 Expr::Or(branches) => {
272 assert_eq!(branches.len(), 2);
273 match &branches[0] {
274 Expr::And(v) => {
275 assert_eq!(v.len(), 2);
276 match &v[0] {
278 Expr::Pred(Filter { field_id, .. }) => {
279 assert_eq!(*field_id, 1)
280 }
281 _ => panic!("expected Pred(f1) in left-AND[0]"),
282 }
283 match &v[1] {
284 Expr::Or(or_vec) => {
285 assert_eq!(or_vec.len(), 2);
286 match &or_vec[0] {
287 Expr::Pred(Filter { field_id, .. }) => {
288 assert_eq!(*field_id, 2)
289 }
290 _ => panic!("expected Pred(f2) in left-AND[1].OR[0]"),
291 }
292 match &or_vec[1] {
293 Expr::Not(inner) => match inner.as_ref() {
294 Expr::Pred(Filter { field_id, .. }) => {
295 assert_eq!(*field_id, 3)
296 }
297 _ => panic!(
298 "expected Not(Pred(f3)) in \
299 left-AND[1].OR[1]"
300 ),
301 },
302 _ => panic!("expected Not(...) in left-AND[1].OR[1]"),
303 }
304 }
305 _ => panic!("expected OR in left-AND[1]"),
306 }
307 }
308 _ => panic!("expected AND on left branch of top OR"),
309 }
310 match &branches[1] {
311 Expr::And(v) => {
312 assert_eq!(v.len(), 2);
313 match &v[0] {
315 Expr::Not(inner) => match inner.as_ref() {
316 Expr::Pred(Filter { field_id, .. }) => {
317 assert_eq!(*field_id, 1)
318 }
319 _ => panic!("expected Not(Pred(f1)) in right-AND[0]"),
320 },
321 _ => panic!("expected Not(...) in right-AND[0]"),
322 }
323 match &v[1] {
324 Expr::Pred(Filter { field_id, .. }) => {
325 assert_eq!(*field_id, 4)
326 }
327 _ => panic!("expected Pred(f4) in right-AND[1]"),
328 }
329 }
330 _ => panic!("expected AND on right branch of top OR"),
331 }
332 }
333 _ => panic!("expected top-level OR"),
334 }
335 }
336
337 #[test]
338 fn range_bounds_roundtrip() {
339 let f = Filter {
341 field_id: 7u32,
342 op: Operator::Range {
343 lower: Bound::Included("aaa".into()),
344 upper: Bound::Excluded("bbb".into()),
345 },
346 };
347
348 match &f.op {
349 Operator::Range { lower, upper } => {
350 if let Bound::Included(l) = lower {
351 assert_eq!(*l, Literal::String("aaa".to_string()));
352 } else {
353 panic!("lower bound should be Included");
354 }
355
356 if let Bound::Excluded(u) = upper {
357 assert_eq!(*u, Literal::String("bbb".to_string()));
358 } else {
359 panic!("upper bound should be Excluded");
360 }
361 }
362 _ => panic!("expected Range operator"),
363 }
364 }
365
366 #[test]
367 fn helper_builders_preserve_structure_and_order() {
368 let f1 = Filter {
369 field_id: 1u32,
370 op: Operator::Equals("a".into()),
371 };
372 let f2 = Filter {
373 field_id: 2u32,
374 op: Operator::Equals("b".into()),
375 };
376 let f3 = Filter {
377 field_id: 3u32,
378 op: Operator::Equals("c".into()),
379 };
380
381 let and_expr = Expr::all_of(vec![f1.clone(), f2.clone(), f3.clone()]);
382 match and_expr {
383 Expr::And(v) => {
384 assert_eq!(v.len(), 3);
385 match &v[0] {
387 Expr::Pred(Filter { field_id, .. }) => {
388 assert_eq!(*field_id, 1)
389 }
390 _ => panic!(),
391 }
392 match &v[1] {
393 Expr::Pred(Filter { field_id, .. }) => {
394 assert_eq!(*field_id, 2)
395 }
396 _ => panic!(),
397 }
398 match &v[2] {
399 Expr::Pred(Filter { field_id, .. }) => {
400 assert_eq!(*field_id, 3)
401 }
402 _ => panic!(),
403 }
404 }
405 _ => panic!("expected And"),
406 }
407
408 let or_expr = Expr::any_of(vec![f3.clone(), f2.clone(), f1.clone()]);
409 match or_expr {
410 Expr::Or(v) => {
411 assert_eq!(v.len(), 3);
412 match &v[0] {
414 Expr::Pred(Filter { field_id, .. }) => {
415 assert_eq!(*field_id, 3)
416 }
417 _ => panic!(),
418 }
419 match &v[1] {
420 Expr::Pred(Filter { field_id, .. }) => {
421 assert_eq!(*field_id, 2)
422 }
423 _ => panic!(),
424 }
425 match &v[2] {
426 Expr::Pred(Filter { field_id, .. }) => {
427 assert_eq!(*field_id, 1)
428 }
429 _ => panic!(),
430 }
431 }
432 _ => panic!("expected Or"),
433 }
434 }
435
436 #[test]
437 fn set_and_pattern_ops_hold_borrowed_slices() {
438 let in_values = ["aa".into(), "bb".into(), "cc".into()];
439 let f_in = Filter {
440 field_id: 9u32,
441 op: Operator::In(&in_values),
442 };
443 match f_in.op {
444 Operator::In(arr) => {
445 assert_eq!(arr.len(), 3);
446 }
447 _ => panic!("expected In"),
448 }
449
450 let f_sw = Filter {
451 field_id: 10u32,
452 op: Operator::starts_with("pre", true),
453 };
454 let f_ew = Filter {
455 field_id: 11u32,
456 op: Operator::ends_with("suf", true),
457 };
458 let f_ct = Filter {
459 field_id: 12u32,
460 op: Operator::contains("mid", true),
461 };
462
463 match f_sw.op {
464 Operator::StartsWith {
465 pattern: b,
466 case_sensitive,
467 } => {
468 assert_eq!(b, "pre");
469 assert!(case_sensitive);
470 }
471 _ => panic!(),
472 }
473 match f_ew.op {
474 Operator::EndsWith {
475 pattern: b,
476 case_sensitive,
477 } => {
478 assert_eq!(b, "suf");
479 assert!(case_sensitive);
480 }
481 _ => panic!(),
482 }
483 match f_ct.op {
484 Operator::Contains {
485 pattern: b,
486 case_sensitive,
487 } => {
488 assert_eq!(b, "mid");
489 assert!(case_sensitive);
490 }
491 _ => panic!(),
492 }
493 }
494
495 #[test]
496 fn generic_field_id_works_with_strings() {
497 let f1 = Filter {
499 field_id: "name",
500 op: Operator::Equals("alice".into()),
501 };
502 let f2 = Filter {
503 field_id: "status",
504 op: Operator::GreaterThanOrEquals("active".into()),
505 };
506 let expr = Expr::all_of(vec![f1.clone(), f2.clone()]);
507
508 match expr {
509 Expr::And(v) => {
510 assert_eq!(v.len(), 2);
511 match &v[0] {
512 Expr::Pred(Filter { field_id, .. }) => {
513 assert_eq!(*field_id, "name")
514 }
515 _ => panic!("expected Pred(name)"),
516 }
517 match &v[1] {
518 Expr::Pred(Filter { field_id, .. }) => {
519 assert_eq!(*field_id, "status")
520 }
521 _ => panic!("expected Pred(status)"),
522 }
523 }
524 _ => panic!("expected And"),
525 }
526 }
527
528 #[test]
529 fn very_deep_not_chain() {
530 let base = Expr::Pred(Filter {
532 field_id: 42u32,
533 op: Operator::Equals("x".into()),
534 });
535 let mut expr = base;
536 for _ in 0..64 {
537 expr = Expr::not(expr);
538 }
539
540 let mut count = 0usize;
542 let mut cur = &expr;
543 loop {
544 match cur {
545 Expr::Not(inner) => {
546 count += 1;
547 cur = inner;
548 }
549 Expr::Pred(Filter { field_id, .. }) => {
550 assert_eq!(*field_id, 42);
551 break;
552 }
553 _ => panic!("unexpected node inside deep NOT chain"),
554 }
555 }
556 assert_eq!(count, 64);
557 }
558
559 #[test]
560 fn literal_construction() {
561 let f = Filter {
562 field_id: "my_u64_col",
563 op: Operator::Range {
564 lower: Bound::Included(150.into()),
565 upper: Bound::Excluded(300.into()),
566 },
567 };
568
569 match f.op {
570 Operator::Range { lower, upper } => {
571 assert_eq!(lower, Bound::Included(Literal::Integer(150)));
572 assert_eq!(upper, Bound::Excluded(Literal::Integer(300)));
573 }
574 _ => panic!("Expected a range operator"),
575 }
576
577 let f2 = Filter {
578 field_id: "my_str_col",
579 op: Operator::Equals("hello".into()),
580 };
581
582 match f2.op {
583 Operator::Equals(lit) => {
584 assert_eq!(lit, Literal::String("hello".to_string()));
585 }
586 _ => panic!("Expected an equals operator"),
587 }
588 }
589}