Skip to main content

tl_data/
translate.rs

1use datafusion::prelude::*;
2use std::collections::HashMap;
3use tl_ast::{BinOp, Expr as AstExpr, UnaryOp};
4
5/// Values that can be used as literals in translated expressions.
6#[derive(Debug, Clone)]
7pub enum LocalValue {
8    Int(i64),
9    Float(f64),
10    String(String),
11    Bool(bool),
12}
13
14/// Context for translating TL AST expressions to DataFusion expressions.
15/// `locals` maps variable names to their runtime values.
16/// Names not in `locals` are treated as column references.
17pub struct TranslateContext {
18    pub locals: HashMap<String, LocalValue>,
19}
20
21impl Default for TranslateContext {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl TranslateContext {
28    pub fn new() -> Self {
29        TranslateContext {
30            locals: HashMap::new(),
31        }
32    }
33}
34
35/// Translate a TL AST expression into a DataFusion Expr.
36///
37/// Resolution rules:
38/// - Identifiers present in `ctx.locals` → `lit(value)`
39/// - Identifiers NOT in `ctx.locals` → `col(name)` (column reference)
40/// - Binary ops → DataFusion binary expressions
41/// - Function calls → aggregate functions (count, sum, avg, min, max) or DataFusion built-in functions
42pub fn translate_expr(ast: &AstExpr, ctx: &TranslateContext) -> Result<Expr, String> {
43    match ast {
44        AstExpr::Int(n) => Ok(lit(*n)),
45        AstExpr::Float(f) => Ok(lit(*f)),
46        AstExpr::String(s) => Ok(lit(s.clone())),
47        AstExpr::Bool(b) => Ok(lit(*b)),
48        AstExpr::None => Ok(lit(datafusion::scalar::ScalarValue::Null)),
49
50        AstExpr::Ident(name) => {
51            if let Some(local) = ctx.locals.get(name) {
52                match local {
53                    LocalValue::Int(n) => Ok(lit(*n)),
54                    LocalValue::Float(f) => Ok(lit(*f)),
55                    LocalValue::String(s) => Ok(lit(s.clone())),
56                    LocalValue::Bool(b) => Ok(lit(*b)),
57                }
58            } else {
59                Ok(col(name.as_str()))
60            }
61        }
62
63        AstExpr::BinOp { left, op, right } => {
64            let l = translate_expr(left, ctx)?;
65            let r = translate_expr(right, ctx)?;
66            match op {
67                BinOp::Add => Ok(l + r),
68                BinOp::Sub => Ok(l - r),
69                BinOp::Mul => Ok(l * r),
70                BinOp::Div => Ok(l / r),
71                BinOp::Mod => Ok(l % r),
72                BinOp::Eq => Ok(l.eq(r)),
73                BinOp::Neq => Ok(l.not_eq(r)),
74                BinOp::Lt => Ok(l.lt(r)),
75                BinOp::Gt => Ok(l.gt(r)),
76                BinOp::Lte => Ok(l.lt_eq(r)),
77                BinOp::Gte => Ok(l.gt_eq(r)),
78                BinOp::And => Ok(l.and(r)),
79                BinOp::Or => Ok(l.or(r)),
80                BinOp::Pow => Err("Power operator not supported in table expressions".into()),
81            }
82        }
83
84        AstExpr::UnaryOp { op, expr } => {
85            let e = translate_expr(expr, ctx)?;
86            match op {
87                UnaryOp::Neg => Ok(Expr::Negative(Box::new(e))),
88                UnaryOp::Not => Ok(e.not()),
89                UnaryOp::Ref => Ok(e), // References are transparent in DataFusion expressions
90            }
91        }
92
93        AstExpr::Call { function, args } => {
94            if let AstExpr::Ident(fname) = function.as_ref() {
95                translate_aggregate_or_function(fname, args, ctx)
96            } else {
97                Err("Only named function calls supported in table expressions".into())
98            }
99        }
100
101        AstExpr::Member { object, field } => {
102            // object.field → col("object.field") — for qualified column names
103            if let AstExpr::Ident(obj_name) = object.as_ref() {
104                Ok(col(format!("{obj_name}.{field}").as_str()))
105            } else {
106                Err("Complex member access not supported in table expressions".into())
107            }
108        }
109
110        _ => Err(format!(
111            "Expression type not supported in table context: {:?}",
112            std::mem::discriminant(ast)
113        )),
114    }
115}
116
117/// Translate aggregate and scalar function calls.
118fn translate_aggregate_or_function(
119    name: &str,
120    args: &[AstExpr],
121    ctx: &TranslateContext,
122) -> Result<Expr, String> {
123    match name {
124        "count" => {
125            if args.is_empty() {
126                Ok(datafusion::functions_aggregate::expr_fn::count(lit(1)))
127            } else {
128                let arg = translate_expr(&args[0], ctx)?;
129                Ok(datafusion::functions_aggregate::expr_fn::count(arg))
130            }
131        }
132        "sum" => {
133            if args.len() != 1 {
134                return Err("sum() requires exactly 1 argument".into());
135            }
136            let arg = translate_expr(&args[0], ctx)?;
137            Ok(datafusion::functions_aggregate::expr_fn::sum(arg))
138        }
139        "avg" => {
140            if args.len() != 1 {
141                return Err("avg() requires exactly 1 argument".into());
142            }
143            let arg = translate_expr(&args[0], ctx)?;
144            Ok(datafusion::functions_aggregate::expr_fn::avg(arg))
145        }
146        "min" => {
147            if args.len() != 1 {
148                return Err("min() requires exactly 1 argument".into());
149            }
150            let arg = translate_expr(&args[0], ctx)?;
151            Ok(datafusion::functions_aggregate::expr_fn::min(arg))
152        }
153        "max" => {
154            if args.len() != 1 {
155                return Err("max() requires exactly 1 argument".into());
156            }
157            let arg = translate_expr(&args[0], ctx)?;
158            Ok(datafusion::functions_aggregate::expr_fn::max(arg))
159        }
160        _ => Err(format!("Unknown function in table expression: {name}")),
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_translate_column_ref() {
170        let ctx = TranslateContext::new();
171        let ast = AstExpr::Ident("age".into());
172        let expr = translate_expr(&ast, &ctx).unwrap();
173        assert_eq!(format!("{expr}"), "age");
174    }
175
176    #[test]
177    fn test_translate_local_literal() {
178        let mut ctx = TranslateContext::new();
179        ctx.locals.insert("threshold".into(), LocalValue::Int(25));
180        let ast = AstExpr::Ident("threshold".into());
181        let expr = translate_expr(&ast, &ctx).unwrap();
182        assert_eq!(format!("{expr}"), "Int64(25)");
183    }
184
185    #[test]
186    fn test_translate_binop() {
187        let ctx = TranslateContext::new();
188        let ast = AstExpr::BinOp {
189            left: Box::new(AstExpr::Ident("age".into())),
190            op: BinOp::Gt,
191            right: Box::new(AstExpr::Int(25)),
192        };
193        let expr = translate_expr(&ast, &ctx).unwrap();
194        assert_eq!(format!("{expr}"), "age > Int64(25)");
195    }
196
197    #[test]
198    fn test_translate_aggregate() {
199        let ctx = TranslateContext::new();
200        let ast = AstExpr::Call {
201            function: Box::new(AstExpr::Ident("sum".into())),
202            args: vec![AstExpr::Ident("amount".into())],
203        };
204        let expr = translate_expr(&ast, &ctx).unwrap();
205        let s = format!("{expr}");
206        assert!(s.contains("sum") || s.contains("SUM"), "Got: {s}");
207    }
208}