Skip to main content

ruby_ir/
traversal.rs

1//! IR traversal and manipulation utilities
2
3use crate::{BasicBlock, BlockId, Class, ClassId, Expression, Function, FunctionId, Module, ModuleId, Program, Statement};
4
5/// Trait for visiting IR nodes
6pub trait Visitor {
7    /// Visit an expression
8    fn visit_expression(&mut self, expr: &Expression) {
9        match expr {
10            Expression::Literal(_) => {}
11            Expression::Variable(_) => {}
12            Expression::GlobalVariable(_) => {}
13            Expression::InstanceVariable(_) => {}
14            Expression::ClassVariable(_) => {}
15            Expression::MethodCall { receiver, method: _, arguments } => {
16                self.visit_expression(receiver);
17                for arg in arguments {
18                    self.visit_expression(arg);
19                }
20            }
21            Expression::BinaryOp { left, op: _, right } => {
22                self.visit_expression(left);
23                self.visit_expression(right);
24            }
25            Expression::UnaryOp { op: _, operand } => {
26                self.visit_expression(operand);
27            }
28            Expression::ArrayLiteral(elements) => {
29                for elem in elements {
30                    self.visit_expression(elem);
31                }
32            }
33            Expression::HashLiteral(pairs) => {
34                for (_, value) in pairs {
35                    self.visit_expression(value);
36                }
37            }
38            Expression::Block { parameters: _, body } => {
39                for stmt in body {
40                    self.visit_statement(stmt);
41                }
42            }
43            Expression::SelfRef => {}
44            Expression::SuperCall { arguments } => {
45                for arg in arguments {
46                    self.visit_expression(arg);
47                }
48            }
49        }
50    }
51
52    /// Visit a statement
53    fn visit_statement(&mut self, stmt: &Statement) {
54        match stmt {
55            Statement::Expression(expr) => {
56                self.visit_expression(expr);
57            }
58            Statement::Assignment { name: _, value } => {
59                self.visit_expression(value);
60            }
61            Statement::GlobalAssignment { name: _, value } => {
62                self.visit_expression(value);
63            }
64            Statement::InstanceAssignment { name: _, value } => {
65                self.visit_expression(value);
66            }
67            Statement::ClassAssignment { name: _, value } => {
68                self.visit_expression(value);
69            }
70            Statement::If { condition, then_branch, else_branch } => {
71                self.visit_expression(condition);
72                for stmt in then_branch {
73                    self.visit_statement(stmt);
74                }
75                for stmt in else_branch {
76                    self.visit_statement(stmt);
77                }
78            }
79            Statement::While { condition, body } => {
80                self.visit_expression(condition);
81                for stmt in body {
82                    self.visit_statement(stmt);
83                }
84            }
85            Statement::Until { condition, body } => {
86                self.visit_expression(condition);
87                for stmt in body {
88                    self.visit_statement(stmt);
89                }
90            }
91            Statement::Case { value, when_clauses, else_clause } => {
92                self.visit_expression(value);
93                for (cond, body) in when_clauses {
94                    self.visit_expression(cond);
95                    for stmt in body {
96                        self.visit_statement(stmt);
97                    }
98                }
99                for stmt in else_clause {
100                    self.visit_statement(stmt);
101                }
102            }
103            Statement::For { variable: _, iterator, body } => {
104                self.visit_expression(iterator);
105                for stmt in body {
106                    self.visit_statement(stmt);
107                }
108            }
109            Statement::Break => {}
110            Statement::Next => {}
111            Statement::Redo => {}
112            Statement::Return(expr) => {
113                if let Some(expr) = expr {
114                    self.visit_expression(expr);
115                }
116            }
117            Statement::MethodDefinition { name: _, parameters: _, body } => {
118                for stmt in body {
119                    self.visit_statement(stmt);
120                }
121            }
122            Statement::ClassDefinition { name: _, superclass: _, body } => {
123                for stmt in body {
124                    self.visit_statement(stmt);
125                }
126            }
127            Statement::ModuleDefinition { name: _, body } => {
128                for stmt in body {
129                    self.visit_statement(stmt);
130                }
131            }
132            Statement::Require(expr) => {
133                self.visit_expression(expr);
134            }
135            Statement::Load(expr) => {
136                self.visit_expression(expr);
137            }
138        }
139    }
140
141    /// Visit a basic block
142    fn visit_basic_block(&mut self, block: &BasicBlock) {
143        for stmt in &block.statements {
144            self.visit_statement(stmt);
145        }
146    }
147
148    /// Visit a function
149    fn visit_function(&mut self, function: &Function) {
150        for block in function.blocks.values() {
151            self.visit_basic_block(block);
152        }
153    }
154
155    /// Visit a class
156    fn visit_class(&mut self, _class: &Class) {
157        // Classes themselves don't have direct statements, but their methods do
158    }
159
160    /// Visit a module
161    fn visit_module(&mut self, _module: &Module) {
162        // Modules themselves don't have direct statements, but their methods do
163    }
164
165    /// Visit a program
166    fn visit_program(&mut self, program: &Program) {
167        for stmt in &program.global_statements {
168            self.visit_statement(stmt);
169        }
170        for function in program.functions.values() {
171            self.visit_function(function);
172        }
173        for class in program.classes.values() {
174            self.visit_class(class);
175        }
176        for module in program.modules.values() {
177            self.visit_module(module);
178        }
179    }
180}
181
182/// Trait for mutating IR nodes
183pub trait Mutator {
184    /// Mutate an expression
185    fn mutate_expression(&mut self, expr: &mut Expression) {
186        match expr {
187            Expression::Literal(_) => {}
188            Expression::Variable(_) => {}
189            Expression::GlobalVariable(_) => {}
190            Expression::InstanceVariable(_) => {}
191            Expression::ClassVariable(_) => {}
192            Expression::MethodCall { receiver, method: _, arguments } => {
193                self.mutate_expression(receiver);
194                for arg in arguments {
195                    self.mutate_expression(arg);
196                }
197            }
198            Expression::BinaryOp { left, op: _, right } => {
199                self.mutate_expression(left);
200                self.mutate_expression(right);
201            }
202            Expression::UnaryOp { op: _, operand } => {
203                self.mutate_expression(operand);
204            }
205            Expression::ArrayLiteral(elements) => {
206                for elem in elements {
207                    self.mutate_expression(elem);
208                }
209            }
210            Expression::HashLiteral(pairs) => {
211                for (_, value) in pairs {
212                    self.mutate_expression(value);
213                }
214            }
215            Expression::Block { parameters: _, body } => {
216                for stmt in body {
217                    self.mutate_statement(stmt);
218                }
219            }
220            Expression::SelfRef => {}
221            Expression::SuperCall { arguments } => {
222                for arg in arguments {
223                    self.mutate_expression(arg);
224                }
225            }
226        }
227    }
228
229    /// Mutate a statement
230    fn mutate_statement(&mut self, stmt: &mut Statement) {
231        match stmt {
232            Statement::Expression(expr) => {
233                self.mutate_expression(expr);
234            }
235            Statement::Assignment { name: _, value } => {
236                self.mutate_expression(value);
237            }
238            Statement::GlobalAssignment { name: _, value } => {
239                self.mutate_expression(value);
240            }
241            Statement::InstanceAssignment { name: _, value } => {
242                self.mutate_expression(value);
243            }
244            Statement::ClassAssignment { name: _, value } => {
245                self.mutate_expression(value);
246            }
247            Statement::If { condition, then_branch, else_branch } => {
248                self.mutate_expression(condition);
249                for stmt in then_branch {
250                    self.mutate_statement(stmt);
251                }
252                for stmt in else_branch {
253                    self.mutate_statement(stmt);
254                }
255            }
256            Statement::While { condition, body } => {
257                self.mutate_expression(condition);
258                for stmt in body {
259                    self.mutate_statement(stmt);
260                }
261            }
262            Statement::Until { condition, body } => {
263                self.mutate_expression(condition);
264                for stmt in body {
265                    self.mutate_statement(stmt);
266                }
267            }
268            Statement::Case { value, when_clauses, else_clause } => {
269                self.mutate_expression(value);
270                for (cond, body) in when_clauses {
271                    self.mutate_expression(cond);
272                    for stmt in body {
273                        self.mutate_statement(stmt);
274                    }
275                }
276                for stmt in else_clause {
277                    self.mutate_statement(stmt);
278                }
279            }
280            Statement::For { variable: _, iterator, body } => {
281                self.mutate_expression(iterator);
282                for stmt in body {
283                    self.mutate_statement(stmt);
284                }
285            }
286            Statement::Break => {}
287            Statement::Next => {}
288            Statement::Redo => {}
289            Statement::Return(expr) => {
290                if let Some(expr) = expr {
291                    self.mutate_expression(expr);
292                }
293            }
294            Statement::MethodDefinition { name: _, parameters: _, body } => {
295                for stmt in body {
296                    self.mutate_statement(stmt);
297                }
298            }
299            Statement::ClassDefinition { name: _, superclass: _, body } => {
300                for stmt in body {
301                    self.mutate_statement(stmt);
302                }
303            }
304            Statement::ModuleDefinition { name: _, body } => {
305                for stmt in body {
306                    self.mutate_statement(stmt);
307                }
308            }
309            Statement::Require(expr) => {
310                self.mutate_expression(expr);
311            }
312            Statement::Load(expr) => {
313                self.mutate_expression(expr);
314            }
315        }
316    }
317
318    /// Mutate a basic block
319    fn mutate_basic_block(&mut self, block: &mut BasicBlock) {
320        for stmt in &mut block.statements {
321            self.mutate_statement(stmt);
322        }
323    }
324
325    /// Mutate a function
326    fn mutate_function(&mut self, function: &mut Function) {
327        for block in function.blocks.values_mut() {
328            self.mutate_basic_block(block);
329        }
330    }
331
332    /// Mutate a class
333    fn mutate_class(&mut self, _class: &mut Class) {
334        // Classes themselves don't have direct statements, but their methods do
335    }
336
337    /// Mutate a module
338    fn mutate_module(&mut self, _module: &mut Module) {
339        // Modules themselves don't have direct statements, but their methods do
340    }
341
342    /// Mutate a program
343    fn mutate_program(&mut self, program: &mut Program) {
344        for stmt in &mut program.global_statements {
345            self.mutate_statement(stmt);
346        }
347        for function in program.functions.values_mut() {
348            self.mutate_function(function);
349        }
350        for class in program.classes.values_mut() {
351            self.mutate_class(class);
352        }
353        for module in program.modules.values_mut() {
354            self.mutate_module(module);
355        }
356    }
357}
358
359/// Builder for creating IR programs
360pub struct ProgramBuilder {
361    program: Program,
362    next_function_id: FunctionId,
363    next_class_id: ClassId,
364    next_module_id: ModuleId,
365    next_block_id: BlockId,
366}
367
368impl ProgramBuilder {
369    /// Create a new program builder
370    pub fn new() -> Self {
371        Self { program: Program::new(), next_function_id: 0, next_class_id: 0, next_module_id: 0, next_block_id: 0 }
372    }
373
374    /// Build the program
375    pub fn build(self) -> Program {
376        self.program
377    }
378
379    /// Add a global statement
380    pub fn add_global_statement(&mut self, statement: Statement) -> &mut Self {
381        self.program.add_global_statement(statement);
382        self
383    }
384
385    /// Create a new function
386    pub fn new_function(&mut self, name: String, parameters: Vec<String>) -> FunctionBuilder<'_> {
387        let function_id = self.next_function_id;
388        self.next_function_id += 1;
389
390        FunctionBuilder {
391            function: Function { id: function_id, name, parameters, blocks: std::collections::HashMap::new(), entry_block: 0 },
392            next_block_id: 0,
393            program_builder: self,
394        }
395    }
396
397    /// Create a new class
398    pub fn new_class(&mut self, name: String, superclass: Option<String>) -> ClassBuilder<'_> {
399        let class_id = self.next_class_id;
400        self.next_class_id += 1;
401
402        ClassBuilder { class: Class { id: class_id, name, superclass, methods: std::collections::HashMap::new() }, program_builder: self }
403    }
404
405    /// Create a new module
406    pub fn new_module(&mut self, name: String) -> ModuleBuilder<'_> {
407        let module_id = self.next_module_id;
408        self.next_module_id += 1;
409
410        ModuleBuilder { module: Module { id: module_id, name, methods: std::collections::HashMap::new() }, program_builder: self }
411    }
412
413    /// Get the next block ID
414    fn next_block_id(&mut self) -> BlockId {
415        let id = self.next_block_id;
416        self.next_block_id += 1;
417        id
418    }
419}
420
421/// Builder for creating IR functions
422pub struct FunctionBuilder<'a> {
423    function: Function,
424    next_block_id: BlockId,
425    program_builder: &'a mut ProgramBuilder,
426}
427
428impl<'a> FunctionBuilder<'a> {
429    /// Create a new basic block
430    pub fn new_block(&mut self) -> BlockBuilder<'_, 'a> {
431        let block_id = self.next_block_id;
432        self.next_block_id += 1;
433
434        BlockBuilder { block: BasicBlock { id: block_id, statements: Vec::new(), successors: Vec::new() }, function_builder: self }
435    }
436
437    /// Set the entry block
438    pub fn entry_block(&mut self, block_id: BlockId) -> &mut Self {
439        self.function.entry_block = block_id;
440        self
441    }
442
443    /// Build the function and add it to the program
444    pub fn build(self) -> &'a mut ProgramBuilder {
445        self.program_builder.program.add_function(self.function);
446        self.program_builder
447    }
448}
449
450/// Builder for creating IR basic blocks
451pub struct BlockBuilder<'b, 'a: 'b> {
452    block: BasicBlock,
453    function_builder: &'b mut FunctionBuilder<'a>,
454}
455
456impl<'b, 'a: 'b> BlockBuilder<'b, 'a> {
457    /// Add a statement to the block
458    pub fn add_statement(&mut self, statement: Statement) -> &mut Self {
459        self.block.statements.push(statement);
460        self
461    }
462
463    /// Add a successor block
464    pub fn add_successor(&mut self, successor_id: BlockId) -> &mut Self {
465        self.block.successors.push(successor_id);
466        self
467    }
468
469    /// Build the block and add it to the function
470    pub fn build(self) -> &'b mut FunctionBuilder<'a> {
471        self.function_builder.function.blocks.insert(self.block.id, self.block);
472        self.function_builder
473    }
474}
475
476/// Builder for creating IR classes
477pub struct ClassBuilder<'a> {
478    class: Class,
479    program_builder: &'a mut ProgramBuilder,
480}
481
482impl<'a> ClassBuilder<'a> {
483    /// Add a method to the class
484    pub fn add_method(&mut self, method_name: String, function_id: FunctionId) -> &mut Self {
485        self.class.methods.insert(method_name, function_id);
486        self
487    }
488
489    /// Build the class and add it to the program
490    pub fn build(self) -> &'a mut ProgramBuilder {
491        self.program_builder.program.add_class(self.class);
492        self.program_builder
493    }
494}
495
496/// Builder for creating IR modules
497pub struct ModuleBuilder<'a> {
498    module: Module,
499    program_builder: &'a mut ProgramBuilder,
500}
501
502impl<'a> ModuleBuilder<'a> {
503    /// Add a method to the module
504    pub fn add_method(&mut self, method_name: String, function_id: FunctionId) -> &mut Self {
505        self.module.methods.insert(method_name, function_id);
506        self
507    }
508
509    /// Build the module and add it to the program
510    pub fn build(self) -> &'a mut ProgramBuilder {
511        self.program_builder.program.add_module(self.module);
512        self.program_builder
513    }
514}