Skip to main content

elo_rust/codegen/
optimization.rs

1//! Optimization passes for ELO code generation
2//!
3//! Provides optimization strategies including constant folding,
4//! dead code elimination, and expression simplification.
5
6use crate::ast::{BinaryOperator, Expr, Literal, UnaryOperator};
7
8/// Optimization context for code generation
9#[derive(Debug, Clone)]
10pub struct Optimizer;
11
12impl Optimizer {
13    /// Create a new optimizer
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Apply all optimizations to an expression
19    pub fn optimize(expr: &Expr) -> Expr {
20        Self::fold_constants(expr)
21    }
22
23    /// Constant folding: evaluate constant expressions at compile time
24    pub fn fold_constants(expr: &Expr) -> Expr {
25        match expr {
26            // Binary operations on literals can be folded
27            Expr::BinaryOp { op, left, right } => {
28                let left_folded = Self::fold_constants(left);
29                let right_folded = Self::fold_constants(right);
30
31                // Try to fold if both sides are literals
32                if let (Expr::Literal(left_lit), Expr::Literal(right_lit)) =
33                    (&left_folded, &right_folded)
34                {
35                    if let Some(folded) = Self::fold_binary_op(*op, left_lit, right_lit) {
36                        return folded;
37                    }
38                }
39
40                Expr::BinaryOp {
41                    op: *op,
42                    left: Box::new(left_folded),
43                    right: Box::new(right_folded),
44                }
45            }
46
47            // Unary operations on literals can be folded
48            Expr::UnaryOp { op, operand } => {
49                let operand_folded = Self::fold_constants(operand);
50
51                if let Expr::Literal(lit) = &operand_folded {
52                    if let Some(folded) = Self::fold_unary_op(*op, lit) {
53                        return folded;
54                    }
55                }
56
57                Expr::UnaryOp {
58                    op: *op,
59                    operand: Box::new(operand_folded),
60                }
61            }
62
63            // Recursively fold expressions in containers
64            Expr::Array(elements) => {
65                let folded: Vec<Expr> = elements.iter().map(Self::fold_constants).collect();
66                Expr::Array(folded)
67            }
68
69            Expr::Object(fields) => {
70                let folded: Vec<(String, Expr)> = fields
71                    .iter()
72                    .map(|(k, v)| (k.clone(), Self::fold_constants(v)))
73                    .collect();
74                Expr::Object(folded)
75            }
76
77            Expr::FieldAccess { receiver, field } => Expr::FieldAccess {
78                receiver: Box::new(Self::fold_constants(receiver)),
79                field: field.clone(),
80            },
81
82            Expr::FunctionCall { name, args } => Expr::FunctionCall {
83                name: name.clone(),
84                args: args.iter().map(Self::fold_constants).collect(),
85            },
86
87            Expr::Lambda { param, body } => Expr::Lambda {
88                param: param.clone(),
89                body: Box::new(Self::fold_constants(body)),
90            },
91
92            Expr::Let { name, value, body } => Expr::Let {
93                name: name.clone(),
94                value: Box::new(Self::fold_constants(value)),
95                body: Box::new(Self::fold_constants(body)),
96            },
97
98            Expr::If {
99                condition,
100                then_branch,
101                else_branch,
102            } => Expr::If {
103                condition: Box::new(Self::fold_constants(condition)),
104                then_branch: Box::new(Self::fold_constants(then_branch)),
105                else_branch: Box::new(Self::fold_constants(else_branch)),
106            },
107
108            Expr::Pipe { value, functions } => Expr::Pipe {
109                value: Box::new(Self::fold_constants(value)),
110                functions: functions.iter().map(Self::fold_constants).collect(),
111            },
112
113            Expr::Alternative {
114                primary,
115                alternative,
116            } => Expr::Alternative {
117                primary: Box::new(Self::fold_constants(primary)),
118                alternative: Box::new(Self::fold_constants(alternative)),
119            },
120
121            Expr::Guard { condition, body } => Expr::Guard {
122                condition: Box::new(Self::fold_constants(condition)),
123                body: Box::new(Self::fold_constants(body)),
124            },
125
126            // Literals and identifiers cannot be folded further
127            expr => expr.clone(),
128        }
129    }
130
131    /// Fold a binary operation on two literals
132    fn fold_binary_op(op: BinaryOperator, left: &Literal, right: &Literal) -> Option<Expr> {
133        match (left, right) {
134            (Literal::Integer(l), Literal::Integer(r)) => {
135                match op {
136                    BinaryOperator::Add => {
137                        let result = l.checked_add(*r)?;
138                        Some(Expr::Literal(Literal::Integer(result)))
139                    }
140                    BinaryOperator::Sub => {
141                        let result = l.checked_sub(*r)?;
142                        Some(Expr::Literal(Literal::Integer(result)))
143                    }
144                    BinaryOperator::Mul => {
145                        let result = l.checked_mul(*r)?;
146                        Some(Expr::Literal(Literal::Integer(result)))
147                    }
148                    BinaryOperator::Div if *r != 0 => {
149                        let result = l.checked_div(*r)?;
150                        Some(Expr::Literal(Literal::Integer(result)))
151                    }
152                    BinaryOperator::Mod if *r != 0 => Some(Expr::Literal(Literal::Integer(l % r))),
153                    BinaryOperator::Pow => {
154                        if *r < 0 || *r > 31 {
155                            return None; // Can't fold negative or large exponents
156                        }
157                        let result = l.checked_pow(*r as u32)?;
158                        Some(Expr::Literal(Literal::Integer(result)))
159                    }
160                    BinaryOperator::Eq => Some(Expr::Literal(Literal::Boolean(l == r))),
161                    BinaryOperator::Neq => Some(Expr::Literal(Literal::Boolean(l != r))),
162                    BinaryOperator::Lt => Some(Expr::Literal(Literal::Boolean(l < r))),
163                    BinaryOperator::Lte => Some(Expr::Literal(Literal::Boolean(l <= r))),
164                    BinaryOperator::Gt => Some(Expr::Literal(Literal::Boolean(l > r))),
165                    BinaryOperator::Gte => Some(Expr::Literal(Literal::Boolean(l >= r))),
166                    #[allow(clippy::needless_return)]
167                    _ => return None,
168                }
169            }
170
171            (Literal::Float(l), Literal::Float(r)) => {
172                let result = match op {
173                    BinaryOperator::Add => l + r,
174                    BinaryOperator::Sub => l - r,
175                    BinaryOperator::Mul => l * r,
176                    BinaryOperator::Div if *r != 0.0 => l / r,
177                    BinaryOperator::Mod if *r != 0.0 => l % r,
178                    BinaryOperator::Pow => l.powf(*r),
179                    BinaryOperator::Eq => {
180                        return Some(Expr::Literal(Literal::Boolean(
181                            (l - r).abs() < f64::EPSILON,
182                        )))
183                    }
184                    BinaryOperator::Neq => {
185                        return Some(Expr::Literal(Literal::Boolean(
186                            (l - r).abs() >= f64::EPSILON,
187                        )))
188                    }
189                    BinaryOperator::Lt => return Some(Expr::Literal(Literal::Boolean(l < r))),
190                    BinaryOperator::Lte => return Some(Expr::Literal(Literal::Boolean(l <= r))),
191                    BinaryOperator::Gt => return Some(Expr::Literal(Literal::Boolean(l > r))),
192                    BinaryOperator::Gte => return Some(Expr::Literal(Literal::Boolean(l >= r))),
193                    _ => return None,
194                };
195                Some(Expr::Literal(Literal::Float(result)))
196            }
197
198            (Literal::Boolean(l), Literal::Boolean(r)) => match op {
199                BinaryOperator::And => Some(Expr::Literal(Literal::Boolean(*l && *r))),
200                BinaryOperator::Or => Some(Expr::Literal(Literal::Boolean(*l || *r))),
201                BinaryOperator::Eq => Some(Expr::Literal(Literal::Boolean(l == r))),
202                BinaryOperator::Neq => Some(Expr::Literal(Literal::Boolean(l != r))),
203                _ => None,
204            },
205
206            _ => None,
207        }
208    }
209
210    /// Fold a unary operation on a literal
211    fn fold_unary_op(op: UnaryOperator, lit: &Literal) -> Option<Expr> {
212        match op {
213            UnaryOperator::Not => {
214                if let Literal::Boolean(b) = lit {
215                    Some(Expr::Literal(Literal::Boolean(!b)))
216                } else {
217                    None
218                }
219            }
220            UnaryOperator::Neg => match lit {
221                Literal::Integer(n) => Some(Expr::Literal(Literal::Integer(-n))),
222                Literal::Float(f) => Some(Expr::Literal(Literal::Float(-f))),
223                _ => None,
224            },
225            UnaryOperator::Plus => match lit {
226                Literal::Integer(n) => Some(Expr::Literal(Literal::Integer(*n))),
227                Literal::Float(f) => Some(Expr::Literal(Literal::Float(*f))),
228                _ => None,
229            },
230        }
231    }
232}
233
234impl Default for Optimizer {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_optimizer_creation() {
246        let _opt = Optimizer::new();
247    }
248
249    #[test]
250    fn test_fold_integer_addition() {
251        let expr = Expr::BinaryOp {
252            op: BinaryOperator::Add,
253            left: Box::new(Expr::Literal(Literal::Integer(5))),
254            right: Box::new(Expr::Literal(Literal::Integer(3))),
255        };
256
257        let folded = Optimizer::optimize(&expr);
258        match folded {
259            Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 8),
260            _ => panic!("Expected folded integer literal"),
261        }
262    }
263
264    #[test]
265    fn test_fold_integer_subtraction() {
266        let expr = Expr::BinaryOp {
267            op: BinaryOperator::Sub,
268            left: Box::new(Expr::Literal(Literal::Integer(10))),
269            right: Box::new(Expr::Literal(Literal::Integer(3))),
270        };
271
272        let folded = Optimizer::optimize(&expr);
273        match folded {
274            Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 7),
275            _ => panic!("Expected folded integer literal"),
276        }
277    }
278
279    #[test]
280    fn test_fold_integer_multiplication() {
281        let expr = Expr::BinaryOp {
282            op: BinaryOperator::Mul,
283            left: Box::new(Expr::Literal(Literal::Integer(4))),
284            right: Box::new(Expr::Literal(Literal::Integer(3))),
285        };
286
287        let folded = Optimizer::optimize(&expr);
288        match folded {
289            Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 12),
290            _ => panic!("Expected folded integer literal"),
291        }
292    }
293
294    #[test]
295    fn test_fold_float_addition() {
296        let expr = Expr::BinaryOp {
297            op: BinaryOperator::Add,
298            left: Box::new(Expr::Literal(Literal::Float(2.5))),
299            right: Box::new(Expr::Literal(Literal::Float(1.5))),
300        };
301
302        let folded = Optimizer::optimize(&expr);
303        match folded {
304            Expr::Literal(Literal::Float(f)) => assert!((f - 4.0).abs() < f64::EPSILON),
305            _ => panic!("Expected folded float literal"),
306        }
307    }
308
309    #[test]
310    fn test_fold_boolean_and() {
311        let expr = Expr::BinaryOp {
312            op: BinaryOperator::And,
313            left: Box::new(Expr::Literal(Literal::Boolean(true))),
314            right: Box::new(Expr::Literal(Literal::Boolean(false))),
315        };
316
317        let folded = Optimizer::optimize(&expr);
318        match folded {
319            Expr::Literal(Literal::Boolean(b)) => assert!(!b),
320            _ => panic!("Expected folded boolean literal"),
321        }
322    }
323
324    #[test]
325    fn test_fold_boolean_or() {
326        let expr = Expr::BinaryOp {
327            op: BinaryOperator::Or,
328            left: Box::new(Expr::Literal(Literal::Boolean(false))),
329            right: Box::new(Expr::Literal(Literal::Boolean(true))),
330        };
331
332        let folded = Optimizer::optimize(&expr);
333        match folded {
334            Expr::Literal(Literal::Boolean(b)) => assert!(b),
335            _ => panic!("Expected folded boolean literal"),
336        }
337    }
338
339    #[test]
340    fn test_fold_integer_comparison() {
341        let expr = Expr::BinaryOp {
342            op: BinaryOperator::Gt,
343            left: Box::new(Expr::Literal(Literal::Integer(10))),
344            right: Box::new(Expr::Literal(Literal::Integer(5))),
345        };
346
347        let folded = Optimizer::optimize(&expr);
348        match folded {
349            Expr::Literal(Literal::Boolean(b)) => assert!(b),
350            _ => panic!("Expected folded boolean literal"),
351        }
352    }
353
354    #[test]
355    fn test_fold_unary_not() {
356        let expr = Expr::UnaryOp {
357            op: UnaryOperator::Not,
358            operand: Box::new(Expr::Literal(Literal::Boolean(true))),
359        };
360
361        let folded = Optimizer::optimize(&expr);
362        match folded {
363            Expr::Literal(Literal::Boolean(b)) => assert!(!b),
364            _ => panic!("Expected folded boolean literal"),
365        }
366    }
367
368    #[test]
369    fn test_fold_unary_negate() {
370        let expr = Expr::UnaryOp {
371            op: UnaryOperator::Neg,
372            operand: Box::new(Expr::Literal(Literal::Integer(42))),
373        };
374
375        let folded = Optimizer::optimize(&expr);
376        match folded {
377            Expr::Literal(Literal::Integer(n)) => assert_eq!(n, -42),
378            _ => panic!("Expected folded integer literal"),
379        }
380    }
381
382    #[test]
383    fn test_no_fold_identifier() {
384        let expr = Expr::BinaryOp {
385            op: BinaryOperator::Add,
386            left: Box::new(Expr::Identifier("x".to_string())),
387            right: Box::new(Expr::Literal(Literal::Integer(5))),
388        };
389
390        let folded = Optimizer::optimize(&expr);
391        // Should not be folded since left is not a literal
392        matches!(folded, Expr::BinaryOp { .. });
393    }
394
395    #[test]
396    fn test_fold_nested_constants() {
397        let expr = Expr::BinaryOp {
398            op: BinaryOperator::Add,
399            left: Box::new(Expr::BinaryOp {
400                op: BinaryOperator::Mul,
401                left: Box::new(Expr::Literal(Literal::Integer(2))),
402                right: Box::new(Expr::Literal(Literal::Integer(3))),
403            }),
404            right: Box::new(Expr::Literal(Literal::Integer(4))),
405        };
406
407        let folded = Optimizer::optimize(&expr);
408        match folded {
409            Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 10), // (2*3)+4 = 10
410            _ => panic!("Expected folded integer literal"),
411        }
412    }
413
414    #[test]
415    fn test_fold_array_constants() {
416        let expr = Expr::Array(vec![
417            Expr::Literal(Literal::Integer(1)),
418            Expr::BinaryOp {
419                op: BinaryOperator::Add,
420                left: Box::new(Expr::Literal(Literal::Integer(2))),
421                right: Box::new(Expr::Literal(Literal::Integer(3))),
422            },
423        ]);
424
425        let folded = Optimizer::optimize(&expr);
426        match folded {
427            Expr::Array(elements) => {
428                assert_eq!(elements.len(), 2);
429                match &elements[1] {
430                    Expr::Literal(Literal::Integer(n)) => assert_eq!(*n, 5),
431                    _ => panic!("Expected folded constant in array"),
432                }
433            }
434            _ => panic!("Expected array expression"),
435        }
436    }
437
438    #[test]
439    fn test_fold_division_by_zero() {
440        let expr = Expr::BinaryOp {
441            op: BinaryOperator::Div,
442            left: Box::new(Expr::Literal(Literal::Integer(10))),
443            right: Box::new(Expr::Literal(Literal::Integer(0))),
444        };
445
446        let folded = Optimizer::optimize(&expr);
447        // Should not fold division by zero
448        matches!(folded, Expr::BinaryOp { .. });
449    }
450}