use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LiteralValue {
I64(i64),
F64(f64),
I32(i32),
Str(String),
Bool(bool),
Null,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExprIr {
Column(String),
Lit(LiteralValue),
Eq(Box<ExprIr>, Box<ExprIr>),
Ne(Box<ExprIr>, Box<ExprIr>),
Gt(Box<ExprIr>, Box<ExprIr>),
Ge(Box<ExprIr>, Box<ExprIr>),
Lt(Box<ExprIr>, Box<ExprIr>),
Le(Box<ExprIr>, Box<ExprIr>),
EqNullSafe(Box<ExprIr>, Box<ExprIr>),
And(Box<ExprIr>, Box<ExprIr>),
Or(Box<ExprIr>, Box<ExprIr>),
Not(Box<ExprIr>),
Add(Box<ExprIr>, Box<ExprIr>),
Sub(Box<ExprIr>, Box<ExprIr>),
Mul(Box<ExprIr>, Box<ExprIr>),
Div(Box<ExprIr>, Box<ExprIr>),
Between {
left: Box<ExprIr>,
lower: Box<ExprIr>,
upper: Box<ExprIr>,
},
IsIn(Box<ExprIr>, Box<ExprIr>),
IsNull(Box<ExprIr>),
IsNotNull(Box<ExprIr>),
When {
condition: Box<ExprIr>,
then_expr: Box<ExprIr>,
otherwise: Box<ExprIr>,
},
Call {
name: String,
args: Vec<ExprIr>,
},
}
pub fn col(name: &str) -> ExprIr {
ExprIr::Column(name.to_string())
}
pub fn lit_i64(n: i64) -> ExprIr {
ExprIr::Lit(LiteralValue::I64(n))
}
pub fn lit_i32(n: i32) -> ExprIr {
ExprIr::Lit(LiteralValue::I32(n))
}
pub fn lit_f64(n: f64) -> ExprIr {
ExprIr::Lit(LiteralValue::F64(n))
}
pub fn lit_str(s: &str) -> ExprIr {
ExprIr::Lit(LiteralValue::Str(s.to_string()))
}
pub fn lit_bool(b: bool) -> ExprIr {
ExprIr::Lit(LiteralValue::Bool(b))
}
pub fn lit_null() -> ExprIr {
ExprIr::Lit(LiteralValue::Null)
}
pub fn call(name: &str, args: Vec<ExprIr>) -> ExprIr {
ExprIr::Call {
name: name.to_string(),
args,
}
}
pub struct WhenBuilder {
condition: ExprIr,
}
impl WhenBuilder {
pub fn then(self, then_expr: ExprIr) -> WhenThenBuilder {
WhenThenBuilder {
condition: self.condition,
then_expr,
}
}
}
pub struct WhenThenBuilder {
condition: ExprIr,
then_expr: ExprIr,
}
impl WhenThenBuilder {
pub fn otherwise(self, otherwise: ExprIr) -> ExprIr {
ExprIr::When {
condition: Box::new(self.condition),
then_expr: Box::new(self.then_expr),
otherwise: Box::new(otherwise),
}
}
}
pub fn when(condition: ExprIr) -> WhenBuilder {
WhenBuilder { condition }
}
pub fn eq(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Eq(Box::new(a), Box::new(b))
}
pub fn ne(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Ne(Box::new(a), Box::new(b))
}
pub fn gt(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Gt(Box::new(a), Box::new(b))
}
pub fn ge(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Ge(Box::new(a), Box::new(b))
}
pub fn lt(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Lt(Box::new(a), Box::new(b))
}
pub fn le(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Le(Box::new(a), Box::new(b))
}
pub fn and_(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::And(Box::new(a), Box::new(b))
}
pub fn or_(a: ExprIr, b: ExprIr) -> ExprIr {
ExprIr::Or(Box::new(a), Box::new(b))
}
pub fn not_(a: ExprIr) -> ExprIr {
ExprIr::Not(Box::new(a))
}
pub fn is_null(a: ExprIr) -> ExprIr {
ExprIr::IsNull(Box::new(a))
}
pub fn between(left: ExprIr, lower: ExprIr, upper: ExprIr) -> ExprIr {
ExprIr::Between {
left: Box::new(left),
lower: Box::new(lower),
upper: Box::new(upper),
}
}
pub fn is_in(left: ExprIr, right: ExprIr) -> ExprIr {
ExprIr::IsIn(Box::new(left), Box::new(right))
}
pub fn sum(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "sum".to_string(),
args: vec![expr],
}
}
pub fn count(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "count".to_string(),
args: vec![expr],
}
}
pub fn min(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "min".to_string(),
args: vec![expr],
}
}
pub fn max(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "max".to_string(),
args: vec![expr],
}
}
pub fn mean(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "mean".to_string(),
args: vec![expr],
}
}
pub fn first(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "first".to_string(),
args: vec![expr],
}
}
pub fn last(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "last".to_string(),
args: vec![expr],
}
}
pub fn stddev(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "stddev".to_string(),
args: vec![expr],
}
}
pub fn stddev_pop(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "stddev_pop".to_string(),
args: vec![expr],
}
}
pub fn std(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "std".to_string(),
args: vec![expr],
}
}
pub fn stddev_samp(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "stddev_samp".to_string(),
args: vec![expr],
}
}
pub fn variance(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "variance".to_string(),
args: vec![expr],
}
}
pub fn var_pop(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "var_pop".to_string(),
args: vec![expr],
}
}
pub fn var_samp(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "var_samp".to_string(),
args: vec![expr],
}
}
pub fn count_distinct(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "count_distinct".to_string(),
args: vec![expr],
}
}
pub fn approx_count_distinct(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "approx_count_distinct".to_string(),
args: vec![expr],
}
}
pub fn collect_list(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "collect_list".to_string(),
args: vec![expr],
}
}
pub fn collect_set(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "collect_set".to_string(),
args: vec![expr],
}
}
pub fn bool_and(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "bool_and".to_string(),
args: vec![expr],
}
}
pub fn every(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "every".to_string(),
args: vec![expr],
}
}
pub fn median(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "median".to_string(),
args: vec![expr],
}
}
pub fn try_sum(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "try_sum".to_string(),
args: vec![expr],
}
}
pub fn try_avg(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "try_avg".to_string(),
args: vec![expr],
}
}
pub fn count_if(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "count_if".to_string(),
args: vec![expr],
}
}
pub fn mode(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "mode".to_string(),
args: vec![expr],
}
}
pub fn kurtosis(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "kurtosis".to_string(),
args: vec![expr],
}
}
pub fn skewness(expr: ExprIr) -> ExprIr {
ExprIr::Call {
name: "skewness".to_string(),
args: vec![expr],
}
}
pub fn alias(expr: ExprIr, name: &str) -> ExprIr {
ExprIr::Call {
name: "alias".to_string(),
args: vec![expr, lit_str(name)],
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn col_builds_column_expr() {
let e = col("x");
assert!(matches!(e, ExprIr::Column(s) if s == "x"));
}
#[test]
fn lit_builders() {
assert!(matches!(lit_i64(42), ExprIr::Lit(LiteralValue::I64(42))));
assert!(matches!(lit_i32(1), ExprIr::Lit(LiteralValue::I32(1))));
assert!(
matches!(lit_f64(1.5), ExprIr::Lit(LiteralValue::F64(x)) if (x - 1.5).abs() < 1e-9)
);
assert!(matches!(lit_str("a"), ExprIr::Lit(LiteralValue::Str(s)) if s == "a"));
assert!(matches!(
lit_bool(true),
ExprIr::Lit(LiteralValue::Bool(true))
));
assert!(matches!(lit_null(), ExprIr::Lit(LiteralValue::Null)));
}
#[test]
fn call_builds_call_expr() {
let e = call("upper", vec![col("name")]);
match &e {
ExprIr::Call { name, args } => {
assert_eq!(name, "upper");
assert_eq!(args.len(), 1);
assert!(matches!(&args[0], ExprIr::Column(s) if s == "name"));
}
_ => panic!("expected Call"),
}
}
#[test]
fn when_then_otherwise_builds_when_expr() {
let e = when(col("a")).then(lit_i64(1)).otherwise(lit_i64(0));
match &e {
ExprIr::When {
condition,
then_expr,
otherwise,
} => {
assert!(matches!(condition.as_ref(), ExprIr::Column(s) if s == "a"));
assert!(matches!(
then_expr.as_ref(),
ExprIr::Lit(LiteralValue::I64(1))
));
assert!(matches!(
otherwise.as_ref(),
ExprIr::Lit(LiteralValue::I64(0))
));
}
_ => panic!("expected When"),
}
}
#[test]
fn binary_ops_build_correct_variants() {
let a = col("a");
let b = lit_i64(2);
assert!(matches!(eq(a.clone(), b.clone()), ExprIr::Eq(_, _)));
assert!(matches!(gt(a.clone(), b.clone()), ExprIr::Gt(_, _)));
assert!(matches!(and_(a.clone(), b.clone()), ExprIr::And(_, _)));
assert!(matches!(or_(a.clone(), b.clone()), ExprIr::Or(_, _)));
assert!(matches!(not_(a.clone()), ExprIr::Not(_)));
assert!(matches!(is_null(a.clone()), ExprIr::IsNull(_)));
}
#[test]
fn between_builds_between_expr() {
let e = between(col("x"), lit_i64(0), lit_i64(10));
match &e {
ExprIr::Between { left, lower, upper } => {
assert!(matches!(left.as_ref(), ExprIr::Column(s) if s == "x"));
assert!(matches!(lower.as_ref(), ExprIr::Lit(LiteralValue::I64(0))));
assert!(matches!(upper.as_ref(), ExprIr::Lit(LiteralValue::I64(10))));
}
_ => panic!("expected Between"),
}
}
#[test]
fn agg_builders_build_call() {
let e = sum(col("v"));
assert!(matches!(e, ExprIr::Call { name, .. } if name == "sum"));
let e = count(col("v"));
assert!(matches!(e, ExprIr::Call { name, .. } if name == "count"));
let e = alias(col("x"), "my_col");
assert!(matches!(e, ExprIr::Call { name, args } if name == "alias" && args.len() == 2));
}
}