1use crate::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
5
6pub struct Optimizer {
8 pub tail_recursion: bool,
10 pub constant_folding: bool,
12 pub dead_code_elimination: bool,
14}
15
16impl Optimizer {
17 pub fn new() -> Self {
19 Optimizer {
20 tail_recursion: true,
21 constant_folding: true,
22 dead_code_elimination: true,
23 }
24 }
25
26 pub fn optimize_program(&self, program: &Program) -> Program {
28 let mut optimized = program.clone();
29
30 if self.constant_folding {
32 optimized = self.fold_constants(optimized);
33 }
34
35 if self.dead_code_elimination {
37 optimized = self.eliminate_dead_code(optimized);
38 }
39
40 if self.tail_recursion {
42 optimized = self.optimize_tail_recursion(optimized);
43 }
44
45 optimized
46 }
47
48 fn fold_constants(&self, program: Program) -> Program {
50 program
51 .into_iter()
52 .map(|stmt| self.fold_stmt(stmt))
53 .collect()
54 }
55
56 fn fold_stmt(&self, stmt: Stmt) -> Stmt {
58 match stmt {
59 Stmt::Set { name, value } => Stmt::Set {
60 name,
61 value: self.fold_expr(value),
62 },
63 Stmt::FuncDef { name, params, body } => Stmt::FuncDef {
64 name,
65 params,
66 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
67 },
68 Stmt::GeneratorDef { name, params, body } => Stmt::GeneratorDef {
69 name,
70 params,
71 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
72 },
73 Stmt::Return(expr) => Stmt::Return(self.fold_expr(expr)),
74 Stmt::Yield(expr) => Stmt::Yield(self.fold_expr(expr)),
75 Stmt::While { condition, body } => Stmt::While {
76 condition: self.fold_expr(condition),
77 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
78 },
79 Stmt::For {
80 var,
81 iterable,
82 body,
83 } => Stmt::For {
84 var,
85 iterable: self.fold_expr(iterable),
86 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
87 },
88 Stmt::ForIndexed {
89 index_var,
90 value_var,
91 iterable,
92 body,
93 } => Stmt::ForIndexed {
94 index_var,
95 value_var,
96 iterable: self.fold_expr(iterable),
97 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
98 },
99 Stmt::Expression(expr) => Stmt::Expression(self.fold_expr(expr)),
100 other => other,
101 }
102 }
103
104 fn fold_expr(&self, expr: Expr) -> Expr {
106 match expr {
107 Expr::Binary { left, op, right } => {
109 let left = self.fold_expr(*left);
110 let right = self.fold_expr(*right);
111
112 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
114 if let Some(result) = Self::eval_const_binary(*l, &op, *r) {
115 return Expr::Number(result);
116 }
117 }
118
119 Expr::Binary {
120 left: Box::new(left),
121 op,
122 right: Box::new(right),
123 }
124 }
125
126 Expr::Unary { op, expr } => {
128 let expr = self.fold_expr(*expr);
129
130 if let Expr::Number(n) = expr {
131 match op {
132 UnaryOp::Minus => return Expr::Number(-n),
133 UnaryOp::Not => return Expr::Boolean(n == 0.0),
134 }
135 }
136
137 if let (UnaryOp::Not, Expr::Boolean(b)) = (&op, &expr) {
138 return Expr::Boolean(!b);
139 }
140
141 Expr::Unary {
142 op,
143 expr: Box::new(expr),
144 }
145 }
146
147 Expr::Call { func, args } => Expr::Call {
149 func: Box::new(self.fold_expr(*func)),
150 args: args.into_iter().map(|e| self.fold_expr(e)).collect(),
151 },
152
153 Expr::Array(elements) => {
154 Expr::Array(elements.into_iter().map(|e| self.fold_expr(e)).collect())
155 }
156
157 Expr::Index { object, index } => Expr::Index {
158 object: Box::new(self.fold_expr(*object)),
159 index: Box::new(self.fold_expr(*index)),
160 },
161
162 other => other,
163 }
164 }
165
166 fn eval_const_binary(left: f64, op: &BinOp, right: f64) -> Option<f64> {
168 match op {
169 BinOp::Add => Some(left + right),
170 BinOp::Subtract => Some(left - right),
171 BinOp::Multiply => Some(left * right),
172 BinOp::Divide if right != 0.0 => Some(left / right),
173 BinOp::Modulo if right != 0.0 => Some(left % right),
174 _ => None,
175 }
176 }
177
178 fn eliminate_dead_code(&self, program: Program) -> Program {
180 program
181 .into_iter()
182 .filter_map(|stmt| self.eliminate_dead_stmt(stmt))
183 .collect()
184 }
185
186 fn eliminate_dead_stmt(&self, stmt: Stmt) -> Option<Stmt> {
188 match stmt {
189 Stmt::While { condition, body } => {
191 if let Expr::Boolean(false) = condition {
192 return None;
194 }
195
196 Some(Stmt::While {
197 condition,
198 body: body
199 .into_iter()
200 .filter_map(|s| self.eliminate_dead_stmt(s))
201 .collect(),
202 })
203 }
204
205 Stmt::FuncDef { name, params, body } => Some(Stmt::FuncDef {
207 name,
208 params,
209 body: body
210 .into_iter()
211 .filter_map(|s| self.eliminate_dead_stmt(s))
212 .collect(),
213 }),
214
215 Stmt::GeneratorDef { name, params, body } => Some(Stmt::GeneratorDef {
216 name,
217 params,
218 body: body
219 .into_iter()
220 .filter_map(|s| self.eliminate_dead_stmt(s))
221 .collect(),
222 }),
223
224 Stmt::Expression(expr) => Some(Stmt::Expression(self.eliminate_dead_expr(expr))),
226
227 other => Some(other),
228 }
229 }
230
231 fn eliminate_dead_expr(&self, expr: Expr) -> Expr {
233 match expr {
234 Expr::If {
235 condition,
236 then_branch,
237 elif_branches,
238 else_branch,
239 } => {
240 if let Expr::Boolean(true) = *condition {
241 return Expr::If {
243 condition: Box::new(Expr::Boolean(true)),
244 then_branch,
245 elif_branches: vec![],
246 else_branch: None,
247 };
248 }
249
250 if let Expr::Boolean(false) = *condition {
251 if let Some(else_body) = else_branch {
253 return Expr::If {
255 condition: Box::new(Expr::Boolean(true)),
256 then_branch: else_body,
257 elif_branches: vec![],
258 else_branch: None,
259 };
260 }
261 return Expr::Null;
263 }
264
265 Expr::If {
267 condition,
268 then_branch: then_branch
269 .into_iter()
270 .filter_map(|s| self.eliminate_dead_stmt(s))
271 .collect(),
272 elif_branches: elif_branches
273 .into_iter()
274 .map(|(c, b)| {
275 (
276 self.eliminate_dead_expr(c),
277 b.into_iter()
278 .filter_map(|s| self.eliminate_dead_stmt(s))
279 .collect(),
280 )
281 })
282 .collect(),
283 else_branch: else_branch.map(|b| {
284 b.into_iter()
285 .filter_map(|s| self.eliminate_dead_stmt(s))
286 .collect()
287 }),
288 }
289 }
290 other => other,
291 }
292 }
293
294 fn optimize_tail_recursion(&self, program: Program) -> Program {
296 program
297 .into_iter()
298 .map(|stmt| self.optimize_tail_recursive_stmt(stmt))
299 .collect()
300 }
301
302 fn optimize_tail_recursive_stmt(&self, stmt: Stmt) -> Stmt {
304 match stmt {
305 Stmt::FuncDef { name, params, body } => {
306 if self.is_tail_recursive(&name, &body) {
308 Stmt::FuncDef {
310 name: name.clone(),
311 params: params.clone(),
312 body: self.convert_tail_recursion_to_loop(&name, ¶ms, body),
313 }
314 } else {
315 Stmt::FuncDef { name, params, body }
316 }
317 }
318 other => other,
319 }
320 }
321
322 fn is_tail_recursive(&self, func_name: &str, body: &[Stmt]) -> bool {
324 if body.is_empty() {
325 return false;
326 }
327
328 self.has_tail_recursion_in_body(func_name, body)
330 }
331
332 fn has_tail_recursion_in_body(&self, func_name: &str, body: &[Stmt]) -> bool {
334 body.iter()
336 .any(|stmt| self.stmt_has_tail_recursion(func_name, stmt))
337 }
338
339 fn stmt_has_tail_recursion(&self, func_name: &str, stmt: &Stmt) -> bool {
341 match stmt {
342 Stmt::Return(expr) => self.is_tail_call(func_name, expr),
343 Stmt::Expression(expr) => self.expr_has_tail_recursion(func_name, expr),
344 Stmt::While { body, .. } => self.has_tail_recursion_in_body(func_name, body),
345 Stmt::For { body, .. } => self.has_tail_recursion_in_body(func_name, body),
346 Stmt::ForIndexed { body, .. } => self.has_tail_recursion_in_body(func_name, body),
347 _ => false,
348 }
349 }
350
351 fn expr_has_tail_recursion(&self, func_name: &str, expr: &Expr) -> bool {
353 match expr {
354 Expr::If {
355 then_branch,
356 elif_branches,
357 else_branch,
358 ..
359 } => {
360 let then_tail = self.has_tail_recursion_in_body(func_name, then_branch);
362 let elif_tail = elif_branches
363 .iter()
364 .any(|(_, body)| self.has_tail_recursion_in_body(func_name, body));
365 let else_tail = else_branch
366 .as_ref()
367 .map(|body| self.has_tail_recursion_in_body(func_name, body))
368 .unwrap_or(false);
369
370 then_tail || elif_tail || else_tail
371 }
372 _ => false,
373 }
374 }
375
376 fn is_tail_call(&self, func_name: &str, expr: &Expr) -> bool {
378 match expr {
379 Expr::Call { func, .. } => {
381 if let Expr::Identifier(name) = &**func {
382 name == func_name
383 } else {
384 false
385 }
386 }
387 Expr::If {
389 then_branch,
390 elif_branches,
391 else_branch,
392 ..
393 } => {
394 let then_is_tail = self.branch_ends_with_tail_call(func_name, then_branch);
396
397 let elif_all_tail = elif_branches
398 .iter()
399 .all(|(_, body)| self.branch_ends_with_tail_call(func_name, body));
400
401 let else_is_tail = else_branch
402 .as_ref()
403 .map(|body| self.branch_ends_with_tail_call(func_name, body))
404 .unwrap_or(true);
405
406 then_is_tail && elif_all_tail && else_is_tail
407 }
408 _ => false,
409 }
410 }
411
412 fn branch_ends_with_tail_call(&self, func_name: &str, branch: &[Stmt]) -> bool {
414 if let Some(last_stmt) = branch.last() {
415 match last_stmt {
416 Stmt::Return(expr) => self.is_tail_call(func_name, expr),
417 Stmt::Expression(expr) => {
418 self.is_tail_call(func_name, expr)
420 }
421 _ => false,
422 }
423 } else {
424 false
425 }
426 }
427
428 fn convert_tail_recursion_to_loop(
430 &self,
431 func_name: &str,
432 params: &[String],
433 body: Vec<Stmt>,
434 ) -> Vec<Stmt> {
435 let mut new_body = Vec::new();
437
438 for param in params {
440 new_body.push(Stmt::Set {
441 name: format!("_loop_{}", param),
442 value: Expr::Identifier(param.clone()),
443 });
444 }
445
446 new_body.push(Stmt::Set {
448 name: "_loop_continue".to_string(),
449 value: Expr::Boolean(true),
450 });
451
452 let loop_body = self.transform_body_to_loop(func_name, params, body);
454
455 new_body.push(Stmt::While {
457 condition: Expr::Identifier("_loop_continue".to_string()),
458 body: loop_body,
459 });
460
461 new_body
462 }
463
464 fn transform_body_to_loop(
466 &self,
467 func_name: &str,
468 params: &[String],
469 body: Vec<Stmt>,
470 ) -> Vec<Stmt> {
471 let mut loop_body = Vec::new();
472
473 for stmt in body {
474 match stmt {
475 Stmt::Return(expr) => {
476 if let Some(new_args) = self.extract_tail_call_args(func_name, &expr) {
478 for (i, param) in params.iter().enumerate() {
480 if let Some(arg) = new_args.get(i) {
481 loop_body.push(Stmt::Set {
482 name: format!("_loop_{}", param),
483 value: arg.clone(),
484 });
485 }
486 }
487
488 for param in params {
490 loop_body.push(Stmt::Set {
491 name: param.clone(),
492 value: Expr::Identifier(format!("_loop_{}", param)),
493 });
494 }
495
496 } else {
498 loop_body.push(Stmt::Set {
500 name: "_loop_continue".to_string(),
501 value: Expr::Boolean(false),
502 });
503 loop_body.push(Stmt::Return(expr));
504 }
505 }
506 _ => {
507 loop_body.push(self.transform_stmt_for_loop(func_name, params, stmt));
509 }
510 }
511 }
512
513 loop_body
514 }
515
516 fn extract_tail_call_args(&self, func_name: &str, expr: &Expr) -> Option<Vec<Expr>> {
518 match expr {
519 Expr::Call { func, args } => {
520 if let Expr::Identifier(name) = &**func {
521 if name == func_name {
522 return Some(args.clone());
523 }
524 }
525 None
526 }
527 _ => None,
528 }
529 }
530
531 fn transform_stmt_for_loop(&self, func_name: &str, params: &[String], stmt: Stmt) -> Stmt {
533 match stmt {
534 Stmt::Expression(expr) => {
535 Stmt::Expression(self.transform_expr_for_loop(func_name, params, expr))
537 }
538 Stmt::While { condition, body } => Stmt::While {
539 condition,
540 body: self.transform_body_to_loop(func_name, params, body),
541 },
542 Stmt::For {
543 var,
544 iterable,
545 body,
546 } => Stmt::For {
547 var,
548 iterable,
549 body: self.transform_body_to_loop(func_name, params, body),
550 },
551 Stmt::ForIndexed {
552 index_var,
553 value_var,
554 iterable,
555 body,
556 } => Stmt::ForIndexed {
557 index_var,
558 value_var,
559 iterable,
560 body: self.transform_body_to_loop(func_name, params, body),
561 },
562 other => other,
563 }
564 }
565
566 fn transform_expr_for_loop(&self, func_name: &str, params: &[String], expr: Expr) -> Expr {
568 match expr {
569 Expr::If {
570 condition,
571 then_branch,
572 elif_branches,
573 else_branch,
574 } => Expr::If {
575 condition,
576 then_branch: self.transform_body_to_loop(func_name, params, then_branch),
577 elif_branches: elif_branches
578 .into_iter()
579 .map(|(cond, body)| {
580 (cond, self.transform_body_to_loop(func_name, params, body))
581 })
582 .collect(),
583 else_branch: else_branch
584 .map(|body| self.transform_body_to_loop(func_name, params, body)),
585 },
586 other => other,
587 }
588 }
589}
590
591impl Default for Optimizer {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_constant_folding() {
603 let optimizer = Optimizer::new();
604
605 let expr = Expr::Binary {
607 left: Box::new(Expr::Number(2.0)),
608 op: BinOp::Add,
609 right: Box::new(Expr::Number(3.0)),
610 };
611
612 let folded = optimizer.fold_expr(expr);
613 assert_eq!(folded, Expr::Number(5.0));
614 }
615
616 #[test]
617 fn test_dead_code_elimination() {
618 let optimizer = Optimizer::new();
619
620 let stmt = Stmt::While {
622 condition: Expr::Boolean(false),
623 body: vec![Stmt::Set {
624 name: "x".to_string(),
625 value: Expr::Number(10.0),
626 }],
627 };
628
629 let result = optimizer.eliminate_dead_stmt(stmt);
630 assert!(result.is_none());
631 }
632
633 #[test]
634 fn test_optimize_program() {
635 let optimizer = Optimizer::new();
636
637 let program: Program = vec![Stmt::Set {
638 name: "x".to_string(),
639 value: Expr::Binary {
640 left: Box::new(Expr::Number(2.0)),
641 op: BinOp::Add,
642 right: Box::new(Expr::Number(3.0)),
643 },
644 }];
645
646 let optimized = optimizer.optimize_program(&program);
647
648 if let Some(Stmt::Set { value, .. }) = optimized.first() {
650 assert_eq!(*value, Expr::Number(5.0));
651 }
652 }
653
654 #[test]
655 fn test_tail_recursion_detection() {
656 let optimizer = Optimizer::new();
657
658 let body = vec![Stmt::Return(Expr::Call {
660 func: Box::new(Expr::Identifier("factorial".to_string())),
661 args: vec![
662 Expr::Binary {
663 left: Box::new(Expr::Identifier("n".to_string())),
664 op: BinOp::Subtract,
665 right: Box::new(Expr::Number(1.0)),
666 },
667 Expr::Binary {
668 left: Box::new(Expr::Identifier("acc".to_string())),
669 op: BinOp::Multiply,
670 right: Box::new(Expr::Identifier("n".to_string())),
671 },
672 ],
673 })];
674
675 assert!(optimizer.is_tail_recursive("factorial", &body));
676 }
677
678 #[test]
679 fn test_non_tail_recursion_detection() {
680 let optimizer = Optimizer::new();
681
682 let body = vec![Stmt::Return(Expr::Binary {
684 left: Box::new(Expr::Identifier("n".to_string())),
685 op: BinOp::Multiply,
686 right: Box::new(Expr::Call {
687 func: Box::new(Expr::Identifier("factorial".to_string())),
688 args: vec![Expr::Binary {
689 left: Box::new(Expr::Identifier("n".to_string())),
690 op: BinOp::Subtract,
691 right: Box::new(Expr::Number(1.0)),
692 }],
693 }),
694 })];
695
696 assert!(!optimizer.is_tail_recursive("factorial", &body));
697 }
698
699 #[test]
700 fn test_tail_recursion_in_if() {
701 let optimizer = Optimizer::new();
702
703 let body = vec![Stmt::Expression(Expr::If {
707 condition: Box::new(Expr::Binary {
708 left: Box::new(Expr::Identifier("n".to_string())),
709 op: BinOp::LessEqual,
710 right: Box::new(Expr::Number(0.0)),
711 }),
712 then_branch: vec![Stmt::Return(Expr::Identifier("acc".to_string()))],
713 elif_branches: vec![],
714 else_branch: Some(vec![Stmt::Return(Expr::Call {
715 func: Box::new(Expr::Identifier("sum".to_string())),
716 args: vec![
717 Expr::Binary {
718 left: Box::new(Expr::Identifier("n".to_string())),
719 op: BinOp::Subtract,
720 right: Box::new(Expr::Number(1.0)),
721 },
722 Expr::Binary {
723 left: Box::new(Expr::Identifier("acc".to_string())),
724 op: BinOp::Add,
725 right: Box::new(Expr::Identifier("n".to_string())),
726 },
727 ],
728 })]),
729 })];
730
731 assert!(optimizer.is_tail_recursive("sum", &body));
732 }
733
734 #[test]
735 fn test_tail_recursion_optimization_transform() {
736 let optimizer = Optimizer::new();
737
738 let func_def = Stmt::FuncDef {
740 name: "factorial".to_string(),
741 params: vec!["n".to_string(), "acc".to_string()],
742 body: vec![Stmt::Return(Expr::Call {
743 func: Box::new(Expr::Identifier("factorial".to_string())),
744 args: vec![
745 Expr::Binary {
746 left: Box::new(Expr::Identifier("n".to_string())),
747 op: BinOp::Subtract,
748 right: Box::new(Expr::Number(1.0)),
749 },
750 Expr::Binary {
751 left: Box::new(Expr::Identifier("acc".to_string())),
752 op: BinOp::Multiply,
753 right: Box::new(Expr::Identifier("n".to_string())),
754 },
755 ],
756 })],
757 };
758
759 let optimized = optimizer.optimize_tail_recursive_stmt(func_def);
760
761 if let Stmt::FuncDef { body, .. } = optimized {
763 assert!(
766 body.len() >= 3,
767 "Expected at least 3 statements, got {}",
768 body.len()
769 );
770
771 if let Some(Stmt::While { .. }) = body.last() {
773 } else {
775 panic!("Expected While loop at the end of optimized function body");
776 }
777 } else {
778 panic!("Expected FuncDef");
779 }
780 }
781}