use std::collections::HashSet;
use polars::prelude::*;
use serde_json::Value as Json;
#[derive(Debug, thiserror::Error)]
#[error("formula error: {0}")]
pub struct FormulaError(pub String);
fn err<T>(msg: impl Into<String>) -> Result<T, FormulaError> {
Err(FormulaError(msg.into()))
}
const OPS: &[&str] = &["+", "-", "*", "/"];
const CMP_OPS: &[&str] = &[">=", ">", "<=", "<", "==", "!="];
const BOOL_OPS: &[&str] = &["and", "or", "not"];
const NULL_OPS: &[&str] = &["is_null", "is_not_null"];
fn func_arity(name: &str) -> Option<(usize, Option<usize>)> {
Some(match name {
"coalesce" => (1, None),
"nullif" => (2, Some(2)),
"abs" => (1, Some(1)),
"round" => (1, Some(2)),
"least" => (2, None),
"greatest" => (2, None),
"ln" => (1, Some(1)),
"log" => (1, Some(1)),
"exp" => (1, Some(1)),
"floor" => (1, Some(1)),
"ceil" => (1, Some(1)),
"sqrt" => (1, Some(1)),
_ => return None,
})
}
pub fn compile_formula(node: &Json, columns: &HashSet<String>) -> Result<Expr, FormulaError> {
let obj = match node.as_object() {
Some(o) => o,
None => return err(format!("expected an AST node (object), got {node}")),
};
if let Some(c) = obj.get("col") {
let name = c
.as_str()
.ok_or_else(|| FormulaError("col must be a string".into()))?;
if !columns.contains(name) {
return err(format!("unknown column {name:?} (not in source schema)"));
}
return Ok(col(name));
}
if let Some(lit_node) = obj.get("lit") {
return lit_expr(lit_node);
}
if let Some(op) = obj.get("op") {
let op = op.as_str().unwrap_or("");
if OPS.contains(&op) {
let args = compile_args(obj, columns)?;
let mut it = args.into_iter();
let Some(mut acc) = it.next() else {
return err(format!("operator {op:?} needs ≥2 args"));
};
let mut n = 1;
for a in it {
acc = match op {
"+" => acc + a,
"-" => acc - a,
"*" => acc * a,
_ => acc / a,
};
n += 1;
}
if n < 2 {
return err(format!("operator {op:?} needs ≥2 args"));
}
return Ok(acc);
}
if CMP_OPS.contains(&op) {
let args = compile_args(obj, columns)?;
let mut it = args.into_iter();
let (Some(a), Some(b), None) = (it.next(), it.next(), it.next()) else {
return err(format!("comparison {op:?} needs exactly 2 args"));
};
return Ok(match op {
">=" => a.gt_eq(b),
">" => a.gt(b),
"<=" => a.lt_eq(b),
"<" => a.lt(b),
"==" => a.eq(b),
_ => a.neq(b),
});
}
if BOOL_OPS.contains(&op) {
let args = compile_args(obj, columns)?;
if op == "not" {
let mut it = args.into_iter();
let (Some(a), None) = (it.next(), it.next()) else {
return err("`not` needs exactly 1 arg");
};
return Ok(a.not());
}
let mut it = args.into_iter();
let Some(mut acc) = it.next() else {
return err(format!("operator {op:?} needs ≥2 args"));
};
let mut n = 1;
for a in it {
acc = if op == "and" { acc.and(a) } else { acc.or(a) };
n += 1;
}
if n < 2 {
return err(format!("operator {op:?} needs ≥2 args"));
}
return Ok(acc);
}
if NULL_OPS.contains(&op) {
let args = compile_args(obj, columns)?;
let mut it = args.into_iter();
let (Some(a), None) = (it.next(), it.next()) else {
return err(format!("{op:?} needs exactly 1 arg"));
};
return Ok(if op == "is_null" {
a.is_null()
} else {
a.is_not_null()
});
}
return err(format!("operator {op:?} not allowed"));
}
if let Some(f) = obj.get("fn") {
let f = f.as_str().unwrap_or("");
let (lo, hi) = match func_arity(f) {
Some(a) => a,
None => return err(format!("function {f:?} not allowed")),
};
let raw = args_of(obj);
if raw.len() < lo || hi.map(|h| raw.len() > h).unwrap_or(false) {
return err(format!("function {f:?} arity {} out of range", raw.len()));
}
if f == "round" {
let decimals: u32 = match raw.get(1) {
None => 0,
Some(p) => p.get("lit").and_then(|v| v.as_u64()).ok_or_else(|| {
FormulaError("round's 2nd arg must be a literal integer".into())
})? as u32,
};
let first = raw
.first()
.ok_or_else(|| FormulaError("round needs ≥1 arg".into()))?;
let inner = compile_formula(first, columns)?;
return Ok(inner.round(decimals, RoundMode::HalfToEven));
}
let args = raw
.iter()
.map(|n| compile_formula(n, columns))
.collect::<Result<Vec<Expr>, FormulaError>>()?;
return apply_fn(f, args);
}
if let Some(inner) = obj.get("cast") {
let t = obj
.get("as")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_lowercase();
let dt = cast_dtype(&t)?;
let e = compile_formula(inner, columns)?;
return Ok(e.cast(dt));
}
let mut keys: Vec<&String> = obj.keys().collect();
keys.sort();
err(format!("unrecognized AST node: {keys:?}"))
}
pub fn referenced_columns(node: &Json) -> HashSet<String> {
let mut out = HashSet::new();
collect_columns(node, &mut out);
out
}
fn collect_columns(node: &Json, out: &mut HashSet<String>) {
match node {
Json::Object(obj) => {
if let Some(Json::String(name)) = obj.get("col") {
out.insert(name.clone());
}
for (_, v) in obj {
collect_columns(v, out);
}
}
Json::Array(items) => {
for v in items {
collect_columns(v, out);
}
}
_ => {}
}
}
fn compile_args(
obj: &serde_json::Map<String, Json>,
columns: &HashSet<String>,
) -> Result<Vec<Expr>, FormulaError> {
obj.get("args")
.and_then(|a| a.as_array())
.map(|a| a.iter().map(|n| compile_formula(n, columns)).collect())
.unwrap_or_else(|| Ok(vec![]))
}
fn args_of(obj: &serde_json::Map<String, Json>) -> Vec<Json> {
obj.get("args")
.and_then(|a| a.as_array())
.cloned()
.unwrap_or_default()
}
fn apply_fn(name: &str, mut args: Vec<Expr>) -> Result<Expr, FormulaError> {
fn one(a: Vec<Expr>) -> Result<Expr, FormulaError> {
a.into_iter()
.next()
.ok_or_else(|| FormulaError("function needs exactly 1 arg".into()))
}
Ok(match name {
"coalesce" => coalesce(&args),
"nullif" => {
let two = || FormulaError("`nullif` needs exactly 2 args".into());
let b = args.pop().ok_or_else(two)?;
let a = args.pop().ok_or_else(two)?;
when(a.clone().eq(b)).then(lit(NULL)).otherwise(a)
}
"abs" => one(args)?.abs(),
"least" => fold_minmax(args, true)?,
"greatest" => fold_minmax(args, false)?,
"ln" => one(args)?.log(lit(std::f64::consts::E)),
"log" => one(args)?.log(lit(10.0)),
"exp" => one(args)?.exp(),
"floor" => one(args)?.floor(),
"ceil" => one(args)?.ceil(),
"sqrt" => one(args)?.sqrt(),
other => return err(format!("function {other:?} not allowed")),
})
}
fn fold_minmax(args: Vec<Expr>, least: bool) -> Result<Expr, FormulaError> {
let mut it = args.into_iter();
let mut acc = it
.next()
.ok_or_else(|| FormulaError("`least`/`greatest` needs ≥1 arg".into()))?;
for a in it {
let pick_a = if least {
a.clone().lt(acc.clone())
} else {
a.clone().gt(acc.clone())
};
acc = when(pick_a).then(a).otherwise(acc);
}
Ok(acc)
}
fn cast_dtype(t: &str) -> Result<DataType, FormulaError> {
Ok(match t {
"double" | "decimal" => DataType::Float64,
"bigint" => DataType::Int64,
"integer" => DataType::Int32,
"varchar" => DataType::String,
"boolean" => DataType::Boolean,
other => return err(format!("cast type {other:?} not allowed")),
})
}
fn lit_expr(v: &Json) -> Result<Expr, FormulaError> {
Ok(match v {
Json::Null => lit(NULL),
Json::Bool(b) => lit(*b),
Json::Number(n) => {
if let Some(i) = n.as_i64() {
lit(i)
} else {
lit(n.as_f64().unwrap_or(0.0))
}
}
Json::String(s) => lit(s.clone()),
other => return err(format!("unsupported literal {other}")),
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn cols() -> HashSet<String> {
["stars", "forks"].iter().map(|s| s.to_string()).collect()
}
#[test]
fn validates_columns_and_funcs() {
let c = cols();
assert!(compile_formula(&json!({"col": "stars"}), &c).is_ok());
assert!(compile_formula(
&json!({"op": "/", "args": [{"col": "forks"}, {"col": "stars"}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"fn": "coalesce", "args": [{"col": "stars"}, {"lit": 0}]}),
&c
)
.is_ok());
assert!(compile_formula(&json!({"cast": {"col": "stars"}, "as": "double"}), &c).is_ok());
assert!(compile_formula(&json!({"col": "evil; DROP TABLE"}), &c).is_err());
assert!(compile_formula(&json!({"fn": "system", "args": []}), &c).is_err());
assert!(compile_formula(
&json!({"op": "%", "args": [{"col": "stars"}, {"lit": 2}]}),
&c
)
.is_err());
assert!(compile_formula(&json!({"cast": {"col": "stars"}, "as": "evil"}), &c).is_err());
}
#[test]
fn malformed_arity_errors_never_panics() {
let c = cols();
let bad = [
json!({"op": "+", "args": [{"col": "stars"}]}), json!({"op": ">=", "args": [{"col": "stars"}]}), json!({"op": ">=", "args": [{"col": "stars"}, {"lit": 1}, {"lit": 2}]}),
json!({"op": "and", "args": [{"col": "stars"}]}), json!({"op": "not", "args": [{"col": "stars"}, {"col": "forks"}]}), json!({"op": "is_null", "args": [{"col": "stars"}, {"col": "forks"}]}),
json!({"fn": "abs", "args": []}), json!({"fn": "least", "args": []}), json!({"fn": "round", "args": []}), json!({"fn": "nullif", "args": [{"col": "stars"}]}),
];
for node in bad {
assert!(
compile_formula(&node, &c).is_err(),
"expected Err (not a panic) for {node}"
);
}
}
#[test]
fn comparison_and_boolean_predicates_compile() {
let c = cols();
assert!(compile_formula(
&json!({"op": ">=", "args": [{"col": "stars"}, {"lit": 1e7}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"op": "<", "args": [{"col": "forks"}, {"lit": 5}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"op": "==", "args": [{"col": "stars"}, {"col": "forks"}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"op": "and", "args": [
{"op": ">=", "args": [{"col": "stars"}, {"lit": 100}]},
{"op": "<", "args": [{"col": "forks"}, {"lit": 1000}]}
]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"op": "not", "args": [{"op": ">", "args": [{"col": "stars"}, {"lit": 0}]}]}),
&c
)
.is_ok());
assert!(compile_formula(&json!({"op": ">=", "args": [{"col": "stars"}]}), &c).is_err());
assert!(compile_formula(
&json!({"op": "not", "args": [{"col": "stars"}, {"col": "forks"}]}),
&c
)
.is_err());
assert!(compile_formula(
&json!({"op": "%", "args": [{"col": "stars"}, {"lit": 2}]}),
&c
)
.is_err());
}
#[test]
fn referenced_columns_walks_the_ast() {
let ast = json!({"op": "and", "args": [
{"op": ">=", "args": [{"col": "mcap_usd"}, {"lit": 1e7}]},
{"op": "<", "args": [{"col": "founding_year"}, {"lit": 2000}]}
]});
let got = referenced_columns(&ast);
let mut got: Vec<String> = got.into_iter().collect();
got.sort();
assert_eq!(
got,
vec!["founding_year".to_string(), "mcap_usd".to_string()]
);
assert_eq!(
referenced_columns(&json!({"col": "x"})),
["x".to_string()].into_iter().collect()
);
assert!(referenced_columns(&json!({"lit": 1})).is_empty());
}
#[test]
fn null_predicates_compile_and_filter() {
let c = cols();
assert!(compile_formula(&json!({"op": "is_null", "args": [{"col": "stars"}]}), &c).is_ok());
assert!(compile_formula(
&json!({"op": "is_not_null", "args": [{"col": "stars"}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"op": "is_null", "args": [{"col": "stars"}, {"col": "forks"}]}),
&c
)
.is_err());
assert!(compile_formula(&json!({"op": "is_not_null", "args": []}), &c).is_err());
let df = df!["stars" => &[Some(1i64), None, Some(3)]].unwrap();
let pred = compile_formula(
&json!({"op": "is_not_null", "args": [{"col": "stars"}]}),
&c,
)
.unwrap();
let kept = df
.lazy()
.filter(pred)
.collect()
.unwrap()
.column("stars")
.unwrap()
.i64()
.unwrap()
.into_no_null_iter()
.collect::<Vec<i64>>();
assert_eq!(kept, vec![1, 3]);
}
#[test]
fn round_arg_handling() {
let c = cols();
assert!(compile_formula(&json!({"fn": "round", "args": [{"col": "stars"}]}), &c).is_ok());
assert!(compile_formula(
&json!({"fn": "round", "args": [{"col": "stars"}, {"lit": 2}]}),
&c
)
.is_ok());
assert!(compile_formula(
&json!({"fn": "round", "args": [{"col": "stars"}, {"col": "forks"}]}),
&c
)
.is_err());
}
}