1use datafusion::prelude::*;
2use std::collections::HashMap;
3use tl_ast::{BinOp, Expr as AstExpr, UnaryOp};
4
5#[derive(Debug, Clone)]
7pub enum LocalValue {
8 Int(i64),
9 Float(f64),
10 String(String),
11 Bool(bool),
12}
13
14pub struct TranslateContext {
18 pub locals: HashMap<String, LocalValue>,
19}
20
21impl Default for TranslateContext {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl TranslateContext {
28 pub fn new() -> Self {
29 TranslateContext {
30 locals: HashMap::new(),
31 }
32 }
33}
34
35pub fn translate_expr(ast: &AstExpr, ctx: &TranslateContext) -> Result<Expr, String> {
43 match ast {
44 AstExpr::Int(n) => Ok(lit(*n)),
45 AstExpr::Float(f) => Ok(lit(*f)),
46 AstExpr::String(s) => Ok(lit(s.clone())),
47 AstExpr::Bool(b) => Ok(lit(*b)),
48 AstExpr::None => Ok(lit(datafusion::scalar::ScalarValue::Null)),
49
50 AstExpr::Ident(name) => {
51 if let Some(local) = ctx.locals.get(name) {
52 match local {
53 LocalValue::Int(n) => Ok(lit(*n)),
54 LocalValue::Float(f) => Ok(lit(*f)),
55 LocalValue::String(s) => Ok(lit(s.clone())),
56 LocalValue::Bool(b) => Ok(lit(*b)),
57 }
58 } else {
59 Ok(col(name.as_str()))
60 }
61 }
62
63 AstExpr::BinOp { left, op, right } => {
64 let l = translate_expr(left, ctx)?;
65 let r = translate_expr(right, ctx)?;
66 match op {
67 BinOp::Add => Ok(l + r),
68 BinOp::Sub => Ok(l - r),
69 BinOp::Mul => Ok(l * r),
70 BinOp::Div => Ok(l / r),
71 BinOp::Mod => Ok(l % r),
72 BinOp::Eq => Ok(l.eq(r)),
73 BinOp::Neq => Ok(l.not_eq(r)),
74 BinOp::Lt => Ok(l.lt(r)),
75 BinOp::Gt => Ok(l.gt(r)),
76 BinOp::Lte => Ok(l.lt_eq(r)),
77 BinOp::Gte => Ok(l.gt_eq(r)),
78 BinOp::And => Ok(l.and(r)),
79 BinOp::Or => Ok(l.or(r)),
80 BinOp::Pow => Err("Power operator not supported in table expressions".into()),
81 }
82 }
83
84 AstExpr::UnaryOp { op, expr } => {
85 let e = translate_expr(expr, ctx)?;
86 match op {
87 UnaryOp::Neg => Ok(Expr::Negative(Box::new(e))),
88 UnaryOp::Not => Ok(e.not()),
89 UnaryOp::Ref => Ok(e), }
91 }
92
93 AstExpr::Call { function, args } => {
94 if let AstExpr::Ident(fname) = function.as_ref() {
95 translate_aggregate_or_function(fname, args, ctx)
96 } else {
97 Err("Only named function calls supported in table expressions".into())
98 }
99 }
100
101 AstExpr::Member { object, field } => {
102 if let AstExpr::Ident(obj_name) = object.as_ref() {
104 Ok(col(format!("{obj_name}.{field}").as_str()))
105 } else {
106 Err("Complex member access not supported in table expressions".into())
107 }
108 }
109
110 _ => Err(format!(
111 "Expression type not supported in table context: {:?}",
112 std::mem::discriminant(ast)
113 )),
114 }
115}
116
117fn translate_aggregate_or_function(
119 name: &str,
120 args: &[AstExpr],
121 ctx: &TranslateContext,
122) -> Result<Expr, String> {
123 match name {
124 "count" => {
125 if args.is_empty() {
126 Ok(datafusion::functions_aggregate::expr_fn::count(lit(1)))
127 } else {
128 let arg = translate_expr(&args[0], ctx)?;
129 Ok(datafusion::functions_aggregate::expr_fn::count(arg))
130 }
131 }
132 "sum" => {
133 if args.len() != 1 {
134 return Err("sum() requires exactly 1 argument".into());
135 }
136 let arg = translate_expr(&args[0], ctx)?;
137 Ok(datafusion::functions_aggregate::expr_fn::sum(arg))
138 }
139 "avg" => {
140 if args.len() != 1 {
141 return Err("avg() requires exactly 1 argument".into());
142 }
143 let arg = translate_expr(&args[0], ctx)?;
144 Ok(datafusion::functions_aggregate::expr_fn::avg(arg))
145 }
146 "min" => {
147 if args.len() != 1 {
148 return Err("min() requires exactly 1 argument".into());
149 }
150 let arg = translate_expr(&args[0], ctx)?;
151 Ok(datafusion::functions_aggregate::expr_fn::min(arg))
152 }
153 "max" => {
154 if args.len() != 1 {
155 return Err("max() requires exactly 1 argument".into());
156 }
157 let arg = translate_expr(&args[0], ctx)?;
158 Ok(datafusion::functions_aggregate::expr_fn::max(arg))
159 }
160 _ => Err(format!("Unknown function in table expression: {name}")),
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_translate_column_ref() {
170 let ctx = TranslateContext::new();
171 let ast = AstExpr::Ident("age".into());
172 let expr = translate_expr(&ast, &ctx).unwrap();
173 assert_eq!(format!("{expr}"), "age");
174 }
175
176 #[test]
177 fn test_translate_local_literal() {
178 let mut ctx = TranslateContext::new();
179 ctx.locals.insert("threshold".into(), LocalValue::Int(25));
180 let ast = AstExpr::Ident("threshold".into());
181 let expr = translate_expr(&ast, &ctx).unwrap();
182 assert_eq!(format!("{expr}"), "Int64(25)");
183 }
184
185 #[test]
186 fn test_translate_binop() {
187 let ctx = TranslateContext::new();
188 let ast = AstExpr::BinOp {
189 left: Box::new(AstExpr::Ident("age".into())),
190 op: BinOp::Gt,
191 right: Box::new(AstExpr::Int(25)),
192 };
193 let expr = translate_expr(&ast, &ctx).unwrap();
194 assert_eq!(format!("{expr}"), "age > Int64(25)");
195 }
196
197 #[test]
198 fn test_translate_aggregate() {
199 let ctx = TranslateContext::new();
200 let ast = AstExpr::Call {
201 function: Box::new(AstExpr::Ident("sum".into())),
202 args: vec![AstExpr::Ident("amount".into())],
203 };
204 let expr = translate_expr(&ast, &ctx).unwrap();
205 let s = format!("{expr}");
206 assert!(s.contains("sum") || s.contains("SUM"), "Got: {s}");
207 }
208}