Skip to main content

elo_rust/codegen/
ast_to_code.rs

1//! AST-to-Rust code generation visitor
2//!
3//! This module implements the Visitor trait to transform ELO AST nodes into
4//! Rust TokenStreams that can be compiled.
5
6use crate::ast::visitor::Visitor;
7use crate::ast::{BinaryOperator, Expr, Literal, TemporalKeyword, UnaryOperator};
8use proc_macro2::TokenStream;
9use quote::quote;
10
11use super::{
12    functions::FunctionGenerator,
13    operators::{BinaryOp, OperatorGenerator, UnaryOp},
14    temporal::TemporalGenerator,
15};
16
17/// Visitor that generates Rust code from ELO AST
18#[derive(Debug)]
19pub struct CodegenVisitor {
20    operator_gen: OperatorGenerator,
21    function_gen: FunctionGenerator,
22    temporal_gen: TemporalGenerator,
23}
24
25impl CodegenVisitor {
26    /// Create a new code generation visitor
27    pub fn new() -> Self {
28        CodegenVisitor {
29            operator_gen: OperatorGenerator::new(),
30            function_gen: FunctionGenerator::new(),
31            temporal_gen: TemporalGenerator::new(),
32        }
33    }
34
35    /// Convert AST BinaryOperator to codegen BinaryOp
36    fn convert_binary_op(op: BinaryOperator) -> BinaryOp {
37        match op {
38            BinaryOperator::Add => BinaryOp::Add,
39            BinaryOperator::Sub => BinaryOp::Subtract,
40            BinaryOperator::Mul => BinaryOp::Multiply,
41            BinaryOperator::Div => BinaryOp::Divide,
42            BinaryOperator::Mod => BinaryOp::Modulo,
43            BinaryOperator::Pow => BinaryOp::Multiply, // Fallback, would need special handling
44            BinaryOperator::Eq => BinaryOp::Equal,
45            BinaryOperator::Neq => BinaryOp::NotEqual,
46            BinaryOperator::Lt => BinaryOp::Less,
47            BinaryOperator::Lte => BinaryOp::LessEqual,
48            BinaryOperator::Gt => BinaryOp::Greater,
49            BinaryOperator::Gte => BinaryOp::GreaterEqual,
50            BinaryOperator::And => BinaryOp::And,
51            BinaryOperator::Or => BinaryOp::Or,
52        }
53    }
54
55    /// Convert AST UnaryOperator to codegen UnaryOp
56    fn convert_unary_op(op: UnaryOperator) -> UnaryOp {
57        match op {
58            UnaryOperator::Not => UnaryOp::Not,
59            UnaryOperator::Neg => UnaryOp::Negate,
60            UnaryOperator::Plus => UnaryOp::Negate, // Identity, treat as no-op via negate
61        }
62    }
63}
64
65impl Default for CodegenVisitor {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl Visitor<TokenStream> for CodegenVisitor {
72    fn visit_expr(&mut self, expr: &Expr) -> TokenStream {
73        match expr {
74            Expr::Literal(lit) => self.visit_literal(lit),
75            Expr::Null => self.visit_null(),
76            Expr::Identifier(name) => self.visit_identifier(name),
77            Expr::String(value) => self.visit_string(value),
78            Expr::FieldAccess { receiver, field } => self.visit_field_access(receiver, field),
79            Expr::BinaryOp { op, left, right } => self.visit_binary_op(*op, left, right),
80            Expr::UnaryOp { op, operand } => self.visit_unary_op(*op, operand),
81            Expr::FunctionCall { name, args } => self.visit_function_call(name, args),
82            Expr::Lambda { param, body } => self.visit_lambda(param, body),
83            Expr::Let { name, value, body } => self.visit_let(name, value, body),
84            Expr::If {
85                condition,
86                then_branch,
87                else_branch,
88            } => self.visit_if(condition, then_branch, else_branch),
89            Expr::Array(elements) => self.visit_array(elements),
90            Expr::Object(fields) => self.visit_object(fields),
91            Expr::Pipe { value, functions } => self.visit_pipe(value, functions),
92            Expr::Alternative {
93                primary,
94                alternative,
95            } => self.visit_alternative(primary, alternative),
96            Expr::Guard { condition, body } => self.visit_guard(condition, body),
97            Expr::Date(date) => self.visit_date(date),
98            Expr::DateTime(datetime) => self.visit_datetime(datetime),
99            Expr::Duration(duration) => self.visit_duration(duration),
100            Expr::TemporalKeyword(keyword) => self.visit_temporal_keyword(*keyword),
101        }
102    }
103
104    fn visit_literal(&mut self, lit: &Literal) -> TokenStream {
105        match lit {
106            Literal::Integer(n) => quote! { #n },
107            Literal::Float(f) => quote! { #f },
108            Literal::Boolean(b) => quote! { #b },
109        }
110    }
111
112    fn visit_null(&mut self) -> TokenStream {
113        quote! { None::<()> }
114    }
115
116    fn visit_identifier(&mut self, name: &str) -> TokenStream {
117        let ident = quote::format_ident!("{}", name);
118        quote! { #ident }
119    }
120
121    fn visit_field_access(&mut self, receiver: &Expr, field: &str) -> TokenStream {
122        let recv = self.visit_expr(receiver);
123        let field_ident = quote::format_ident!("{}", field);
124        quote! { #recv.#field_ident }
125    }
126
127    fn visit_binary_op(&mut self, op: BinaryOperator, left: &Expr, right: &Expr) -> TokenStream {
128        let l = self.visit_expr(left);
129        let r = self.visit_expr(right);
130        let codegen_op = Self::convert_binary_op(op);
131        self.operator_gen.binary(codegen_op, l, r)
132    }
133
134    fn visit_unary_op(&mut self, op: UnaryOperator, operand: &Expr) -> TokenStream {
135        let operand = self.visit_expr(operand);
136        let codegen_op = Self::convert_unary_op(op);
137        self.operator_gen.unary(codegen_op, operand)
138    }
139
140    fn visit_function_call(&mut self, name: &str, args: &[Expr]) -> TokenStream {
141        let arg_tokens: Vec<TokenStream> = args.iter().map(|a| self.visit_expr(a)).collect();
142
143        // Use the unified function generator interface
144        self.function_gen.call(name, arg_tokens)
145    }
146
147    fn visit_lambda(&mut self, param: &str, body: &Expr) -> TokenStream {
148        let param_ident = quote::format_ident!("{}", param);
149        let body = self.visit_expr(body);
150        quote! {
151            |#param_ident| {
152                #body
153            }
154        }
155    }
156
157    fn visit_let(&mut self, name: &str, value: &Expr, body: &Expr) -> TokenStream {
158        let var_ident = quote::format_ident!("{}", name);
159        let val = self.visit_expr(value);
160        let bod = self.visit_expr(body);
161        quote! {
162            {
163                let #var_ident = #val;
164                #bod
165            }
166        }
167    }
168
169    fn visit_if(
170        &mut self,
171        condition: &Expr,
172        then_branch: &Expr,
173        else_branch: &Expr,
174    ) -> TokenStream {
175        let cond = self.visit_expr(condition);
176        let then_b = self.visit_expr(then_branch);
177        let else_b = self.visit_expr(else_branch);
178        quote! {
179            if #cond { #then_b } else { #else_b }
180        }
181    }
182
183    fn visit_array(&mut self, elements: &[Expr]) -> TokenStream {
184        let elems: Vec<TokenStream> = elements.iter().map(|e| self.visit_expr(e)).collect();
185        quote! {
186            vec![#(#elems),*]
187        }
188    }
189
190    fn visit_object(&mut self, fields: &[(String, Expr)]) -> TokenStream {
191        // For objects, we can't easily generate code without knowing the target type
192        // For now, generate a tuple struct representing the key-value pairs
193        let pairs: Vec<TokenStream> = fields
194            .iter()
195            .map(|(k, v)| {
196                let val = self.visit_expr(v);
197                let key_str = k.clone();
198                quote! { (#key_str, #val) }
199            })
200            .collect();
201        quote! {
202            vec![#(#pairs),*]
203        }
204    }
205
206    fn visit_pipe(&mut self, value: &Expr, functions: &[Expr]) -> TokenStream {
207        let mut result = self.visit_expr(value);
208
209        for func in functions {
210            // For each function in the pipe, we need to apply it to the previous result
211            // If it's a function call, inject result as first arg
212            // Otherwise, create a function call with result as argument
213            match func {
214                Expr::FunctionCall { name, args } => {
215                    // Insert result as first argument
216                    let mut new_args = vec![result.clone()];
217                    for arg in args {
218                        new_args.push(self.visit_expr(arg));
219                    }
220
221                    // Generate the function call with the new arguments
222                    let arg_tokens: Vec<TokenStream> = new_args;
223                    result = self.function_gen.call(name, arg_tokens);
224                }
225                Expr::Identifier(name) => {
226                    // Simple identifier - treat as a function call with one argument
227                    result = self.function_gen.call(name, vec![result]);
228                }
229                _ => {
230                    // Other expressions - try to apply them
231                    result = self.visit_expr(func);
232                }
233            }
234        }
235        result
236    }
237
238    fn visit_alternative(&mut self, primary: &Expr, alternative: &Expr) -> TokenStream {
239        let prim = self.visit_expr(primary);
240        let alt = self.visit_expr(alternative);
241        quote! {
242            #prim.or_else(|| #alt)
243        }
244    }
245
246    fn visit_guard(&mut self, condition: &Expr, body: &Expr) -> TokenStream {
247        let cond = self.visit_expr(condition);
248        let bod = self.visit_expr(body);
249        quote! {
250            if #cond { #bod } else { panic!("Guard failed") }
251        }
252    }
253
254    fn visit_date(&mut self, date: &str) -> TokenStream {
255        self.temporal_gen.date(date)
256    }
257
258    fn visit_datetime(&mut self, datetime: &str) -> TokenStream {
259        self.temporal_gen.datetime(datetime)
260    }
261
262    fn visit_duration(&mut self, duration: &str) -> TokenStream {
263        self.temporal_gen.duration(duration)
264    }
265
266    fn visit_temporal_keyword(&mut self, keyword: TemporalKeyword) -> TokenStream {
267        let keyword_str = match keyword {
268            TemporalKeyword::Now => "NOW",
269            TemporalKeyword::Today => "TODAY",
270            TemporalKeyword::Tomorrow => "TOMORROW",
271            TemporalKeyword::Yesterday => "YESTERDAY",
272            TemporalKeyword::StartOfDay => "START_OF_DAY",
273            TemporalKeyword::EndOfDay => "END_OF_DAY",
274            TemporalKeyword::StartOfWeek => "START_OF_WEEK",
275            TemporalKeyword::EndOfWeek => "END_OF_WEEK",
276            TemporalKeyword::StartOfMonth => "START_OF_MONTH",
277            TemporalKeyword::EndOfMonth => "END_OF_MONTH",
278            TemporalKeyword::StartOfQuarter => "START_OF_QUARTER",
279            TemporalKeyword::EndOfQuarter => "END_OF_QUARTER",
280            TemporalKeyword::StartOfYear => "START_OF_YEAR",
281            TemporalKeyword::EndOfYear => "END_OF_YEAR",
282            TemporalKeyword::BeginningOfTime => "BEGINNING_OF_TIME",
283            TemporalKeyword::EndOfTime => "END_OF_TIME",
284        };
285        self.temporal_gen.keyword(keyword_str)
286    }
287
288    fn visit_string(&mut self, value: &str) -> TokenStream {
289        quote! { #value }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::ast::BinaryOperator;
297
298    #[test]
299    fn test_codegen_literal_integer() {
300        let mut visitor = CodegenVisitor::new();
301        let expr = Expr::Literal(Literal::Integer(42));
302        let tokens = visitor.visit_expr(&expr);
303        let tokens_str = tokens.to_string();
304        assert!(tokens_str.contains("42"));
305    }
306
307    #[test]
308    fn test_codegen_identifier() {
309        let mut visitor = CodegenVisitor::new();
310        let expr = Expr::Identifier("age".to_string());
311        let tokens = visitor.visit_expr(&expr);
312        let tokens_str = tokens.to_string();
313        assert!(tokens_str.contains("age"));
314    }
315
316    #[test]
317    fn test_codegen_field_access() {
318        let mut visitor = CodegenVisitor::new();
319        let expr = Expr::FieldAccess {
320            receiver: Box::new(Expr::Identifier("user".to_string())),
321            field: "age".to_string(),
322        };
323        let tokens = visitor.visit_expr(&expr);
324        let tokens_str = tokens.to_string();
325        assert!(tokens_str.contains("user"));
326        assert!(tokens_str.contains("age"));
327    }
328
329    #[test]
330    fn test_codegen_binary_op() {
331        let mut visitor = CodegenVisitor::new();
332        let expr = Expr::BinaryOp {
333            op: BinaryOperator::Add,
334            left: Box::new(Expr::Literal(Literal::Integer(1))),
335            right: Box::new(Expr::Literal(Literal::Integer(2))),
336        };
337        let tokens = visitor.visit_expr(&expr);
338        let tokens_str = tokens.to_string();
339        // Should contain the addition operation
340        assert!(!tokens_str.is_empty());
341    }
342
343    #[test]
344    fn test_codegen_let_expr() {
345        let mut visitor = CodegenVisitor::new();
346        let expr = Expr::Let {
347            name: "x".to_string(),
348            value: Box::new(Expr::Literal(Literal::Integer(42))),
349            body: Box::new(Expr::Identifier("x".to_string())),
350        };
351        let tokens = visitor.visit_expr(&expr);
352        let tokens_str = tokens.to_string();
353        assert!(tokens_str.contains("let"));
354        assert!(tokens_str.contains("x"));
355    }
356
357    #[test]
358    fn test_codegen_if_expr() {
359        let mut visitor = CodegenVisitor::new();
360        let expr = Expr::If {
361            condition: Box::new(Expr::Literal(Literal::Boolean(true))),
362            then_branch: Box::new(Expr::Literal(Literal::Integer(1))),
363            else_branch: Box::new(Expr::Literal(Literal::Integer(0))),
364        };
365        let tokens = visitor.visit_expr(&expr);
366        let tokens_str = tokens.to_string();
367        assert!(tokens_str.contains("if"));
368    }
369
370    #[test]
371    fn test_codegen_array() {
372        let mut visitor = CodegenVisitor::new();
373        let expr = Expr::Array(vec![
374            Expr::Literal(Literal::Integer(1)),
375            Expr::Literal(Literal::Integer(2)),
376        ]);
377        let tokens = visitor.visit_expr(&expr);
378        let tokens_str = tokens.to_string();
379        assert!(tokens_str.contains("vec"));
380    }
381
382    #[test]
383    fn test_codegen_string() {
384        let mut visitor = CodegenVisitor::new();
385        let expr = Expr::String("hello".to_string());
386        let tokens = visitor.visit_expr(&expr);
387        let tokens_str = tokens.to_string();
388        assert!(tokens_str.contains("hello"));
389    }
390
391    #[test]
392    fn test_codegen_null() {
393        let mut visitor = CodegenVisitor::new();
394        let expr = Expr::Null;
395        let tokens = visitor.visit_expr(&expr);
396        let tokens_str = tokens.to_string();
397        assert!(tokens_str.contains("None"));
398    }
399}