1use crate::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
5
6pub struct Optimizer {
8 pub tail_recursion: bool,
10 pub constant_folding: bool,
12 pub dead_code_elimination: bool,
14}
15
16impl Optimizer {
17 pub fn new() -> Self {
19 Optimizer {
20 tail_recursion: true,
21 constant_folding: true,
22 dead_code_elimination: true,
23 }
24 }
25
26 pub fn optimize_program(&self, program: &Program) -> Program {
28 let mut optimized = program.clone();
29
30 if self.constant_folding {
32 optimized = self.fold_constants(optimized);
33 }
34
35 if self.dead_code_elimination {
37 optimized = self.eliminate_dead_code(optimized);
38 }
39
40 if self.tail_recursion {
42 optimized = self.optimize_tail_recursion(optimized);
43 }
44
45 optimized
46 }
47
48 fn fold_constants(&self, program: Program) -> Program {
50 program
51 .into_iter()
52 .map(|stmt| self.fold_stmt(stmt))
53 .collect()
54 }
55
56 fn fold_stmt(&self, stmt: Stmt) -> Stmt {
58 match stmt {
59 Stmt::Set { name, value } => Stmt::Set {
60 name,
61 value: self.fold_expr(value),
62 },
63 Stmt::FuncDef { name, params, body } => Stmt::FuncDef {
64 name,
65 params,
66 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
67 },
68 Stmt::GeneratorDef { name, params, body } => Stmt::GeneratorDef {
69 name,
70 params,
71 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
72 },
73 Stmt::Return(expr) => Stmt::Return(self.fold_expr(expr)),
74 Stmt::Yield(expr) => Stmt::Yield(self.fold_expr(expr)),
75 Stmt::While { condition, body } => Stmt::While {
76 condition: self.fold_expr(condition),
77 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
78 },
79 Stmt::For {
80 var,
81 iterable,
82 body,
83 } => Stmt::For {
84 var,
85 iterable: self.fold_expr(iterable),
86 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
87 },
88 Stmt::ForIndexed {
89 index_var,
90 value_var,
91 iterable,
92 body,
93 } => Stmt::ForIndexed {
94 index_var,
95 value_var,
96 iterable: self.fold_expr(iterable),
97 body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
98 },
99 Stmt::Expression(expr) => Stmt::Expression(self.fold_expr(expr)),
100 other => other,
101 }
102 }
103
104 #[allow(clippy::only_used_in_recursion)]
106 fn fold_expr(&self, expr: Expr) -> Expr {
107 match expr {
108 Expr::Binary { left, op, right } => {
110 let left = self.fold_expr(*left);
111 let right = self.fold_expr(*right);
112
113 if let (Expr::Number(l), Expr::Number(r)) = (&left, &right)
115 && let Some(result) = Self::eval_const_binary(*l, &op, *r)
116 {
117 return Expr::Number(result);
118 }
119
120 Expr::Binary {
121 left: Box::new(left),
122 op,
123 right: Box::new(right),
124 }
125 }
126
127 Expr::Unary { op, expr } => {
129 let expr = self.fold_expr(*expr);
130
131 if let Expr::Number(n) = expr {
132 match op {
133 UnaryOp::Minus => return Expr::Number(-n),
134 UnaryOp::Not => return Expr::Boolean(n == 0.0),
135 }
136 }
137
138 if let (UnaryOp::Not, Expr::Boolean(b)) = (&op, &expr) {
139 return Expr::Boolean(!b);
140 }
141
142 Expr::Unary {
143 op,
144 expr: Box::new(expr),
145 }
146 }
147
148 Expr::Call { func, args } => Expr::Call {
150 func: Box::new(self.fold_expr(*func)),
151 args: args.into_iter().map(|e| self.fold_expr(e)).collect(),
152 },
153
154 Expr::Array(elements) => {
155 Expr::Array(elements.into_iter().map(|e| self.fold_expr(e)).collect())
156 }
157
158 Expr::Index { object, index } => Expr::Index {
159 object: Box::new(self.fold_expr(*object)),
160 index: Box::new(self.fold_expr(*index)),
161 },
162
163 other => other,
164 }
165 }
166
167 fn eval_const_binary(left: f64, op: &BinOp, right: f64) -> Option<f64> {
169 match op {
170 BinOp::Add => Some(left + right),
171 BinOp::Subtract => Some(left - right),
172 BinOp::Multiply => Some(left * right),
173 BinOp::Divide if right != 0.0 => Some(left / right),
174 BinOp::Modulo if right != 0.0 => Some(left % right),
175 _ => None,
176 }
177 }
178
179 fn eliminate_dead_code(&self, program: Program) -> Program {
181 program
182 .into_iter()
183 .filter_map(|stmt| self.eliminate_dead_stmt(stmt))
184 .collect()
185 }
186
187 fn eliminate_dead_stmt(&self, stmt: Stmt) -> Option<Stmt> {
189 match stmt {
190 Stmt::While { condition, body } => {
192 if let Expr::Boolean(false) = condition {
193 return None;
195 }
196
197 Some(Stmt::While {
198 condition,
199 body: body
200 .into_iter()
201 .filter_map(|s| self.eliminate_dead_stmt(s))
202 .collect(),
203 })
204 }
205
206 Stmt::FuncDef { name, params, body } => Some(Stmt::FuncDef {
208 name,
209 params,
210 body: body
211 .into_iter()
212 .filter_map(|s| self.eliminate_dead_stmt(s))
213 .collect(),
214 }),
215
216 Stmt::GeneratorDef { name, params, body } => Some(Stmt::GeneratorDef {
217 name,
218 params,
219 body: body
220 .into_iter()
221 .filter_map(|s| self.eliminate_dead_stmt(s))
222 .collect(),
223 }),
224
225 Stmt::Expression(expr) => Some(Stmt::Expression(self.eliminate_dead_expr(expr))),
227
228 other => Some(other),
229 }
230 }
231
232 fn eliminate_dead_expr(&self, expr: Expr) -> Expr {
234 match expr {
235 Expr::If {
236 condition,
237 then_branch,
238 elif_branches,
239 else_branch,
240 } => {
241 if let Expr::Boolean(true) = *condition {
242 return Expr::If {
244 condition: Box::new(Expr::Boolean(true)),
245 then_branch,
246 elif_branches: vec![],
247 else_branch: None,
248 };
249 }
250
251 if let Expr::Boolean(false) = *condition {
252 if let Some(else_body) = else_branch {
254 return Expr::If {
256 condition: Box::new(Expr::Boolean(true)),
257 then_branch: else_body,
258 elif_branches: vec![],
259 else_branch: None,
260 };
261 }
262 return Expr::Null;
264 }
265
266 Expr::If {
268 condition,
269 then_branch: then_branch
270 .into_iter()
271 .filter_map(|s| self.eliminate_dead_stmt(s))
272 .collect(),
273 elif_branches: elif_branches
274 .into_iter()
275 .map(|(c, b)| {
276 (
277 self.eliminate_dead_expr(c),
278 b.into_iter()
279 .filter_map(|s| self.eliminate_dead_stmt(s))
280 .collect(),
281 )
282 })
283 .collect(),
284 else_branch: else_branch.map(|b| {
285 b.into_iter()
286 .filter_map(|s| self.eliminate_dead_stmt(s))
287 .collect()
288 }),
289 }
290 }
291 other => other,
292 }
293 }
294
295 fn optimize_tail_recursion(&self, program: Program) -> Program {
297 program
298 .into_iter()
299 .map(|stmt| self.optimize_tail_recursive_stmt(stmt))
300 .collect()
301 }
302
303 fn optimize_tail_recursive_stmt(&self, stmt: Stmt) -> Stmt {
305 match stmt {
306 Stmt::FuncDef { name, params, body } => {
307 if self.is_tail_recursive(&name, &body) {
309 Stmt::FuncDef {
311 name: name.clone(),
312 params: params.clone(),
313 body: self.convert_tail_recursion_to_loop(&name, ¶ms, body),
314 }
315 } else {
316 Stmt::FuncDef { name, params, body }
317 }
318 }
319 other => other,
320 }
321 }
322
323 fn is_tail_recursive(&self, func_name: &str, body: &[Stmt]) -> bool {
325 if body.is_empty() {
326 return false;
327 }
328
329 self.has_tail_recursion_in_body(func_name, body)
331 }
332
333 fn has_tail_recursion_in_body(&self, func_name: &str, body: &[Stmt]) -> bool {
335 body.iter()
337 .any(|stmt| self.stmt_has_tail_recursion(func_name, stmt))
338 }
339
340 fn stmt_has_tail_recursion(&self, func_name: &str, stmt: &Stmt) -> bool {
342 match stmt {
343 Stmt::Return(expr) => self.is_tail_call(func_name, expr),
344 Stmt::Expression(expr) => self.expr_has_tail_recursion(func_name, expr),
345 Stmt::While { body, .. } => self.has_tail_recursion_in_body(func_name, body),
346 Stmt::For { body, .. } => self.has_tail_recursion_in_body(func_name, body),
347 Stmt::ForIndexed { body, .. } => self.has_tail_recursion_in_body(func_name, body),
348 _ => false,
349 }
350 }
351
352 fn expr_has_tail_recursion(&self, func_name: &str, expr: &Expr) -> bool {
354 match expr {
355 Expr::If {
356 then_branch,
357 elif_branches,
358 else_branch,
359 ..
360 } => {
361 let then_tail = self.has_tail_recursion_in_body(func_name, then_branch);
363 let elif_tail = elif_branches
364 .iter()
365 .any(|(_, body)| self.has_tail_recursion_in_body(func_name, body));
366 let else_tail = else_branch
367 .as_ref()
368 .map(|body| self.has_tail_recursion_in_body(func_name, body))
369 .unwrap_or(false);
370
371 then_tail || elif_tail || else_tail
372 }
373 _ => false,
374 }
375 }
376
377 fn is_tail_call(&self, func_name: &str, expr: &Expr) -> bool {
379 match expr {
380 Expr::Call { func, .. } => {
382 if let Expr::Identifier(name) = &**func {
383 name == func_name
384 } else {
385 false
386 }
387 }
388 Expr::If {
390 then_branch,
391 elif_branches,
392 else_branch,
393 ..
394 } => {
395 let then_is_tail = self.branch_ends_with_tail_call(func_name, then_branch);
397
398 let elif_all_tail = elif_branches
399 .iter()
400 .all(|(_, body)| self.branch_ends_with_tail_call(func_name, body));
401
402 let else_is_tail = else_branch
403 .as_ref()
404 .map(|body| self.branch_ends_with_tail_call(func_name, body))
405 .unwrap_or(true);
406
407 then_is_tail && elif_all_tail && else_is_tail
408 }
409 _ => false,
410 }
411 }
412
413 fn branch_ends_with_tail_call(&self, func_name: &str, branch: &[Stmt]) -> bool {
415 if let Some(last_stmt) = branch.last() {
416 match last_stmt {
417 Stmt::Return(expr) => self.is_tail_call(func_name, expr),
418 Stmt::Expression(expr) => {
419 self.is_tail_call(func_name, expr)
421 }
422 _ => false,
423 }
424 } else {
425 false
426 }
427 }
428
429 fn convert_tail_recursion_to_loop(
431 &self,
432 func_name: &str,
433 params: &[String],
434 body: Vec<Stmt>,
435 ) -> Vec<Stmt> {
436 let mut new_body = Vec::new();
438
439 for param in params {
441 new_body.push(Stmt::Set {
442 name: format!("_loop_{}", param),
443 value: Expr::Identifier(param.clone()),
444 });
445 }
446
447 new_body.push(Stmt::Set {
449 name: "_loop_continue".to_string(),
450 value: Expr::Boolean(true),
451 });
452
453 let loop_body = self.transform_body_to_loop(func_name, params, body);
455
456 new_body.push(Stmt::While {
458 condition: Expr::Identifier("_loop_continue".to_string()),
459 body: loop_body,
460 });
461
462 new_body
463 }
464
465 fn transform_body_to_loop(
467 &self,
468 func_name: &str,
469 params: &[String],
470 body: Vec<Stmt>,
471 ) -> Vec<Stmt> {
472 let mut loop_body = Vec::new();
473
474 for stmt in body {
475 match stmt {
476 Stmt::Return(expr) => {
477 if let Some(new_args) = self.extract_tail_call_args(func_name, &expr) {
479 for (i, param) in params.iter().enumerate() {
481 if let Some(arg) = new_args.get(i) {
482 loop_body.push(Stmt::Set {
483 name: format!("_loop_{}", param),
484 value: arg.clone(),
485 });
486 }
487 }
488
489 for param in params {
491 loop_body.push(Stmt::Set {
492 name: param.clone(),
493 value: Expr::Identifier(format!("_loop_{}", param)),
494 });
495 }
496
497 } else {
499 loop_body.push(Stmt::Set {
501 name: "_loop_continue".to_string(),
502 value: Expr::Boolean(false),
503 });
504 loop_body.push(Stmt::Return(expr));
505 }
506 }
507 _ => {
508 loop_body.push(self.transform_stmt_for_loop(func_name, params, stmt));
510 }
511 }
512 }
513
514 loop_body
515 }
516
517 fn extract_tail_call_args(&self, func_name: &str, expr: &Expr) -> Option<Vec<Expr>> {
519 match expr {
520 Expr::Call { func, args } => {
521 if let Expr::Identifier(name) = &**func
522 && name == func_name
523 {
524 return Some(args.clone());
525 }
526 None
527 }
528 _ => None,
529 }
530 }
531
532 fn transform_stmt_for_loop(&self, func_name: &str, params: &[String], stmt: Stmt) -> Stmt {
534 match stmt {
535 Stmt::Expression(expr) => {
536 Stmt::Expression(self.transform_expr_for_loop(func_name, params, expr))
538 }
539 Stmt::While { condition, body } => Stmt::While {
540 condition,
541 body: self.transform_body_to_loop(func_name, params, body),
542 },
543 Stmt::For {
544 var,
545 iterable,
546 body,
547 } => Stmt::For {
548 var,
549 iterable,
550 body: self.transform_body_to_loop(func_name, params, body),
551 },
552 Stmt::ForIndexed {
553 index_var,
554 value_var,
555 iterable,
556 body,
557 } => Stmt::ForIndexed {
558 index_var,
559 value_var,
560 iterable,
561 body: self.transform_body_to_loop(func_name, params, body),
562 },
563 other => other,
564 }
565 }
566
567 fn transform_expr_for_loop(&self, func_name: &str, params: &[String], expr: Expr) -> Expr {
569 match expr {
570 Expr::If {
571 condition,
572 then_branch,
573 elif_branches,
574 else_branch,
575 } => Expr::If {
576 condition,
577 then_branch: self.transform_body_to_loop(func_name, params, then_branch),
578 elif_branches: elif_branches
579 .into_iter()
580 .map(|(cond, body)| {
581 (cond, self.transform_body_to_loop(func_name, params, body))
582 })
583 .collect(),
584 else_branch: else_branch
585 .map(|body| self.transform_body_to_loop(func_name, params, body)),
586 },
587 other => other,
588 }
589 }
590}
591
592impl Default for Optimizer {
593 fn default() -> Self {
594 Self::new()
595 }
596}
597
598#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn test_constant_folding() {
605 let optimizer = Optimizer::new();
606
607 let expr = Expr::Binary {
609 left: Box::new(Expr::Number(2.0)),
610 op: BinOp::Add,
611 right: Box::new(Expr::Number(3.0)),
612 };
613
614 let folded = optimizer.fold_expr(expr);
615 assert_eq!(folded, Expr::Number(5.0));
616 }
617
618 #[test]
619 fn test_dead_code_elimination() {
620 let optimizer = Optimizer::new();
621
622 let stmt = Stmt::While {
624 condition: Expr::Boolean(false),
625 body: vec![Stmt::Set {
626 name: "x".to_string(),
627 value: Expr::Number(10.0),
628 }],
629 };
630
631 let result = optimizer.eliminate_dead_stmt(stmt);
632 assert!(result.is_none());
633 }
634
635 #[test]
636 fn test_tail_recursion_detection() {
637 let optimizer = Optimizer::new();
638
639 let body = vec![Stmt::Return(Expr::Call {
641 func: Box::new(Expr::Identifier("factorial".to_string())),
642 args: vec![
643 Expr::Binary {
644 left: Box::new(Expr::Identifier("n".to_string())),
645 op: BinOp::Subtract,
646 right: Box::new(Expr::Number(1.0)),
647 },
648 Expr::Binary {
649 left: Box::new(Expr::Identifier("acc".to_string())),
650 op: BinOp::Multiply,
651 right: Box::new(Expr::Identifier("n".to_string())),
652 },
653 ],
654 })];
655
656 assert!(optimizer.is_tail_recursive("factorial", &body));
657 }
658
659 #[test]
660 fn test_non_tail_recursion_detection() {
661 let optimizer = Optimizer::new();
662
663 let body = vec![Stmt::Return(Expr::Binary {
665 left: Box::new(Expr::Identifier("n".to_string())),
666 op: BinOp::Multiply,
667 right: Box::new(Expr::Call {
668 func: Box::new(Expr::Identifier("factorial".to_string())),
669 args: vec![Expr::Binary {
670 left: Box::new(Expr::Identifier("n".to_string())),
671 op: BinOp::Subtract,
672 right: Box::new(Expr::Number(1.0)),
673 }],
674 }),
675 })];
676
677 assert!(!optimizer.is_tail_recursive("factorial", &body));
678 }
679
680 #[test]
681 fn test_tail_recursion_in_if() {
682 let optimizer = Optimizer::new();
683
684 let body = vec![Stmt::Expression(Expr::If {
688 condition: Box::new(Expr::Binary {
689 left: Box::new(Expr::Identifier("n".to_string())),
690 op: BinOp::LessEqual,
691 right: Box::new(Expr::Number(0.0)),
692 }),
693 then_branch: vec![Stmt::Return(Expr::Identifier("acc".to_string()))],
694 elif_branches: vec![],
695 else_branch: Some(vec![Stmt::Return(Expr::Call {
696 func: Box::new(Expr::Identifier("sum".to_string())),
697 args: vec![
698 Expr::Binary {
699 left: Box::new(Expr::Identifier("n".to_string())),
700 op: BinOp::Subtract,
701 right: Box::new(Expr::Number(1.0)),
702 },
703 Expr::Binary {
704 left: Box::new(Expr::Identifier("acc".to_string())),
705 op: BinOp::Add,
706 right: Box::new(Expr::Identifier("n".to_string())),
707 },
708 ],
709 })]),
710 })];
711
712 assert!(optimizer.is_tail_recursive("sum", &body));
713 }
714
715 #[test]
716 fn test_tail_recursion_optimization_transform() {
717 let optimizer = Optimizer::new();
718
719 let func_def = Stmt::FuncDef {
721 name: "factorial".to_string(),
722 params: vec!["n".to_string(), "acc".to_string()],
723 body: vec![Stmt::Return(Expr::Call {
724 func: Box::new(Expr::Identifier("factorial".to_string())),
725 args: vec![
726 Expr::Binary {
727 left: Box::new(Expr::Identifier("n".to_string())),
728 op: BinOp::Subtract,
729 right: Box::new(Expr::Number(1.0)),
730 },
731 Expr::Binary {
732 left: Box::new(Expr::Identifier("acc".to_string())),
733 op: BinOp::Multiply,
734 right: Box::new(Expr::Identifier("n".to_string())),
735 },
736 ],
737 })],
738 };
739
740 let optimized = optimizer.optimize_tail_recursive_stmt(func_def);
741
742 if let Stmt::FuncDef { body, .. } = optimized {
744 assert!(
747 body.len() >= 3,
748 "Expected at least 3 statements, got {}",
749 body.len()
750 );
751
752 if let Some(Stmt::While { .. }) = body.last() {
754 } else {
756 panic!("Expected While loop at the end of optimized function body");
757 }
758 } else {
759 panic!("Expected FuncDef");
760 }
761 }
762}