1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum LiteralValue {
8 I64(i64),
9 F64(f64),
10 I32(i32),
11 Str(String),
12 Bool(bool),
13 Null,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub enum ExprIr {
19 Column(String),
21 Lit(LiteralValue),
23
24 Eq(Box<ExprIr>, Box<ExprIr>),
26 Ne(Box<ExprIr>, Box<ExprIr>),
27 Gt(Box<ExprIr>, Box<ExprIr>),
28 Ge(Box<ExprIr>, Box<ExprIr>),
29 Lt(Box<ExprIr>, Box<ExprIr>),
30 Le(Box<ExprIr>, Box<ExprIr>),
31 EqNullSafe(Box<ExprIr>, Box<ExprIr>),
32
33 And(Box<ExprIr>, Box<ExprIr>),
35 Or(Box<ExprIr>, Box<ExprIr>),
36 Not(Box<ExprIr>),
37
38 Add(Box<ExprIr>, Box<ExprIr>),
40 Sub(Box<ExprIr>, Box<ExprIr>),
41 Mul(Box<ExprIr>, Box<ExprIr>),
42 Div(Box<ExprIr>, Box<ExprIr>),
43
44 Between {
46 left: Box<ExprIr>,
47 lower: Box<ExprIr>,
48 upper: Box<ExprIr>,
49 },
50 IsIn(Box<ExprIr>, Box<ExprIr>),
51
52 IsNull(Box<ExprIr>),
54 IsNotNull(Box<ExprIr>),
55
56 When {
58 condition: Box<ExprIr>,
59 then_expr: Box<ExprIr>,
60 otherwise: Box<ExprIr>,
61 },
62
63 Call {
65 name: String,
66 args: Vec<ExprIr>,
67 },
68}
69
70pub fn col(name: &str) -> ExprIr {
74 ExprIr::Column(name.to_string())
75}
76
77pub fn lit_i64(n: i64) -> ExprIr {
78 ExprIr::Lit(LiteralValue::I64(n))
79}
80
81pub fn lit_i32(n: i32) -> ExprIr {
82 ExprIr::Lit(LiteralValue::I32(n))
83}
84
85pub fn lit_f64(n: f64) -> ExprIr {
86 ExprIr::Lit(LiteralValue::F64(n))
87}
88
89pub fn lit_str(s: &str) -> ExprIr {
90 ExprIr::Lit(LiteralValue::Str(s.to_string()))
91}
92
93pub fn lit_bool(b: bool) -> ExprIr {
94 ExprIr::Lit(LiteralValue::Bool(b))
95}
96
97pub fn lit_null() -> ExprIr {
98 ExprIr::Lit(LiteralValue::Null)
99}
100
101pub fn call(name: &str, args: Vec<ExprIr>) -> ExprIr {
103 ExprIr::Call {
104 name: name.to_string(),
105 args,
106 }
107}
108
109pub struct WhenBuilder {
111 condition: ExprIr,
112}
113
114impl WhenBuilder {
115 pub fn then(self, then_expr: ExprIr) -> WhenThenBuilder {
116 WhenThenBuilder {
117 condition: self.condition,
118 then_expr,
119 }
120 }
121}
122
123pub struct WhenThenBuilder {
124 condition: ExprIr,
125 then_expr: ExprIr,
126}
127
128impl WhenThenBuilder {
129 pub fn otherwise(self, otherwise: ExprIr) -> ExprIr {
130 ExprIr::When {
131 condition: Box::new(self.condition),
132 then_expr: Box::new(self.then_expr),
133 otherwise: Box::new(otherwise),
134 }
135 }
136}
137
138pub fn when(condition: ExprIr) -> WhenBuilder {
140 WhenBuilder { condition }
141}
142
143pub fn eq(a: ExprIr, b: ExprIr) -> ExprIr {
146 ExprIr::Eq(Box::new(a), Box::new(b))
147}
148
149pub fn ne(a: ExprIr, b: ExprIr) -> ExprIr {
150 ExprIr::Ne(Box::new(a), Box::new(b))
151}
152
153pub fn gt(a: ExprIr, b: ExprIr) -> ExprIr {
154 ExprIr::Gt(Box::new(a), Box::new(b))
155}
156
157pub fn ge(a: ExprIr, b: ExprIr) -> ExprIr {
158 ExprIr::Ge(Box::new(a), Box::new(b))
159}
160
161pub fn lt(a: ExprIr, b: ExprIr) -> ExprIr {
162 ExprIr::Lt(Box::new(a), Box::new(b))
163}
164
165pub fn le(a: ExprIr, b: ExprIr) -> ExprIr {
166 ExprIr::Le(Box::new(a), Box::new(b))
167}
168
169pub fn and_(a: ExprIr, b: ExprIr) -> ExprIr {
170 ExprIr::And(Box::new(a), Box::new(b))
171}
172
173pub fn or_(a: ExprIr, b: ExprIr) -> ExprIr {
174 ExprIr::Or(Box::new(a), Box::new(b))
175}
176
177pub fn not_(a: ExprIr) -> ExprIr {
178 ExprIr::Not(Box::new(a))
179}
180
181pub fn is_null(a: ExprIr) -> ExprIr {
182 ExprIr::IsNull(Box::new(a))
183}
184
185pub fn between(left: ExprIr, lower: ExprIr, upper: ExprIr) -> ExprIr {
186 ExprIr::Between {
187 left: Box::new(left),
188 lower: Box::new(lower),
189 upper: Box::new(upper),
190 }
191}
192
193pub fn is_in(left: ExprIr, right: ExprIr) -> ExprIr {
194 ExprIr::IsIn(Box::new(left), Box::new(right))
195}
196
197pub fn sum(expr: ExprIr) -> ExprIr {
200 ExprIr::Call {
201 name: "sum".to_string(),
202 args: vec![expr],
203 }
204}
205
206pub fn count(expr: ExprIr) -> ExprIr {
207 ExprIr::Call {
208 name: "count".to_string(),
209 args: vec![expr],
210 }
211}
212
213pub fn min(expr: ExprIr) -> ExprIr {
214 ExprIr::Call {
215 name: "min".to_string(),
216 args: vec![expr],
217 }
218}
219
220pub fn max(expr: ExprIr) -> ExprIr {
221 ExprIr::Call {
222 name: "max".to_string(),
223 args: vec![expr],
224 }
225}
226
227pub fn mean(expr: ExprIr) -> ExprIr {
228 ExprIr::Call {
229 name: "mean".to_string(),
230 args: vec![expr],
231 }
232}
233
234pub fn first(expr: ExprIr) -> ExprIr {
235 ExprIr::Call {
236 name: "first".to_string(),
237 args: vec![expr],
238 }
239}
240
241pub fn last(expr: ExprIr) -> ExprIr {
242 ExprIr::Call {
243 name: "last".to_string(),
244 args: vec![expr],
245 }
246}
247
248pub fn stddev(expr: ExprIr) -> ExprIr {
249 ExprIr::Call {
250 name: "stddev".to_string(),
251 args: vec![expr],
252 }
253}
254
255pub fn stddev_pop(expr: ExprIr) -> ExprIr {
256 ExprIr::Call {
257 name: "stddev_pop".to_string(),
258 args: vec![expr],
259 }
260}
261
262pub fn std(expr: ExprIr) -> ExprIr {
263 ExprIr::Call {
264 name: "std".to_string(),
265 args: vec![expr],
266 }
267}
268
269pub fn stddev_samp(expr: ExprIr) -> ExprIr {
270 ExprIr::Call {
271 name: "stddev_samp".to_string(),
272 args: vec![expr],
273 }
274}
275
276pub fn variance(expr: ExprIr) -> ExprIr {
277 ExprIr::Call {
278 name: "variance".to_string(),
279 args: vec![expr],
280 }
281}
282
283pub fn var_pop(expr: ExprIr) -> ExprIr {
284 ExprIr::Call {
285 name: "var_pop".to_string(),
286 args: vec![expr],
287 }
288}
289
290pub fn var_samp(expr: ExprIr) -> ExprIr {
291 ExprIr::Call {
292 name: "var_samp".to_string(),
293 args: vec![expr],
294 }
295}
296
297pub fn count_distinct(expr: ExprIr) -> ExprIr {
298 ExprIr::Call {
299 name: "count_distinct".to_string(),
300 args: vec![expr],
301 }
302}
303
304pub fn approx_count_distinct(expr: ExprIr) -> ExprIr {
305 ExprIr::Call {
306 name: "approx_count_distinct".to_string(),
307 args: vec![expr],
308 }
309}
310
311pub fn collect_list(expr: ExprIr) -> ExprIr {
312 ExprIr::Call {
313 name: "collect_list".to_string(),
314 args: vec![expr],
315 }
316}
317
318pub fn collect_set(expr: ExprIr) -> ExprIr {
319 ExprIr::Call {
320 name: "collect_set".to_string(),
321 args: vec![expr],
322 }
323}
324
325pub fn bool_and(expr: ExprIr) -> ExprIr {
326 ExprIr::Call {
327 name: "bool_and".to_string(),
328 args: vec![expr],
329 }
330}
331
332pub fn every(expr: ExprIr) -> ExprIr {
333 ExprIr::Call {
334 name: "every".to_string(),
335 args: vec![expr],
336 }
337}
338
339pub fn median(expr: ExprIr) -> ExprIr {
340 ExprIr::Call {
341 name: "median".to_string(),
342 args: vec![expr],
343 }
344}
345
346pub fn try_sum(expr: ExprIr) -> ExprIr {
347 ExprIr::Call {
348 name: "try_sum".to_string(),
349 args: vec![expr],
350 }
351}
352
353pub fn try_avg(expr: ExprIr) -> ExprIr {
354 ExprIr::Call {
355 name: "try_avg".to_string(),
356 args: vec![expr],
357 }
358}
359
360pub fn count_if(expr: ExprIr) -> ExprIr {
361 ExprIr::Call {
362 name: "count_if".to_string(),
363 args: vec![expr],
364 }
365}
366
367pub fn mode(expr: ExprIr) -> ExprIr {
368 ExprIr::Call {
369 name: "mode".to_string(),
370 args: vec![expr],
371 }
372}
373
374pub fn kurtosis(expr: ExprIr) -> ExprIr {
375 ExprIr::Call {
376 name: "kurtosis".to_string(),
377 args: vec![expr],
378 }
379}
380
381pub fn skewness(expr: ExprIr) -> ExprIr {
382 ExprIr::Call {
383 name: "skewness".to_string(),
384 args: vec![expr],
385 }
386}
387
388pub fn alias(expr: ExprIr, name: &str) -> ExprIr {
390 ExprIr::Call {
391 name: "alias".to_string(),
392 args: vec![expr, lit_str(name)],
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn col_builds_column_expr() {
402 let e = col("x");
403 assert!(matches!(e, ExprIr::Column(s) if s == "x"));
404 }
405
406 #[test]
407 fn lit_builders() {
408 assert!(matches!(lit_i64(42), ExprIr::Lit(LiteralValue::I64(42))));
409 assert!(matches!(lit_i32(1), ExprIr::Lit(LiteralValue::I32(1))));
410 assert!(
411 matches!(lit_f64(1.5), ExprIr::Lit(LiteralValue::F64(x)) if (x - 1.5).abs() < 1e-9)
412 );
413 assert!(matches!(lit_str("a"), ExprIr::Lit(LiteralValue::Str(s)) if s == "a"));
414 assert!(matches!(
415 lit_bool(true),
416 ExprIr::Lit(LiteralValue::Bool(true))
417 ));
418 assert!(matches!(lit_null(), ExprIr::Lit(LiteralValue::Null)));
419 }
420
421 #[test]
422 fn call_builds_call_expr() {
423 let e = call("upper", vec![col("name")]);
424 match &e {
425 ExprIr::Call { name, args } => {
426 assert_eq!(name, "upper");
427 assert_eq!(args.len(), 1);
428 assert!(matches!(&args[0], ExprIr::Column(s) if s == "name"));
429 }
430 _ => panic!("expected Call"),
431 }
432 }
433
434 #[test]
435 fn when_then_otherwise_builds_when_expr() {
436 let e = when(col("a")).then(lit_i64(1)).otherwise(lit_i64(0));
437 match &e {
438 ExprIr::When {
439 condition,
440 then_expr,
441 otherwise,
442 } => {
443 assert!(matches!(condition.as_ref(), ExprIr::Column(s) if s == "a"));
444 assert!(matches!(
445 then_expr.as_ref(),
446 ExprIr::Lit(LiteralValue::I64(1))
447 ));
448 assert!(matches!(
449 otherwise.as_ref(),
450 ExprIr::Lit(LiteralValue::I64(0))
451 ));
452 }
453 _ => panic!("expected When"),
454 }
455 }
456
457 #[test]
458 fn binary_ops_build_correct_variants() {
459 let a = col("a");
460 let b = lit_i64(2);
461 assert!(matches!(eq(a.clone(), b.clone()), ExprIr::Eq(_, _)));
462 assert!(matches!(gt(a.clone(), b.clone()), ExprIr::Gt(_, _)));
463 assert!(matches!(and_(a.clone(), b.clone()), ExprIr::And(_, _)));
464 assert!(matches!(or_(a.clone(), b.clone()), ExprIr::Or(_, _)));
465 assert!(matches!(not_(a.clone()), ExprIr::Not(_)));
466 assert!(matches!(is_null(a.clone()), ExprIr::IsNull(_)));
467 }
468
469 #[test]
470 fn between_builds_between_expr() {
471 let e = between(col("x"), lit_i64(0), lit_i64(10));
472 match &e {
473 ExprIr::Between { left, lower, upper } => {
474 assert!(matches!(left.as_ref(), ExprIr::Column(s) if s == "x"));
475 assert!(matches!(lower.as_ref(), ExprIr::Lit(LiteralValue::I64(0))));
476 assert!(matches!(upper.as_ref(), ExprIr::Lit(LiteralValue::I64(10))));
477 }
478 _ => panic!("expected Between"),
479 }
480 }
481
482 #[test]
483 fn agg_builders_build_call() {
484 let e = sum(col("v"));
485 assert!(matches!(e, ExprIr::Call { name, .. } if name == "sum"));
486 let e = count(col("v"));
487 assert!(matches!(e, ExprIr::Call { name, .. } if name == "count"));
488 let e = alias(col("x"), "my_col");
489 assert!(matches!(e, ExprIr::Call { name, args } if name == "alias" && args.len() == 2));
490 }
491}