Skip to main content

robin_sparkless_core/
expr.rs

1//! Engine-agnostic expression IR. All backends interpret this; root and core only use ExprIr.
2
3use serde::{Deserialize, Serialize};
4
5/// Literal value in an expression (engine-agnostic).
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum LiteralValue {
8    I64(i64),
9    F64(f64),
10    I32(i32),
11    Str(String),
12    Bool(bool),
13    Null,
14}
15
16/// Expression IR: a single, serializable tree that backends convert to their native Expr.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub enum ExprIr {
19    /// Column reference: `col("name")`
20    Column(String),
21    /// Literal value
22    Lit(LiteralValue),
23
24    // --- Binary comparison (left, right) ---
25    Eq(Box<ExprIr>, Box<ExprIr>),
26    Ne(Box<ExprIr>, Box<ExprIr>),
27    Gt(Box<ExprIr>, Box<ExprIr>),
28    Ge(Box<ExprIr>, Box<ExprIr>),
29    Lt(Box<ExprIr>, Box<ExprIr>),
30    Le(Box<ExprIr>, Box<ExprIr>),
31    EqNullSafe(Box<ExprIr>, Box<ExprIr>),
32
33    // --- Logical ---
34    And(Box<ExprIr>, Box<ExprIr>),
35    Or(Box<ExprIr>, Box<ExprIr>),
36    Not(Box<ExprIr>),
37
38    // --- Arithmetic ---
39    Add(Box<ExprIr>, Box<ExprIr>),
40    Sub(Box<ExprIr>, Box<ExprIr>),
41    Mul(Box<ExprIr>, Box<ExprIr>),
42    Div(Box<ExprIr>, Box<ExprIr>),
43
44    // --- Other binary ---
45    Between {
46        left: Box<ExprIr>,
47        lower: Box<ExprIr>,
48        upper: Box<ExprIr>,
49    },
50    IsIn(Box<ExprIr>, Box<ExprIr>),
51
52    // --- Unary ---
53    IsNull(Box<ExprIr>),
54    IsNotNull(Box<ExprIr>),
55
56    // --- Conditional ---
57    When {
58        condition: Box<ExprIr>,
59        then_expr: Box<ExprIr>,
60        otherwise: Box<ExprIr>,
61    },
62
63    /// Function call: name and args (e.g. sum, count, upper, substring, cast).
64    Call {
65        name: String,
66        args: Vec<ExprIr>,
67    },
68}
69
70// ---------- Builder helpers ----------
71
72/// Column reference.
73pub fn col(name: &str) -> ExprIr {
74    ExprIr::Column(name.to_string())
75}
76
77pub fn lit_i64(n: i64) -> ExprIr {
78    ExprIr::Lit(LiteralValue::I64(n))
79}
80
81pub fn lit_i32(n: i32) -> ExprIr {
82    ExprIr::Lit(LiteralValue::I32(n))
83}
84
85pub fn lit_f64(n: f64) -> ExprIr {
86    ExprIr::Lit(LiteralValue::F64(n))
87}
88
89pub fn lit_str(s: &str) -> ExprIr {
90    ExprIr::Lit(LiteralValue::Str(s.to_string()))
91}
92
93pub fn lit_bool(b: bool) -> ExprIr {
94    ExprIr::Lit(LiteralValue::Bool(b))
95}
96
97pub fn lit_null() -> ExprIr {
98    ExprIr::Lit(LiteralValue::Null)
99}
100
101/// Generic function call (for the long tail of functions).
102pub fn call(name: &str, args: Vec<ExprIr>) -> ExprIr {
103    ExprIr::Call {
104        name: name.to_string(),
105        args,
106    }
107}
108
109/// When-then-otherwise builder.
110pub struct WhenBuilder {
111    condition: ExprIr,
112}
113
114impl WhenBuilder {
115    pub fn then(self, then_expr: ExprIr) -> WhenThenBuilder {
116        WhenThenBuilder {
117            condition: self.condition,
118            then_expr,
119        }
120    }
121}
122
123pub struct WhenThenBuilder {
124    condition: ExprIr,
125    then_expr: ExprIr,
126}
127
128impl WhenThenBuilder {
129    pub fn otherwise(self, otherwise: ExprIr) -> ExprIr {
130        ExprIr::When {
131            condition: Box::new(self.condition),
132            then_expr: Box::new(self.then_expr),
133            otherwise: Box::new(otherwise),
134        }
135    }
136}
137
138/// Start a when(condition).then(...).otherwise(...) chain.
139pub fn when(condition: ExprIr) -> WhenBuilder {
140    WhenBuilder { condition }
141}
142
143// ---------- Common binary ops as ExprIr builders ----------
144
145pub fn eq(a: ExprIr, b: ExprIr) -> ExprIr {
146    ExprIr::Eq(Box::new(a), Box::new(b))
147}
148
149pub fn ne(a: ExprIr, b: ExprIr) -> ExprIr {
150    ExprIr::Ne(Box::new(a), Box::new(b))
151}
152
153pub fn gt(a: ExprIr, b: ExprIr) -> ExprIr {
154    ExprIr::Gt(Box::new(a), Box::new(b))
155}
156
157pub fn ge(a: ExprIr, b: ExprIr) -> ExprIr {
158    ExprIr::Ge(Box::new(a), Box::new(b))
159}
160
161pub fn lt(a: ExprIr, b: ExprIr) -> ExprIr {
162    ExprIr::Lt(Box::new(a), Box::new(b))
163}
164
165pub fn le(a: ExprIr, b: ExprIr) -> ExprIr {
166    ExprIr::Le(Box::new(a), Box::new(b))
167}
168
169pub fn and_(a: ExprIr, b: ExprIr) -> ExprIr {
170    ExprIr::And(Box::new(a), Box::new(b))
171}
172
173pub fn or_(a: ExprIr, b: ExprIr) -> ExprIr {
174    ExprIr::Or(Box::new(a), Box::new(b))
175}
176
177pub fn not_(a: ExprIr) -> ExprIr {
178    ExprIr::Not(Box::new(a))
179}
180
181pub fn is_null(a: ExprIr) -> ExprIr {
182    ExprIr::IsNull(Box::new(a))
183}
184
185pub fn between(left: ExprIr, lower: ExprIr, upper: ExprIr) -> ExprIr {
186    ExprIr::Between {
187        left: Box::new(left),
188        lower: Box::new(lower),
189        upper: Box::new(upper),
190    }
191}
192
193pub fn is_in(left: ExprIr, right: ExprIr) -> ExprIr {
194    ExprIr::IsIn(Box::new(left), Box::new(right))
195}
196
197// ---------- Aggregation builders (ExprIr::Call) ----------
198
199pub fn sum(expr: ExprIr) -> ExprIr {
200    ExprIr::Call {
201        name: "sum".to_string(),
202        args: vec![expr],
203    }
204}
205
206pub fn count(expr: ExprIr) -> ExprIr {
207    ExprIr::Call {
208        name: "count".to_string(),
209        args: vec![expr],
210    }
211}
212
213pub fn min(expr: ExprIr) -> ExprIr {
214    ExprIr::Call {
215        name: "min".to_string(),
216        args: vec![expr],
217    }
218}
219
220pub fn max(expr: ExprIr) -> ExprIr {
221    ExprIr::Call {
222        name: "max".to_string(),
223        args: vec![expr],
224    }
225}
226
227pub fn mean(expr: ExprIr) -> ExprIr {
228    ExprIr::Call {
229        name: "mean".to_string(),
230        args: vec![expr],
231    }
232}
233
234pub fn first(expr: ExprIr) -> ExprIr {
235    ExprIr::Call {
236        name: "first".to_string(),
237        args: vec![expr],
238    }
239}
240
241pub fn last(expr: ExprIr) -> ExprIr {
242    ExprIr::Call {
243        name: "last".to_string(),
244        args: vec![expr],
245    }
246}
247
248pub fn stddev(expr: ExprIr) -> ExprIr {
249    ExprIr::Call {
250        name: "stddev".to_string(),
251        args: vec![expr],
252    }
253}
254
255pub fn stddev_pop(expr: ExprIr) -> ExprIr {
256    ExprIr::Call {
257        name: "stddev_pop".to_string(),
258        args: vec![expr],
259    }
260}
261
262pub fn std(expr: ExprIr) -> ExprIr {
263    ExprIr::Call {
264        name: "std".to_string(),
265        args: vec![expr],
266    }
267}
268
269pub fn stddev_samp(expr: ExprIr) -> ExprIr {
270    ExprIr::Call {
271        name: "stddev_samp".to_string(),
272        args: vec![expr],
273    }
274}
275
276pub fn variance(expr: ExprIr) -> ExprIr {
277    ExprIr::Call {
278        name: "variance".to_string(),
279        args: vec![expr],
280    }
281}
282
283pub fn var_pop(expr: ExprIr) -> ExprIr {
284    ExprIr::Call {
285        name: "var_pop".to_string(),
286        args: vec![expr],
287    }
288}
289
290pub fn var_samp(expr: ExprIr) -> ExprIr {
291    ExprIr::Call {
292        name: "var_samp".to_string(),
293        args: vec![expr],
294    }
295}
296
297pub fn count_distinct(expr: ExprIr) -> ExprIr {
298    ExprIr::Call {
299        name: "count_distinct".to_string(),
300        args: vec![expr],
301    }
302}
303
304pub fn approx_count_distinct(expr: ExprIr) -> ExprIr {
305    ExprIr::Call {
306        name: "approx_count_distinct".to_string(),
307        args: vec![expr],
308    }
309}
310
311pub fn collect_list(expr: ExprIr) -> ExprIr {
312    ExprIr::Call {
313        name: "collect_list".to_string(),
314        args: vec![expr],
315    }
316}
317
318pub fn collect_set(expr: ExprIr) -> ExprIr {
319    ExprIr::Call {
320        name: "collect_set".to_string(),
321        args: vec![expr],
322    }
323}
324
325pub fn bool_and(expr: ExprIr) -> ExprIr {
326    ExprIr::Call {
327        name: "bool_and".to_string(),
328        args: vec![expr],
329    }
330}
331
332pub fn every(expr: ExprIr) -> ExprIr {
333    ExprIr::Call {
334        name: "every".to_string(),
335        args: vec![expr],
336    }
337}
338
339pub fn median(expr: ExprIr) -> ExprIr {
340    ExprIr::Call {
341        name: "median".to_string(),
342        args: vec![expr],
343    }
344}
345
346pub fn try_sum(expr: ExprIr) -> ExprIr {
347    ExprIr::Call {
348        name: "try_sum".to_string(),
349        args: vec![expr],
350    }
351}
352
353pub fn try_avg(expr: ExprIr) -> ExprIr {
354    ExprIr::Call {
355        name: "try_avg".to_string(),
356        args: vec![expr],
357    }
358}
359
360pub fn count_if(expr: ExprIr) -> ExprIr {
361    ExprIr::Call {
362        name: "count_if".to_string(),
363        args: vec![expr],
364    }
365}
366
367pub fn mode(expr: ExprIr) -> ExprIr {
368    ExprIr::Call {
369        name: "mode".to_string(),
370        args: vec![expr],
371    }
372}
373
374pub fn kurtosis(expr: ExprIr) -> ExprIr {
375    ExprIr::Call {
376        name: "kurtosis".to_string(),
377        args: vec![expr],
378    }
379}
380
381pub fn skewness(expr: ExprIr) -> ExprIr {
382    ExprIr::Call {
383        name: "skewness".to_string(),
384        args: vec![expr],
385    }
386}
387
388/// Alias an expression with a new output name.
389pub fn alias(expr: ExprIr, name: &str) -> ExprIr {
390    ExprIr::Call {
391        name: "alias".to_string(),
392        args: vec![expr, lit_str(name)],
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn col_builds_column_expr() {
402        let e = col("x");
403        assert!(matches!(e, ExprIr::Column(s) if s == "x"));
404    }
405
406    #[test]
407    fn lit_builders() {
408        assert!(matches!(lit_i64(42), ExprIr::Lit(LiteralValue::I64(42))));
409        assert!(matches!(lit_i32(1), ExprIr::Lit(LiteralValue::I32(1))));
410        assert!(
411            matches!(lit_f64(1.5), ExprIr::Lit(LiteralValue::F64(x)) if (x - 1.5).abs() < 1e-9)
412        );
413        assert!(matches!(lit_str("a"), ExprIr::Lit(LiteralValue::Str(s)) if s == "a"));
414        assert!(matches!(
415            lit_bool(true),
416            ExprIr::Lit(LiteralValue::Bool(true))
417        ));
418        assert!(matches!(lit_null(), ExprIr::Lit(LiteralValue::Null)));
419    }
420
421    #[test]
422    fn call_builds_call_expr() {
423        let e = call("upper", vec![col("name")]);
424        match &e {
425            ExprIr::Call { name, args } => {
426                assert_eq!(name, "upper");
427                assert_eq!(args.len(), 1);
428                assert!(matches!(&args[0], ExprIr::Column(s) if s == "name"));
429            }
430            _ => panic!("expected Call"),
431        }
432    }
433
434    #[test]
435    fn when_then_otherwise_builds_when_expr() {
436        let e = when(col("a")).then(lit_i64(1)).otherwise(lit_i64(0));
437        match &e {
438            ExprIr::When {
439                condition,
440                then_expr,
441                otherwise,
442            } => {
443                assert!(matches!(condition.as_ref(), ExprIr::Column(s) if s == "a"));
444                assert!(matches!(
445                    then_expr.as_ref(),
446                    ExprIr::Lit(LiteralValue::I64(1))
447                ));
448                assert!(matches!(
449                    otherwise.as_ref(),
450                    ExprIr::Lit(LiteralValue::I64(0))
451                ));
452            }
453            _ => panic!("expected When"),
454        }
455    }
456
457    #[test]
458    fn binary_ops_build_correct_variants() {
459        let a = col("a");
460        let b = lit_i64(2);
461        assert!(matches!(eq(a.clone(), b.clone()), ExprIr::Eq(_, _)));
462        assert!(matches!(gt(a.clone(), b.clone()), ExprIr::Gt(_, _)));
463        assert!(matches!(and_(a.clone(), b.clone()), ExprIr::And(_, _)));
464        assert!(matches!(or_(a.clone(), b.clone()), ExprIr::Or(_, _)));
465        assert!(matches!(not_(a.clone()), ExprIr::Not(_)));
466        assert!(matches!(is_null(a.clone()), ExprIr::IsNull(_)));
467    }
468
469    #[test]
470    fn between_builds_between_expr() {
471        let e = between(col("x"), lit_i64(0), lit_i64(10));
472        match &e {
473            ExprIr::Between { left, lower, upper } => {
474                assert!(matches!(left.as_ref(), ExprIr::Column(s) if s == "x"));
475                assert!(matches!(lower.as_ref(), ExprIr::Lit(LiteralValue::I64(0))));
476                assert!(matches!(upper.as_ref(), ExprIr::Lit(LiteralValue::I64(10))));
477            }
478            _ => panic!("expected Between"),
479        }
480    }
481
482    #[test]
483    fn agg_builders_build_call() {
484        let e = sum(col("v"));
485        assert!(matches!(e, ExprIr::Call { name, .. } if name == "sum"));
486        let e = count(col("v"));
487        assert!(matches!(e, ExprIr::Call { name, .. } if name == "count"));
488        let e = alias(col("x"), "my_col");
489        assert!(matches!(e, ExprIr::Call { name, args } if name == "alias" && args.len() == 2));
490    }
491}