Skip to main content

ruby_ir/
optimization.rs

1//! IR optimization utilities
2
3use crate::{BinaryOperator, Expression, Program, Statement, UnaryOperator, traversal::Mutator};
4use ruby_types::RubyValue;
5
6/// Constant folder for expressions
7pub struct ConstantFolder;
8
9impl Mutator for ConstantFolder {
10    fn mutate_expression(&mut self, expr: &mut Expression) {
11        // First mutate sub-expressions
12        match expr {
13            Expression::MethodCall { receiver, method: _, arguments } => {
14                self.mutate_expression(receiver);
15                for arg in arguments {
16                    self.mutate_expression(arg);
17                }
18            }
19            Expression::BinaryOp { left, op, right } => {
20                self.mutate_expression(left);
21                self.mutate_expression(right);
22
23                // Try to fold binary operations with constant operands
24                if let (Expression::Literal(left_val), Expression::Literal(right_val)) = (left.as_ref(), right.as_ref()) {
25                    if let Some(result) = self.evaluate_binary_op(op.clone(), left_val, right_val) {
26                        *expr = Expression::Literal(result);
27                    }
28                }
29            }
30            Expression::UnaryOp { op, operand } => {
31                self.mutate_expression(operand);
32
33                // Try to fold unary operations with constant operands
34                if let Expression::Literal(val) = operand.as_ref() {
35                    if let Some(result) = self.evaluate_unary_op(op.clone(), val) {
36                        *expr = Expression::Literal(result);
37                    }
38                }
39            }
40            Expression::ArrayLiteral(elements) => {
41                for elem in elements {
42                    self.mutate_expression(elem);
43                }
44            }
45            Expression::HashLiteral(pairs) => {
46                for (_, value) in pairs {
47                    self.mutate_expression(value);
48                }
49            }
50            Expression::Block { parameters: _, body } => {
51                for stmt in body {
52                    self.mutate_statement(stmt);
53                }
54            }
55            Expression::SuperCall { arguments } => {
56                for arg in arguments {
57                    self.mutate_expression(arg);
58                }
59            }
60            _ => {}
61        }
62    }
63
64    fn mutate_statement(&mut self, stmt: &mut Statement) {
65        // First mutate sub-expressions
66        match stmt {
67            Statement::Expression(expr) => {
68                self.mutate_expression(expr);
69            }
70            Statement::Assignment { name: _, value } => {
71                self.mutate_expression(value);
72            }
73            Statement::GlobalAssignment { name: _, value } => {
74                self.mutate_expression(value);
75            }
76            Statement::InstanceAssignment { name: _, value } => {
77                self.mutate_expression(value);
78            }
79            Statement::ClassAssignment { name: _, value } => {
80                self.mutate_expression(value);
81            }
82            Statement::If { condition, then_branch, else_branch } => {
83                self.mutate_expression(condition);
84
85                // Try to fold if statements with constant conditions
86                if let Expression::Literal(val) = condition {
87                    if val.to_bool() {
88                        // If condition is true, replace with then branch
89                        let mut new_body = Vec::new();
90                        new_body.extend(then_branch.iter().cloned());
91                        *stmt = Statement::Expression(Expression::Literal(RubyValue::Nil));
92                        // TODO: Replace with the then branch statements
93                    }
94                    else {
95                        // If condition is false, replace with else branch
96                        let mut new_body = Vec::new();
97                        new_body.extend(else_branch.iter().cloned());
98                        *stmt = Statement::Expression(Expression::Literal(RubyValue::Nil));
99                        // TODO: Replace with the else branch statements
100                    }
101                }
102                else {
103                    for stmt in then_branch {
104                        self.mutate_statement(stmt);
105                    }
106                    for stmt in else_branch {
107                        self.mutate_statement(stmt);
108                    }
109                }
110            }
111            Statement::While { condition, body } => {
112                self.mutate_expression(condition);
113
114                // Try to fold while statements with constant conditions
115                if let Expression::Literal(val) = condition {
116                    if !val.to_bool() {
117                        // If condition is false, remove the loop
118                        *stmt = Statement::Expression(Expression::Literal(RubyValue::Nil));
119                    }
120                }
121                else {
122                    for stmt in body {
123                        self.mutate_statement(stmt);
124                    }
125                }
126            }
127            Statement::For { variable: _, iterator, body } => {
128                self.mutate_expression(iterator);
129                for stmt in body {
130                    self.mutate_statement(stmt);
131                }
132            }
133            Statement::Return(expr) => {
134                if let Some(expr) = expr {
135                    self.mutate_expression(expr);
136                }
137            }
138            Statement::MethodDefinition { name: _, parameters: _, body } => {
139                for stmt in body {
140                    self.mutate_statement(stmt);
141                }
142            }
143            Statement::ClassDefinition { name: _, superclass: _, body } => {
144                for stmt in body {
145                    self.mutate_statement(stmt);
146                }
147            }
148            Statement::ModuleDefinition { name: _, body } => {
149                for stmt in body {
150                    self.mutate_statement(stmt);
151                }
152            }
153            Statement::Require(expr) => {
154                self.mutate_expression(expr);
155            }
156            Statement::Load(expr) => {
157                self.mutate_expression(expr);
158            }
159            _ => {}
160        }
161    }
162}
163
164impl ConstantFolder {
165    /// Evaluate a binary operation with constant operands
166    fn evaluate_binary_op(&self, op: BinaryOperator, left: &RubyValue, right: &RubyValue) -> Option<RubyValue> {
167        match op {
168            BinaryOperator::Add => Some(RubyValue::Float(left.to_f64() + right.to_f64())),
169            BinaryOperator::Sub => Some(RubyValue::Float(left.to_f64() - right.to_f64())),
170            BinaryOperator::Mul => Some(RubyValue::Float(left.to_f64() * right.to_f64())),
171            BinaryOperator::Div => {
172                if right.to_f64() != 0.0 {
173                    Some(RubyValue::Float(left.to_f64() / right.to_f64()))
174                }
175                else {
176                    None // Division by zero
177                }
178            }
179            BinaryOperator::Mod => Some(RubyValue::Float(left.to_f64() % right.to_f64())),
180            BinaryOperator::Exp => Some(RubyValue::Float(left.to_f64().powf(right.to_f64()))),
181            BinaryOperator::Eq => Some(RubyValue::Boolean(left == right)),
182            BinaryOperator::Neq => Some(RubyValue::Boolean(left != right)),
183            BinaryOperator::Lt => Some(RubyValue::Boolean(left.to_f64() < right.to_f64())),
184            BinaryOperator::Lte => Some(RubyValue::Boolean(left.to_f64() <= right.to_f64())),
185            BinaryOperator::Gt => Some(RubyValue::Boolean(left.to_f64() > right.to_f64())),
186            BinaryOperator::Gte => Some(RubyValue::Boolean(left.to_f64() >= right.to_f64())),
187            BinaryOperator::And => Some(RubyValue::Boolean(left.to_bool() && right.to_bool())),
188            BinaryOperator::Or => Some(RubyValue::Boolean(left.to_bool() || right.to_bool())),
189            BinaryOperator::Assign => {
190                None // Assignment can't be folded
191            }
192        }
193    }
194
195    /// Evaluate a unary operation with constant operands
196    fn evaluate_unary_op(&self, op: UnaryOperator, operand: &RubyValue) -> Option<RubyValue> {
197        match op {
198            UnaryOperator::Not => Some(RubyValue::Boolean(!operand.to_bool())),
199            UnaryOperator::Neg => Some(RubyValue::Float(-operand.to_f64())),
200            UnaryOperator::Plus => Some(RubyValue::Float(operand.to_f64())),
201            UnaryOperator::BitNot => Some(RubyValue::Integer(!operand.to_i32())),
202        }
203    }
204}
205
206/// Dead code eliminator
207pub struct DeadCodeEliminator;
208
209impl Mutator for DeadCodeEliminator {
210    fn mutate_statement(&mut self, stmt: &mut Statement) {
211        // First mutate sub-expressions and statements
212        match stmt {
213            Statement::Expression(expr) => {
214                self.mutate_expression(expr);
215                // If the expression is pure, replace it with a nil expression
216                if self.is_pure_expression(expr) {
217                    *stmt = Statement::Expression(Expression::Literal(RubyValue::Nil));
218                }
219            }
220            Statement::Assignment { name: _, value } => {
221                self.mutate_expression(value);
222            }
223            Statement::GlobalAssignment { name: _, value } => {
224                self.mutate_expression(value);
225            }
226            Statement::InstanceAssignment { name: _, value } => {
227                self.mutate_expression(value);
228            }
229            Statement::ClassAssignment { name: _, value } => {
230                self.mutate_expression(value);
231            }
232            Statement::If { condition, then_branch, else_branch } => {
233                self.mutate_expression(condition);
234
235                // Mutate then branch
236                let mut new_then_branch = Vec::new();
237                for s in &mut *then_branch {
238                    let mut s = s.clone();
239                    self.mutate_statement(&mut s);
240                    // Only keep non-dead statements
241                    if !self.is_dead_statement(&s) {
242                        new_then_branch.push(s);
243                    }
244                }
245                *then_branch = new_then_branch;
246
247                // Mutate else branch
248                let mut new_else_branch = Vec::new();
249                for s in &mut *else_branch {
250                    let mut s = s.clone();
251                    self.mutate_statement(&mut s);
252                    // Only keep non-dead statements
253                    if !self.is_dead_statement(&s) {
254                        new_else_branch.push(s);
255                    }
256                }
257                *else_branch = new_else_branch;
258            }
259            Statement::While { condition, body } => {
260                self.mutate_expression(condition);
261
262                // Mutate body
263                let mut new_body = Vec::new();
264                for s in &mut *body {
265                    let mut s = s.clone();
266                    self.mutate_statement(&mut s);
267                    // Only keep non-dead statements
268                    if !self.is_dead_statement(&s) {
269                        new_body.push(s);
270                    }
271                }
272                *body = new_body;
273            }
274            Statement::For { variable: _, iterator, body } => {
275                self.mutate_expression(iterator);
276
277                // Mutate body
278                let mut new_body = Vec::new();
279                for s in &mut *body {
280                    let mut s = s.clone();
281                    self.mutate_statement(&mut s);
282                    // Only keep non-dead statements
283                    if !self.is_dead_statement(&s) {
284                        new_body.push(s);
285                    }
286                }
287                *body = new_body;
288            }
289            Statement::Return(expr) => {
290                if let Some(expr) = expr {
291                    self.mutate_expression(expr);
292                }
293            }
294            Statement::MethodDefinition { name: _, parameters: _, body } => {
295                // Mutate body
296                let mut new_body = Vec::new();
297                for s in &mut *body {
298                    let mut s = s.clone();
299                    self.mutate_statement(&mut s);
300                    // Only keep non-dead statements
301                    if !self.is_dead_statement(&s) {
302                        new_body.push(s);
303                    }
304                }
305                *body = new_body;
306            }
307            Statement::ClassDefinition { name: _, superclass: _, body } => {
308                // Mutate body
309                let mut new_body = Vec::new();
310                for s in &mut *body {
311                    let mut s = s.clone();
312                    self.mutate_statement(&mut s);
313                    // Only keep non-dead statements
314                    if !self.is_dead_statement(&s) {
315                        new_body.push(s);
316                    }
317                }
318                *body = new_body;
319            }
320            Statement::ModuleDefinition { name: _, body } => {
321                // Mutate body
322                let mut new_body = Vec::new();
323                for s in &mut *body {
324                    let mut s = s.clone();
325                    self.mutate_statement(&mut s);
326                    // Only keep non-dead statements
327                    if !self.is_dead_statement(&s) {
328                        new_body.push(s);
329                    }
330                }
331                *body = new_body;
332            }
333            Statement::Require(expr) => {
334                self.mutate_expression(expr);
335            }
336            Statement::Load(expr) => {
337                self.mutate_expression(expr);
338            }
339            _ => {}
340        }
341    }
342}
343
344impl DeadCodeEliminator {
345    /// Check if a statement is dead
346    fn is_dead_statement(&self, stmt: &Statement) -> bool {
347        match stmt {
348            // Expression statements with no side effects are dead
349            Statement::Expression(expr) => self.is_pure_expression(expr),
350            _ => {
351                false // Other statements have side effects
352            }
353        }
354    }
355
356    /// Check if an expression is pure (has no side effects)
357    fn is_pure_expression(&self, expr: &Expression) -> bool {
358        match expr {
359            Expression::Literal(_) => true,
360            Expression::Variable(_) => true,
361            Expression::GlobalVariable(_) => false,   // Global variables have side effects
362            Expression::InstanceVariable(_) => false, // Instance variables have side effects
363            Expression::ClassVariable(_) => false,    // Class variables have side effects
364            Expression::MethodCall { .. } => false,   // Method calls may have side effects
365            Expression::BinaryOp { left, right, .. } => self.is_pure_expression(left) && self.is_pure_expression(right),
366            Expression::UnaryOp { operand, .. } => self.is_pure_expression(operand),
367            Expression::ArrayLiteral(elements) => elements.iter().all(|e| self.is_pure_expression(e)),
368            Expression::HashLiteral(pairs) => pairs.values().all(|e| self.is_pure_expression(e)),
369            Expression::Block { .. } => false, // Blocks may have side effects
370            Expression::SelfRef => true,
371            Expression::SuperCall { .. } => false, // Super calls may have side effects
372        }
373    }
374}
375
376/// Optimize a program
377pub fn optimize_program(program: &mut Program) {
378    // Apply constant folding
379    let mut constant_folder = ConstantFolder;
380    constant_folder.mutate_program(program);
381
382    // Apply dead code elimination
383    let mut dead_code_eliminator = DeadCodeEliminator;
384    dead_code_eliminator.mutate_program(program);
385}