1use crate::ast::{self, BinOp, Expr, Ident, Item, Literal, Stmt, UnaryOp, Block, NumBase, TypePath, PathSegment, Pattern, Param, TypeExpr, Visibility, FunctionAttrs};
7use crate::span::Span;
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OptLevel {
17 None,
19 Basic,
21 Standard,
23 Aggressive,
25 Size,
27}
28
29#[derive(Debug, Default, Clone)]
31pub struct OptStats {
32 pub constants_folded: usize,
33 pub dead_code_eliminated: usize,
34 pub expressions_deduplicated: usize,
35 pub functions_inlined: usize,
36 pub strength_reductions: usize,
37 pub branches_simplified: usize,
38 pub loops_optimized: usize,
39 pub tail_recursion_transforms: usize,
40 pub memoization_transforms: usize,
41}
42
43pub struct Optimizer {
45 level: OptLevel,
46 stats: OptStats,
47 functions: HashMap<String, ast::Function>,
49 recursive_functions: HashSet<String>,
51 cse_counter: usize,
53}
54
55impl Optimizer {
56 pub fn new(level: OptLevel) -> Self {
57 Self {
58 level,
59 stats: OptStats::default(),
60 functions: HashMap::new(),
61 recursive_functions: HashSet::new(),
62 cse_counter: 0,
63 }
64 }
65
66 pub fn stats(&self) -> &OptStats {
68 &self.stats
69 }
70
71 pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
73 for item in &file.items {
75 if let Item::Function(func) = &item.node {
76 self.functions.insert(func.name.name.clone(), func.clone());
77 if self.is_recursive(&func.name.name, func) {
78 self.recursive_functions.insert(func.name.name.clone());
79 }
80 }
81 }
82
83 let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
87 let mut transformed_functions: HashMap<String, String> = HashMap::new();
88
89 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
90 for item in &file.items {
91 if let Item::Function(func) = &item.node {
92 if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func) {
93 new_items.push(crate::span::Spanned {
95 node: Item::Function(helper_func),
96 span: item.span.clone(),
97 });
98 transformed_functions.insert(func.name.name.clone(), wrapper_func.name.name.clone());
99 self.stats.tail_recursion_transforms += 1;
100 }
101 }
102 }
103
104 }
110
111 let items: Vec<_> = file.items.iter().map(|item| {
113 let node = match &item.node {
114 Item::Function(func) => {
115 if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
117 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
118 && transformed_functions.contains_key(&func.name.name) {
119 Item::Function(self.optimize_function(&wrapper))
120 } else {
121 Item::Function(self.optimize_function(func))
122 }
123 } else {
124 Item::Function(self.optimize_function(func))
125 }
126 }
127 other => other.clone(),
128 };
129 crate::span::Spanned { node, span: item.span.clone() }
130 }).collect();
131
132 new_items.extend(items);
134
135 ast::SourceFile {
136 attrs: file.attrs.clone(),
137 config: file.config.clone(),
138 items: new_items,
139 }
140 }
141
142 fn try_accumulator_transform(&self, func: &ast::Function) -> Option<(ast::Function, ast::Function)> {
145 if func.params.len() != 1 {
147 return None;
148 }
149
150 if !self.recursive_functions.contains(&func.name.name) {
152 return None;
153 }
154
155 let body = func.body.as_ref()?;
156
157 if !self.is_fib_like_pattern(&func.name.name, body) {
161 return None;
162 }
163
164 let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
166 name.name.clone()
167 } else {
168 return None;
169 };
170
171 let helper_name = format!("{}_tail", func.name.name);
173
174 let helper_func = self.generate_fib_helper(&helper_name, ¶m_name);
176
177 let wrapper_func = self.generate_fib_wrapper(&func.name.name, &helper_name, ¶m_name, func);
179
180 Some((helper_func, wrapper_func))
181 }
182
183 fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
185 if body.stmts.is_empty() && body.expr.is_none() {
192 return false;
193 }
194
195 if let Some(expr) = &body.expr {
197 if let Expr::If { else_branch: Some(else_expr), .. } = expr.as_ref() {
198 return self.is_double_recursive_expr(func_name, else_expr);
201 }
202 }
203
204 if body.stmts.len() >= 1 {
206 if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
208 if let Expr::Return(Some(ret_expr)) = expr {
209 return self.is_double_recursive_expr(func_name, ret_expr);
210 }
211 }
212 if let Some(expr) = &body.expr {
213 return self.is_double_recursive_expr(func_name, expr);
214 }
215 }
216
217 false
218 }
219
220 fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
222 if let Expr::Binary { op: BinOp::Add, left, right } = expr {
223 let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
224 let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
225 return left_is_recursive && right_is_recursive;
226 }
227 false
228 }
229
230 fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
232 if let Expr::Call { func, args } = expr {
233 if let Expr::Path(path) = func.as_ref() {
234 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
235 if args.len() == 1 {
237 if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
238 return true;
239 }
240 }
241 }
242 }
243 }
244 false
245 }
246
247 fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
249 let span = Span { start: 0, end: 0 };
250
251 let n_ident = Ident { name: "n".to_string(), evidentiality: None, affect: None, span: span.clone() };
256 let a_ident = Ident { name: "a".to_string(), evidentiality: None, affect: None, span: span.clone() };
257 let b_ident = Ident { name: "b".to_string(), evidentiality: None, affect: None, span: span.clone() };
258
259 let params = vec![
260 Param {
261 pattern: Pattern::Ident { mutable: false, name: n_ident.clone(), evidentiality: None },
262 ty: TypeExpr::Infer,
263 },
264 Param {
265 pattern: Pattern::Ident { mutable: false, name: a_ident.clone(), evidentiality: None },
266 ty: TypeExpr::Infer,
267 },
268 Param {
269 pattern: Pattern::Ident { mutable: false, name: b_ident.clone(), evidentiality: None },
270 ty: TypeExpr::Infer,
271 },
272 ];
273
274 let condition = Expr::Binary {
276 op: BinOp::Le,
277 left: Box::new(Expr::Path(TypePath {
278 segments: vec![PathSegment { ident: n_ident.clone(), generics: None }],
279 })),
280 right: Box::new(Expr::Literal(Literal::Int { value: "0".to_string(), base: NumBase::Decimal, suffix: None })),
281 };
282
283 let then_branch = Block {
285 stmts: vec![],
286 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(TypePath {
287 segments: vec![PathSegment { ident: a_ident.clone(), generics: None }],
288 })))))),
289 };
290
291 let recursive_call = Expr::Call {
293 func: Box::new(Expr::Path(TypePath {
294 segments: vec![PathSegment {
295 ident: Ident { name: name.to_string(), evidentiality: None, affect: None, span: span.clone() },
296 generics: None,
297 }],
298 })),
299 args: vec![
300 Expr::Binary {
302 op: BinOp::Sub,
303 left: Box::new(Expr::Path(TypePath {
304 segments: vec![PathSegment { ident: n_ident.clone(), generics: None }],
305 })),
306 right: Box::new(Expr::Literal(Literal::Int { value: "1".to_string(), base: NumBase::Decimal, suffix: None })),
307 },
308 Expr::Path(TypePath {
310 segments: vec![PathSegment { ident: b_ident.clone(), generics: None }],
311 }),
312 Expr::Binary {
314 op: BinOp::Add,
315 left: Box::new(Expr::Path(TypePath {
316 segments: vec![PathSegment { ident: a_ident.clone(), generics: None }],
317 })),
318 right: Box::new(Expr::Path(TypePath {
319 segments: vec![PathSegment { ident: b_ident.clone(), generics: None }],
320 })),
321 },
322 ],
323 };
324
325 let body = Block {
327 stmts: vec![],
328 expr: Some(Box::new(Expr::If {
329 condition: Box::new(condition),
330 then_branch,
331 else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
332 })),
333 };
334
335 ast::Function {
336 visibility: Visibility::default(),
337 is_async: false,
338 attrs: FunctionAttrs::default(),
339 name: Ident { name: name.to_string(), evidentiality: None, affect: None, span: span.clone() },
340 aspect: None,
341 generics: None,
342 params,
343 return_type: None,
344 where_clause: None,
345 body: Some(body),
346 }
347 }
348
349 fn generate_fib_wrapper(&self, name: &str, helper_name: &str, param_name: &str, original: &ast::Function) -> ast::Function {
351 let span = Span { start: 0, end: 0 };
352
353 let call_helper = Expr::Call {
355 func: Box::new(Expr::Path(TypePath {
356 segments: vec![PathSegment {
357 ident: Ident { name: helper_name.to_string(), evidentiality: None, affect: None, span: span.clone() },
358 generics: None,
359 }],
360 })),
361 args: vec![
362 Expr::Path(TypePath {
364 segments: vec![PathSegment {
365 ident: Ident { name: param_name.to_string(), evidentiality: None, affect: None, span: span.clone() },
366 generics: None,
367 }],
368 }),
369 Expr::Literal(Literal::Int { value: "0".to_string(), base: NumBase::Decimal, suffix: None }),
371 Expr::Literal(Literal::Int { value: "1".to_string(), base: NumBase::Decimal, suffix: None }),
373 ],
374 };
375
376 let body = Block {
377 stmts: vec![],
378 expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
379 };
380
381 ast::Function {
382 visibility: original.visibility,
383 is_async: original.is_async,
384 attrs: original.attrs.clone(),
385 name: Ident { name: name.to_string(), evidentiality: None, affect: None, span: span.clone() },
386 aspect: original.aspect,
387 generics: original.generics.clone(),
388 params: original.params.clone(),
389 return_type: original.return_type.clone(),
390 where_clause: original.where_clause.clone(),
391 body: Some(body),
392 }
393 }
394
395 #[allow(dead_code)]
402 fn try_memoize_transform(&self, func: &ast::Function) -> Option<(ast::Function, ast::Function, ast::Function)> {
403 let param_count = func.params.len();
404 if param_count != 1 && param_count != 2 {
405 return None;
406 }
407
408 let span = Span { start: 0, end: 0 };
409 let func_name = &func.name.name;
410 let impl_name = format!("_memo_impl_{}", func_name);
411 let _cache_name = format!("_memo_cache_{}", func_name);
412 let init_name = format!("_memo_init_{}", func_name);
413
414 let param_names: Vec<String> = func.params.iter().filter_map(|p| {
416 if let Pattern::Ident { name, .. } = &p.pattern {
417 Some(name.name.clone())
418 } else {
419 None
420 }
421 }).collect();
422
423 if param_names.len() != param_count {
424 return None;
425 }
426
427 let impl_func = ast::Function {
429 visibility: Visibility::default(),
430 is_async: func.is_async,
431 attrs: func.attrs.clone(),
432 name: Ident { name: impl_name.clone(), evidentiality: None, affect: None, span: span.clone() },
433 aspect: func.aspect,
434 generics: func.generics.clone(),
435 params: func.params.clone(),
436 return_type: func.return_type.clone(),
437 where_clause: func.where_clause.clone(),
438 body: func.body.as_ref().map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
439 };
440
441 let cache_init_body = Block {
444 stmts: vec![],
445 expr: Some(Box::new(Expr::Call {
446 func: Box::new(Expr::Path(TypePath {
447 segments: vec![PathSegment {
448 ident: Ident { name: "sigil_memo_new".to_string(), evidentiality: None, affect: None, span: span.clone() },
449 generics: None,
450 }],
451 })),
452 args: vec![Expr::Literal(Literal::Int {
453 value: "65536".to_string(),
454 base: NumBase::Decimal,
455 suffix: None
456 })],
457 })),
458 };
459
460 let cache_init_func = ast::Function {
461 visibility: Visibility::default(),
462 is_async: false,
463 attrs: FunctionAttrs::default(),
464 name: Ident { name: init_name.clone(), evidentiality: None, affect: None, span: span.clone() },
465 aspect: None,
466 generics: None,
467 params: vec![],
468 return_type: None,
469 where_clause: None,
470 body: Some(cache_init_body),
471 };
472
473 let wrapper_func = self.generate_memo_wrapper(func, &impl_name, ¶m_names);
475
476 Some((impl_func, cache_init_func, wrapper_func))
477 }
478
479 #[allow(dead_code)]
481 fn generate_memo_wrapper(&self, original: &ast::Function, impl_name: &str, param_names: &[String]) -> ast::Function {
482 let span = Span { start: 0, end: 0 };
483 let param_count = param_names.len();
484
485 let cache_var = Ident { name: "__cache".to_string(), evidentiality: None, affect: None, span: span.clone() };
487 let result_var = Ident { name: "__result".to_string(), evidentiality: None, affect: None, span: span.clone() };
488 let cached_var = Ident { name: "__cached".to_string(), evidentiality: None, affect: None, span: span.clone() };
489
490 let mut stmts = vec![];
491
492 stmts.push(Stmt::Let {
494 pattern: Pattern::Ident { mutable: false, name: cache_var.clone(), evidentiality: None },
495 ty: None,
496 init: Some(Expr::Call {
497 func: Box::new(Expr::Path(TypePath {
498 segments: vec![PathSegment {
499 ident: Ident { name: "sigil_memo_new".to_string(), evidentiality: None, affect: None, span: span.clone() },
500 generics: None,
501 }],
502 })),
503 args: vec![Expr::Literal(Literal::Int {
504 value: "65536".to_string(),
505 base: NumBase::Decimal,
506 suffix: None
507 })],
508 }),
509 });
510
511 let get_fn_name = if param_count == 1 { "sigil_memo_get_1" } else { "sigil_memo_get_2" };
513 let mut get_args = vec![Expr::Path(TypePath {
514 segments: vec![PathSegment { ident: cache_var.clone(), generics: None }],
515 })];
516 for name in param_names {
517 get_args.push(Expr::Path(TypePath {
518 segments: vec![PathSegment {
519 ident: Ident { name: name.clone(), evidentiality: None, affect: None, span: span.clone() },
520 generics: None,
521 }],
522 }));
523 }
524
525 stmts.push(Stmt::Let {
526 pattern: Pattern::Ident { mutable: false, name: cached_var.clone(), evidentiality: None },
527 ty: None,
528 init: Some(Expr::Call {
529 func: Box::new(Expr::Path(TypePath {
530 segments: vec![PathSegment {
531 ident: Ident { name: get_fn_name.to_string(), evidentiality: None, affect: None, span: span.clone() },
532 generics: None,
533 }],
534 })),
535 args: get_args,
536 }),
537 });
538
539 let cache_check = Expr::If {
542 condition: Box::new(Expr::Binary {
543 op: BinOp::Ne,
544 left: Box::new(Expr::Path(TypePath {
545 segments: vec![PathSegment { ident: cached_var.clone(), generics: None }],
546 })),
547 right: Box::new(Expr::Unary {
548 op: UnaryOp::Neg,
549 expr: Box::new(Expr::Literal(Literal::Int {
550 value: "9223372036854775807".to_string(),
551 base: NumBase::Decimal,
552 suffix: None,
553 })),
554 }),
555 }),
556 then_branch: Block {
557 stmts: vec![],
558 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(TypePath {
559 segments: vec![PathSegment { ident: cached_var.clone(), generics: None }],
560 })))))),
561 },
562 else_branch: None,
563 };
564 stmts.push(Stmt::Semi(cache_check));
565
566 let mut impl_args = vec![];
568 for name in param_names {
569 impl_args.push(Expr::Path(TypePath {
570 segments: vec![PathSegment {
571 ident: Ident { name: name.clone(), evidentiality: None, affect: None, span: span.clone() },
572 generics: None,
573 }],
574 }));
575 }
576
577 stmts.push(Stmt::Let {
578 pattern: Pattern::Ident { mutable: false, name: result_var.clone(), evidentiality: None },
579 ty: None,
580 init: Some(Expr::Call {
581 func: Box::new(Expr::Path(TypePath {
582 segments: vec![PathSegment {
583 ident: Ident { name: impl_name.to_string(), evidentiality: None, affect: None, span: span.clone() },
584 generics: None,
585 }],
586 })),
587 args: impl_args,
588 }),
589 });
590
591 let set_fn_name = if param_count == 1 { "sigil_memo_set_1" } else { "sigil_memo_set_2" };
593 let mut set_args = vec![Expr::Path(TypePath {
594 segments: vec![PathSegment { ident: cache_var.clone(), generics: None }],
595 })];
596 for name in param_names {
597 set_args.push(Expr::Path(TypePath {
598 segments: vec![PathSegment {
599 ident: Ident { name: name.clone(), evidentiality: None, affect: None, span: span.clone() },
600 generics: None,
601 }],
602 }));
603 }
604 set_args.push(Expr::Path(TypePath {
605 segments: vec![PathSegment { ident: result_var.clone(), generics: None }],
606 }));
607
608 stmts.push(Stmt::Semi(Expr::Call {
609 func: Box::new(Expr::Path(TypePath {
610 segments: vec![PathSegment {
611 ident: Ident { name: set_fn_name.to_string(), evidentiality: None, affect: None, span: span.clone() },
612 generics: None,
613 }],
614 })),
615 args: set_args,
616 }));
617
618 let body = Block {
620 stmts,
621 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(TypePath {
622 segments: vec![PathSegment { ident: result_var.clone(), generics: None }],
623 })))))),
624 };
625
626 ast::Function {
627 visibility: original.visibility,
628 is_async: original.is_async,
629 attrs: original.attrs.clone(),
630 name: original.name.clone(),
631 aspect: original.aspect,
632 generics: original.generics.clone(),
633 params: original.params.clone(),
634 return_type: original.return_type.clone(),
635 where_clause: original.where_clause.clone(),
636 body: Some(body),
637 }
638 }
639
640 #[allow(dead_code)]
642 fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
643 block.clone()
645 }
646
647 fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
649 if let Some(body) = &func.body {
650 self.block_calls_function(name, body)
651 } else {
652 false
653 }
654 }
655
656 fn block_calls_function(&self, name: &str, block: &Block) -> bool {
657 for stmt in &block.stmts {
658 if self.stmt_calls_function(name, stmt) {
659 return true;
660 }
661 }
662 if let Some(expr) = &block.expr {
663 if self.expr_calls_function(name, expr) {
664 return true;
665 }
666 }
667 false
668 }
669
670 fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
671 match stmt {
672 Stmt::Let { init: Some(expr), .. } => self.expr_calls_function(name, expr),
673 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
674 _ => false,
675 }
676 }
677
678 fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
679 match expr {
680 Expr::Call { func, args } => {
681 if let Expr::Path(path) = func.as_ref() {
682 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
683 return true;
684 }
685 }
686 args.iter().any(|a| self.expr_calls_function(name, a))
687 }
688 Expr::Binary { left, right, .. } => {
689 self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
690 }
691 Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
692 Expr::If { condition, then_branch, else_branch } => {
693 self.expr_calls_function(name, condition)
694 || self.block_calls_function(name, then_branch)
695 || else_branch.as_ref().map(|e| self.expr_calls_function(name, e)).unwrap_or(false)
696 }
697 Expr::While { condition, body } => {
698 self.expr_calls_function(name, condition) || self.block_calls_function(name, body)
699 }
700 Expr::Block(block) => self.block_calls_function(name, block),
701 Expr::Return(Some(e)) => self.expr_calls_function(name, e),
702 _ => false,
703 }
704 }
705
706 fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
708 self.cse_counter = 0;
710
711 let body = if let Some(body) = &func.body {
712 let optimized = match self.level {
714 OptLevel::None => body.clone(),
715 OptLevel::Basic => {
716 let b = self.pass_constant_fold_block(body);
717 self.pass_dead_code_block(&b)
718 }
719 OptLevel::Standard | OptLevel::Size => {
720 let b = self.pass_constant_fold_block(body);
721 let b = self.pass_inline_block(&b); let b = self.pass_strength_reduce_block(&b);
723 let b = self.pass_licm_block(&b); let b = self.pass_cse_block(&b); let b = self.pass_dead_code_block(&b);
726 self.pass_simplify_branches_block(&b)
727 }
728 OptLevel::Aggressive => {
729 let mut b = body.clone();
731 for _ in 0..3 {
732 b = self.pass_constant_fold_block(&b);
733 b = self.pass_inline_block(&b); b = self.pass_strength_reduce_block(&b);
735 b = self.pass_loop_unroll_block(&b); b = self.pass_licm_block(&b); b = self.pass_cse_block(&b); b = self.pass_dead_code_block(&b);
739 b = self.pass_simplify_branches_block(&b);
740 }
741 b
742 }
743 };
744 Some(optimized)
745 } else {
746 None
747 };
748
749 ast::Function {
750 visibility: func.visibility.clone(),
751 is_async: func.is_async,
752 attrs: func.attrs.clone(),
753 name: func.name.clone(),
754 aspect: func.aspect,
755 generics: func.generics.clone(),
756 params: func.params.clone(),
757 return_type: func.return_type.clone(),
758 where_clause: func.where_clause.clone(),
759 body,
760 }
761 }
762
763 fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
768 let stmts = block.stmts.iter().map(|s| self.pass_constant_fold_stmt(s)).collect();
769 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_constant_fold_expr(e)));
770 Block { stmts, expr }
771 }
772
773 fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
774 match stmt {
775 Stmt::Let { pattern, ty, init, .. } => {
776 Stmt::Let {
777 pattern: pattern.clone(),
778 ty: ty.clone(),
779 init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
780 }
781 }
782 Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
783 Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
784 Stmt::Item(item) => Stmt::Item(item.clone()),
785 }
786 }
787
788 fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
789 match expr {
790 Expr::Binary { op, left, right } => {
791 let left = Box::new(self.pass_constant_fold_expr(left));
792 let right = Box::new(self.pass_constant_fold_expr(right));
793
794 if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
796 if let Some(result) = self.fold_binary(op.clone(), l, r) {
797 self.stats.constants_folded += 1;
798 return Expr::Literal(Literal::Int {
799 value: result.to_string(),
800 base: NumBase::Decimal,
801 suffix: None,
802 });
803 }
804 }
805
806 Expr::Binary { op: op.clone(), left, right }
807 }
808 Expr::Unary { op, expr: inner } => {
809 let inner = Box::new(self.pass_constant_fold_expr(inner));
810
811 if let Some(v) = self.as_int(&inner) {
812 if let Some(result) = self.fold_unary(*op, v) {
813 self.stats.constants_folded += 1;
814 return Expr::Literal(Literal::Int {
815 value: result.to_string(),
816 base: NumBase::Decimal,
817 suffix: None,
818 });
819 }
820 }
821
822 Expr::Unary { op: *op, expr: inner }
823 }
824 Expr::If { condition, then_branch, else_branch } => {
825 let condition = Box::new(self.pass_constant_fold_expr(condition));
826 let then_branch = self.pass_constant_fold_block(then_branch);
827 let else_branch = else_branch.as_ref().map(|e| Box::new(self.pass_constant_fold_expr(e)));
828
829 if let Some(cond) = self.as_bool(&condition) {
831 self.stats.branches_simplified += 1;
832 if cond {
833 return Expr::Block(then_branch);
834 } else if let Some(else_expr) = else_branch {
835 return *else_expr;
836 } else {
837 return Expr::Literal(Literal::Bool(false));
838 }
839 }
840
841 Expr::If { condition, then_branch, else_branch }
842 }
843 Expr::While { condition, body } => {
844 let condition = Box::new(self.pass_constant_fold_expr(condition));
845 let body = self.pass_constant_fold_block(body);
846
847 if let Some(false) = self.as_bool(&condition) {
849 self.stats.branches_simplified += 1;
850 return Expr::Block(Block { stmts: vec![], expr: None });
851 }
852
853 Expr::While { condition, body }
854 }
855 Expr::Block(block) => {
856 Expr::Block(self.pass_constant_fold_block(block))
857 }
858 Expr::Call { func, args } => {
859 let args = args.iter().map(|a| self.pass_constant_fold_expr(a)).collect();
860 Expr::Call { func: func.clone(), args }
861 }
862 Expr::Return(e) => {
863 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_constant_fold_expr(e))))
864 }
865 Expr::Assign { target, value } => {
866 let value = Box::new(self.pass_constant_fold_expr(value));
867 Expr::Assign { target: target.clone(), value }
868 }
869 Expr::Index { expr: e, index } => {
870 let e = Box::new(self.pass_constant_fold_expr(e));
871 let index = Box::new(self.pass_constant_fold_expr(index));
872 Expr::Index { expr: e, index }
873 }
874 Expr::Array(elements) => {
875 let elements = elements.iter().map(|e| self.pass_constant_fold_expr(e)).collect();
876 Expr::Array(elements)
877 }
878 other => other.clone(),
879 }
880 }
881
882 fn as_int(&self, expr: &Expr) -> Option<i64> {
883 match expr {
884 Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
885 Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
886 _ => None,
887 }
888 }
889
890 fn as_bool(&self, expr: &Expr) -> Option<bool> {
891 match expr {
892 Expr::Literal(Literal::Bool(b)) => Some(*b),
893 Expr::Literal(Literal::Int { value, .. }) => {
894 value.parse::<i64>().ok().map(|v| v != 0)
895 }
896 _ => None,
897 }
898 }
899
900 fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
901 match op {
902 BinOp::Add => Some(l.wrapping_add(r)),
903 BinOp::Sub => Some(l.wrapping_sub(r)),
904 BinOp::Mul => Some(l.wrapping_mul(r)),
905 BinOp::Div if r != 0 => Some(l / r),
906 BinOp::Rem if r != 0 => Some(l % r),
907 BinOp::BitAnd => Some(l & r),
908 BinOp::BitOr => Some(l | r),
909 BinOp::BitXor => Some(l ^ r),
910 BinOp::Shl => Some(l << (r & 63)),
911 BinOp::Shr => Some(l >> (r & 63)),
912 BinOp::Eq => Some(if l == r { 1 } else { 0 }),
913 BinOp::Ne => Some(if l != r { 1 } else { 0 }),
914 BinOp::Lt => Some(if l < r { 1 } else { 0 }),
915 BinOp::Le => Some(if l <= r { 1 } else { 0 }),
916 BinOp::Gt => Some(if l > r { 1 } else { 0 }),
917 BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
918 BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
919 BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
920 _ => None,
921 }
922 }
923
924 fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
925 match op {
926 UnaryOp::Neg => Some(-v),
927 UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
928 _ => None,
929 }
930 }
931
932 fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
937 let stmts = block.stmts.iter().map(|s| self.pass_strength_reduce_stmt(s)).collect();
938 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_strength_reduce_expr(e)));
939 Block { stmts, expr }
940 }
941
942 fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
943 match stmt {
944 Stmt::Let { pattern, ty, init, .. } => {
945 Stmt::Let {
946 pattern: pattern.clone(),
947 ty: ty.clone(),
948 init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
949 }
950 }
951 Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
952 Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
953 Stmt::Item(item) => Stmt::Item(item.clone()),
954 }
955 }
956
957 fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
958 match expr {
959 Expr::Binary { op, left, right } => {
960 let left = Box::new(self.pass_strength_reduce_expr(left));
961 let right = Box::new(self.pass_strength_reduce_expr(right));
962
963 if *op == BinOp::Mul {
965 if let Some(n) = self.as_int(&right) {
966 if n > 0 && (n as u64).is_power_of_two() {
967 self.stats.strength_reductions += 1;
968 let shift = (n as u64).trailing_zeros() as i64;
969 return Expr::Binary {
970 op: BinOp::Shl,
971 left,
972 right: Box::new(Expr::Literal(Literal::Int {
973 value: shift.to_string(),
974 base: NumBase::Decimal,
975 suffix: None,
976 })),
977 };
978 }
979 }
980 if let Some(n) = self.as_int(&left) {
981 if n > 0 && (n as u64).is_power_of_two() {
982 self.stats.strength_reductions += 1;
983 let shift = (n as u64).trailing_zeros() as i64;
984 return Expr::Binary {
985 op: BinOp::Shl,
986 left: right,
987 right: Box::new(Expr::Literal(Literal::Int {
988 value: shift.to_string(),
989 base: NumBase::Decimal,
990 suffix: None,
991 })),
992 };
993 }
994 }
995 }
996
997 if let Some(n) = self.as_int(&right) {
999 match (op, n) {
1000 (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1001 | (BinOp::Mul | BinOp::Div, 1)
1002 | (BinOp::Shl | BinOp::Shr, 0) => {
1003 self.stats.strength_reductions += 1;
1004 return *left;
1005 }
1006 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1007 self.stats.strength_reductions += 1;
1008 return Expr::Literal(Literal::Int {
1009 value: "0".to_string(),
1010 base: NumBase::Decimal,
1011 suffix: None,
1012 });
1013 }
1014 _ => {}
1015 }
1016 }
1017
1018 if let Some(n) = self.as_int(&left) {
1020 match (op, n) {
1021 (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0)
1022 | (BinOp::Mul, 1) => {
1023 self.stats.strength_reductions += 1;
1024 return *right;
1025 }
1026 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1027 self.stats.strength_reductions += 1;
1028 return Expr::Literal(Literal::Int {
1029 value: "0".to_string(),
1030 base: NumBase::Decimal,
1031 suffix: None,
1032 });
1033 }
1034 _ => {}
1035 }
1036 }
1037
1038 Expr::Binary { op: op.clone(), left, right }
1039 }
1040 Expr::Unary { op, expr: inner } => {
1041 let inner = Box::new(self.pass_strength_reduce_expr(inner));
1042
1043 if *op == UnaryOp::Neg {
1045 if let Expr::Unary { op: UnaryOp::Neg, expr: inner2 } = inner.as_ref() {
1046 self.stats.strength_reductions += 1;
1047 return *inner2.clone();
1048 }
1049 }
1050
1051 if *op == UnaryOp::Not {
1053 if let Expr::Unary { op: UnaryOp::Not, expr: inner2 } = inner.as_ref() {
1054 self.stats.strength_reductions += 1;
1055 return *inner2.clone();
1056 }
1057 }
1058
1059 Expr::Unary { op: *op, expr: inner }
1060 }
1061 Expr::If { condition, then_branch, else_branch } => {
1062 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1063 let then_branch = self.pass_strength_reduce_block(then_branch);
1064 let else_branch = else_branch.as_ref().map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1065 Expr::If { condition, then_branch, else_branch }
1066 }
1067 Expr::While { condition, body } => {
1068 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1069 let body = self.pass_strength_reduce_block(body);
1070 Expr::While { condition, body }
1071 }
1072 Expr::Block(block) => {
1073 Expr::Block(self.pass_strength_reduce_block(block))
1074 }
1075 Expr::Call { func, args } => {
1076 let args = args.iter().map(|a| self.pass_strength_reduce_expr(a)).collect();
1077 Expr::Call { func: func.clone(), args }
1078 }
1079 Expr::Return(e) => {
1080 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_strength_reduce_expr(e))))
1081 }
1082 Expr::Assign { target, value } => {
1083 let value = Box::new(self.pass_strength_reduce_expr(value));
1084 Expr::Assign { target: target.clone(), value }
1085 }
1086 other => other.clone(),
1087 }
1088 }
1089
1090 fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1095 let mut stmts = Vec::new();
1097 let mut found_return = false;
1098
1099 for stmt in &block.stmts {
1100 if found_return {
1101 self.stats.dead_code_eliminated += 1;
1102 continue;
1103 }
1104 let stmt = self.pass_dead_code_stmt(stmt);
1105 if self.stmt_returns(&stmt) {
1106 found_return = true;
1107 }
1108 stmts.push(stmt);
1109 }
1110
1111 let expr = if found_return {
1113 if block.expr.is_some() {
1114 self.stats.dead_code_eliminated += 1;
1115 }
1116 None
1117 } else {
1118 block.expr.as_ref().map(|e| Box::new(self.pass_dead_code_expr(e)))
1119 };
1120
1121 Block { stmts, expr }
1122 }
1123
1124 fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1125 match stmt {
1126 Stmt::Let { pattern, ty, init, .. } => {
1127 Stmt::Let {
1128 pattern: pattern.clone(),
1129 ty: ty.clone(),
1130 init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1131 }
1132 }
1133 Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1134 Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1135 Stmt::Item(item) => Stmt::Item(item.clone()),
1136 }
1137 }
1138
1139 fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1140 match expr {
1141 Expr::If { condition, then_branch, else_branch } => {
1142 let condition = Box::new(self.pass_dead_code_expr(condition));
1143 let then_branch = self.pass_dead_code_block(then_branch);
1144 let else_branch = else_branch.as_ref().map(|e| Box::new(self.pass_dead_code_expr(e)));
1145 Expr::If { condition, then_branch, else_branch }
1146 }
1147 Expr::While { condition, body } => {
1148 let condition = Box::new(self.pass_dead_code_expr(condition));
1149 let body = self.pass_dead_code_block(body);
1150 Expr::While { condition, body }
1151 }
1152 Expr::Block(block) => {
1153 Expr::Block(self.pass_dead_code_block(block))
1154 }
1155 other => other.clone(),
1156 }
1157 }
1158
1159 fn stmt_returns(&self, stmt: &Stmt) -> bool {
1160 match stmt {
1161 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1162 _ => false,
1163 }
1164 }
1165
1166 fn expr_returns(&self, expr: &Expr) -> bool {
1167 match expr {
1168 Expr::Return(_) => true,
1169 Expr::Block(block) => {
1170 block.stmts.iter().any(|s| self.stmt_returns(s))
1171 || block.expr.as_ref().map(|e| self.expr_returns(e)).unwrap_or(false)
1172 }
1173 _ => false,
1174 }
1175 }
1176
1177 fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1182 let stmts = block.stmts.iter().map(|s| self.pass_simplify_branches_stmt(s)).collect();
1183 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1184 Block { stmts, expr }
1185 }
1186
1187 fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1188 match stmt {
1189 Stmt::Let { pattern, ty, init, .. } => {
1190 Stmt::Let {
1191 pattern: pattern.clone(),
1192 ty: ty.clone(),
1193 init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1194 }
1195 }
1196 Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1197 Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1198 Stmt::Item(item) => Stmt::Item(item.clone()),
1199 }
1200 }
1201
1202 fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1203 match expr {
1204 Expr::If { condition, then_branch, else_branch } => {
1205 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1206 let then_branch = self.pass_simplify_branches_block(then_branch);
1207 let else_branch = else_branch.as_ref().map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1208
1209 if let Expr::Unary { op: UnaryOp::Not, expr: inner } = condition.as_ref() {
1211 if let Some(else_expr) = &else_branch {
1212 self.stats.branches_simplified += 1;
1213 let new_else = Some(Box::new(Expr::Block(then_branch)));
1214 let new_then = match else_expr.as_ref() {
1215 Expr::Block(b) => b.clone(),
1216 other => Block { stmts: vec![], expr: Some(Box::new(other.clone())) },
1217 };
1218 return Expr::If {
1219 condition: inner.clone(),
1220 then_branch: new_then,
1221 else_branch: new_else,
1222 };
1223 }
1224 }
1225
1226 Expr::If { condition, then_branch, else_branch }
1227 }
1228 Expr::While { condition, body } => {
1229 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1230 let body = self.pass_simplify_branches_block(body);
1231 Expr::While { condition, body }
1232 }
1233 Expr::Block(block) => {
1234 Expr::Block(self.pass_simplify_branches_block(block))
1235 }
1236 Expr::Binary { op, left, right } => {
1237 let left = Box::new(self.pass_simplify_branches_expr(left));
1238 let right = Box::new(self.pass_simplify_branches_expr(right));
1239 Expr::Binary { op: op.clone(), left, right }
1240 }
1241 Expr::Unary { op, expr: inner } => {
1242 let inner = Box::new(self.pass_simplify_branches_expr(inner));
1243 Expr::Unary { op: *op, expr: inner }
1244 }
1245 Expr::Call { func, args } => {
1246 let args = args.iter().map(|a| self.pass_simplify_branches_expr(a)).collect();
1247 Expr::Call { func: func.clone(), args }
1248 }
1249 Expr::Return(e) => {
1250 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_simplify_branches_expr(e))))
1251 }
1252 other => other.clone(),
1253 }
1254 }
1255
1256 fn should_inline(&self, func: &ast::Function) -> bool {
1262 if self.recursive_functions.contains(&func.name.name) {
1264 return false;
1265 }
1266
1267 if let Some(body) = &func.body {
1269 let stmt_count = self.count_stmts_in_block(body);
1270 stmt_count <= 10
1272 } else {
1273 false
1274 }
1275 }
1276
1277 fn count_stmts_in_block(&self, block: &Block) -> usize {
1279 let mut count = block.stmts.len();
1280 if block.expr.is_some() {
1281 count += 1;
1282 }
1283 for stmt in &block.stmts {
1285 count += self.count_stmts_in_stmt(stmt);
1286 }
1287 count
1288 }
1289
1290 fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1291 match stmt {
1292 Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1293 Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1294 _ => 0,
1295 }
1296 }
1297
1298 fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1299 match expr {
1300 Expr::If { then_branch, else_branch, .. } => {
1301 let mut count = self.count_stmts_in_block(then_branch);
1302 if let Some(else_expr) = else_branch {
1303 count += self.count_stmts_in_expr(else_expr);
1304 }
1305 count
1306 }
1307 Expr::While { body, .. } => self.count_stmts_in_block(body),
1308 Expr::Block(block) => self.count_stmts_in_block(block),
1309 _ => 0,
1310 }
1311 }
1312
1313 fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1315 let body = func.body.as_ref()?;
1316
1317 let mut param_map: HashMap<String, Expr> = HashMap::new();
1319 for (param, arg) in func.params.iter().zip(args.iter()) {
1320 if let Pattern::Ident { name, .. } = ¶m.pattern {
1321 param_map.insert(name.name.clone(), arg.clone());
1322 }
1323 }
1324
1325 let inlined_body = self.substitute_params_in_block(body, ¶m_map);
1327
1328 self.stats.functions_inlined += 1;
1329
1330 if inlined_body.stmts.is_empty() {
1333 if let Some(expr) = inlined_body.expr {
1334 if let Expr::Return(Some(inner)) = expr.as_ref() {
1336 return Some(inner.as_ref().clone());
1337 }
1338 return Some(*expr);
1339 }
1340 }
1341
1342 Some(Expr::Block(inlined_body))
1343 }
1344
1345 fn substitute_params_in_block(&self, block: &Block, param_map: &HashMap<String, Expr>) -> Block {
1347 let stmts = block.stmts.iter().map(|s| self.substitute_params_in_stmt(s, param_map)).collect();
1348 let expr = block.expr.as_ref().map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1349 Block { stmts, expr }
1350 }
1351
1352 fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1353 match stmt {
1354 Stmt::Let { pattern, ty, init } => Stmt::Let {
1355 pattern: pattern.clone(),
1356 ty: ty.clone(),
1357 init: init.as_ref().map(|e| self.substitute_params_in_expr(e, param_map)),
1358 },
1359 Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1360 Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1361 Stmt::Item(item) => Stmt::Item(item.clone()),
1362 }
1363 }
1364
1365 fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1366 match expr {
1367 Expr::Path(path) => {
1368 if path.segments.len() == 1 {
1370 let name = &path.segments[0].ident.name;
1371 if let Some(arg) = param_map.get(name) {
1372 return arg.clone();
1373 }
1374 }
1375 expr.clone()
1376 }
1377 Expr::Binary { op, left, right } => Expr::Binary {
1378 op: op.clone(),
1379 left: Box::new(self.substitute_params_in_expr(left, param_map)),
1380 right: Box::new(self.substitute_params_in_expr(right, param_map)),
1381 },
1382 Expr::Unary { op, expr: inner } => Expr::Unary {
1383 op: *op,
1384 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1385 },
1386 Expr::If { condition, then_branch, else_branch } => Expr::If {
1387 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1388 then_branch: self.substitute_params_in_block(then_branch, param_map),
1389 else_branch: else_branch.as_ref().map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1390 },
1391 Expr::While { condition, body } => Expr::While {
1392 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1393 body: self.substitute_params_in_block(body, param_map),
1394 },
1395 Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1396 Expr::Call { func, args } => Expr::Call {
1397 func: Box::new(self.substitute_params_in_expr(func, param_map)),
1398 args: args.iter().map(|a| self.substitute_params_in_expr(a, param_map)).collect(),
1399 },
1400 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.substitute_params_in_expr(e, param_map)))),
1401 Expr::Assign { target, value } => Expr::Assign {
1402 target: target.clone(),
1403 value: Box::new(self.substitute_params_in_expr(value, param_map)),
1404 },
1405 Expr::Index { expr: e, index } => Expr::Index {
1406 expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1407 index: Box::new(self.substitute_params_in_expr(index, param_map)),
1408 },
1409 Expr::Array(elements) => Expr::Array(
1410 elements.iter().map(|e| self.substitute_params_in_expr(e, param_map)).collect()
1411 ),
1412 other => other.clone(),
1413 }
1414 }
1415
1416 fn pass_inline_block(&mut self, block: &Block) -> Block {
1417 let stmts = block.stmts.iter().map(|s| self.pass_inline_stmt(s)).collect();
1418 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_inline_expr(e)));
1419 Block { stmts, expr }
1420 }
1421
1422 fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1423 match stmt {
1424 Stmt::Let { pattern, ty, init } => Stmt::Let {
1425 pattern: pattern.clone(),
1426 ty: ty.clone(),
1427 init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1428 },
1429 Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1430 Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1431 Stmt::Item(item) => Stmt::Item(item.clone()),
1432 }
1433 }
1434
1435 fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1436 match expr {
1437 Expr::Call { func, args } => {
1438 let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1440
1441 if let Expr::Path(path) = func.as_ref() {
1443 if path.segments.len() == 1 {
1444 let func_name = &path.segments[0].ident.name;
1445 if let Some(target_func) = self.functions.get(func_name).cloned() {
1446 if self.should_inline(&target_func) && args.len() == target_func.params.len() {
1447 if let Some(inlined) = self.inline_call(&target_func, &args) {
1448 return inlined;
1449 }
1450 }
1451 }
1452 }
1453 }
1454
1455 Expr::Call { func: func.clone(), args }
1456 }
1457 Expr::Binary { op, left, right } => Expr::Binary {
1458 op: op.clone(),
1459 left: Box::new(self.pass_inline_expr(left)),
1460 right: Box::new(self.pass_inline_expr(right)),
1461 },
1462 Expr::Unary { op, expr: inner } => Expr::Unary {
1463 op: *op,
1464 expr: Box::new(self.pass_inline_expr(inner)),
1465 },
1466 Expr::If { condition, then_branch, else_branch } => Expr::If {
1467 condition: Box::new(self.pass_inline_expr(condition)),
1468 then_branch: self.pass_inline_block(then_branch),
1469 else_branch: else_branch.as_ref().map(|e| Box::new(self.pass_inline_expr(e))),
1470 },
1471 Expr::While { condition, body } => Expr::While {
1472 condition: Box::new(self.pass_inline_expr(condition)),
1473 body: self.pass_inline_block(body),
1474 },
1475 Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
1476 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
1477 Expr::Assign { target, value } => Expr::Assign {
1478 target: target.clone(),
1479 value: Box::new(self.pass_inline_expr(value)),
1480 },
1481 Expr::Index { expr: e, index } => Expr::Index {
1482 expr: Box::new(self.pass_inline_expr(e)),
1483 index: Box::new(self.pass_inline_expr(index)),
1484 },
1485 Expr::Array(elements) => Expr::Array(
1486 elements.iter().map(|e| self.pass_inline_expr(e)).collect()
1487 ),
1488 other => other.clone(),
1489 }
1490 }
1491
1492 fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
1498 let stmts = block.stmts.iter().map(|s| self.pass_loop_unroll_stmt(s)).collect();
1499 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e)));
1500 Block { stmts, expr }
1501 }
1502
1503 fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
1504 match stmt {
1505 Stmt::Let { pattern, ty, init } => Stmt::Let {
1506 pattern: pattern.clone(),
1507 ty: ty.clone(),
1508 init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
1509 },
1510 Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
1511 Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
1512 Stmt::Item(item) => Stmt::Item(item.clone()),
1513 }
1514 }
1515
1516 fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
1517 match expr {
1518 Expr::While { condition, body } => {
1519 if let Some(unrolled) = self.try_unroll_loop(condition, body) {
1521 self.stats.loops_optimized += 1;
1522 return unrolled;
1523 }
1524 Expr::While {
1526 condition: Box::new(self.pass_loop_unroll_expr(condition)),
1527 body: self.pass_loop_unroll_block(body),
1528 }
1529 }
1530 Expr::If { condition, then_branch, else_branch } => Expr::If {
1531 condition: Box::new(self.pass_loop_unroll_expr(condition)),
1532 then_branch: self.pass_loop_unroll_block(then_branch),
1533 else_branch: else_branch.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))),
1534 },
1535 Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
1536 Expr::Binary { op, left, right } => Expr::Binary {
1537 op: *op,
1538 left: Box::new(self.pass_loop_unroll_expr(left)),
1539 right: Box::new(self.pass_loop_unroll_expr(right)),
1540 },
1541 Expr::Unary { op, expr: inner } => Expr::Unary {
1542 op: *op,
1543 expr: Box::new(self.pass_loop_unroll_expr(inner)),
1544 },
1545 Expr::Call { func, args } => Expr::Call {
1546 func: func.clone(),
1547 args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
1548 },
1549 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e)))),
1550 Expr::Assign { target, value } => Expr::Assign {
1551 target: target.clone(),
1552 value: Box::new(self.pass_loop_unroll_expr(value)),
1553 },
1554 other => other.clone(),
1555 }
1556 }
1557
1558 fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
1561 let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
1563
1564 if upper_bound > 8 || upper_bound <= 0 {
1566 return None;
1567 }
1568
1569 if !self.body_has_simple_increment(&loop_var, body) {
1571 return None;
1572 }
1573
1574 let stmt_count = body.stmts.len();
1576 if stmt_count > 5 {
1577 return None;
1578 }
1579
1580 let mut unrolled_stmts: Vec<Stmt> = Vec::new();
1582
1583 for i in 0..upper_bound {
1584 let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
1586
1587 for stmt in &substituted_body.stmts {
1589 if !self.is_increment_stmt(&loop_var, stmt) {
1590 unrolled_stmts.push(stmt.clone());
1591 }
1592 }
1593 }
1594
1595 Some(Expr::Block(Block {
1597 stmts: unrolled_stmts,
1598 expr: None,
1599 }))
1600 }
1601
1602 fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
1604 if let Expr::Binary { op: BinOp::Lt, left, right } = condition {
1605 if let Expr::Path(path) = left.as_ref() {
1607 if path.segments.len() == 1 {
1608 let var_name = path.segments[0].ident.name.clone();
1609 if let Some(bound) = self.as_int(right) {
1611 return Some((var_name, bound));
1612 }
1613 }
1614 }
1615 }
1616 None
1617 }
1618
1619 fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
1621 for stmt in &body.stmts {
1622 if self.is_increment_stmt(loop_var, stmt) {
1623 return true;
1624 }
1625 }
1626 false
1627 }
1628
1629 fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
1631 match stmt {
1632 Stmt::Semi(Expr::Assign { target, value }) | Stmt::Expr(Expr::Assign { target, value }) => {
1633 if let Expr::Path(path) = target.as_ref() {
1635 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
1636 if let Expr::Binary { op: BinOp::Add, left, right } = value.as_ref() {
1638 if let Expr::Path(lpath) = left.as_ref() {
1639 if lpath.segments.len() == 1 && lpath.segments[0].ident.name == var_name {
1640 if let Some(1) = self.as_int(right) {
1641 return true;
1642 }
1643 }
1644 }
1645 }
1646 }
1647 }
1648 false
1649 }
1650 _ => false,
1651 }
1652 }
1653
1654 fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
1656 let stmts = block.stmts.iter().map(|s| self.substitute_loop_var_in_stmt(s, var_name, value)).collect();
1657 let expr = block.expr.as_ref().map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
1658 Block { stmts, expr }
1659 }
1660
1661 fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
1662 match stmt {
1663 Stmt::Let { pattern, ty, init } => Stmt::Let {
1664 pattern: pattern.clone(),
1665 ty: ty.clone(),
1666 init: init.as_ref().map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
1667 },
1668 Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
1669 Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
1670 Stmt::Item(item) => Stmt::Item(item.clone()),
1671 }
1672 }
1673
1674 fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
1675 match expr {
1676 Expr::Path(path) => {
1677 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
1678 return Expr::Literal(Literal::Int {
1679 value: value.to_string(),
1680 base: NumBase::Decimal,
1681 suffix: None,
1682 });
1683 }
1684 expr.clone()
1685 }
1686 Expr::Binary { op, left, right } => Expr::Binary {
1687 op: *op,
1688 left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
1689 right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
1690 },
1691 Expr::Unary { op, expr: inner } => Expr::Unary {
1692 op: *op,
1693 expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
1694 },
1695 Expr::Call { func, args } => Expr::Call {
1696 func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
1697 args: args.iter().map(|a| self.substitute_loop_var_in_expr(a, var_name, value)).collect(),
1698 },
1699 Expr::If { condition, then_branch, else_branch } => Expr::If {
1700 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
1701 then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
1702 else_branch: else_branch.as_ref().map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
1703 },
1704 Expr::While { condition, body } => Expr::While {
1705 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
1706 body: self.substitute_loop_var_in_block(body, var_name, value),
1707 },
1708 Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
1709 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)))),
1710 Expr::Assign { target, value: v } => Expr::Assign {
1711 target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
1712 value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
1713 },
1714 Expr::Index { expr: e, index } => Expr::Index {
1715 expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
1716 index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
1717 },
1718 Expr::Array(elements) => Expr::Array(
1719 elements.iter().map(|e| self.substitute_loop_var_in_expr(e, var_name, value)).collect()
1720 ),
1721 other => other.clone(),
1722 }
1723 }
1724
1725 fn pass_licm_block(&mut self, block: &Block) -> Block {
1731 let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
1732 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_licm_expr(e)));
1733 Block { stmts, expr }
1734 }
1735
1736 fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
1737 match stmt {
1738 Stmt::Let { pattern, ty, init } => Stmt::Let {
1739 pattern: pattern.clone(),
1740 ty: ty.clone(),
1741 init: init.as_ref().map(|e| self.pass_licm_expr(e)),
1742 },
1743 Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
1744 Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
1745 Stmt::Item(item) => Stmt::Item(item.clone()),
1746 }
1747 }
1748
1749 fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
1750 match expr {
1751 Expr::While { condition, body } => {
1752 let mut modified_vars = HashSet::new();
1754 self.collect_modified_vars_block(body, &mut modified_vars);
1755
1756 self.collect_modified_vars_expr(condition, &mut modified_vars);
1758
1759 let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
1761 self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
1762
1763 if invariant_exprs.is_empty() {
1764 return Expr::While {
1766 condition: Box::new(self.pass_licm_expr(condition)),
1767 body: self.pass_licm_block(body),
1768 };
1769 }
1770
1771 let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
1773 let mut substitution_map: HashMap<String, String> = HashMap::new();
1774
1775 for (original_key, invariant_expr) in &invariant_exprs {
1776 let var_name = format!("__licm_{}", self.cse_counter);
1777 self.cse_counter += 1;
1778
1779 pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
1780 substitution_map.insert(original_key.clone(), var_name);
1781 self.stats.loops_optimized += 1;
1782 }
1783
1784 let new_body = self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
1786
1787 let new_while = Expr::While {
1789 condition: Box::new(self.pass_licm_expr(condition)),
1790 body: self.pass_licm_block(&new_body),
1791 };
1792
1793 pre_loop_stmts.push(Stmt::Expr(new_while));
1795 Expr::Block(Block {
1796 stmts: pre_loop_stmts,
1797 expr: None,
1798 })
1799 }
1800 Expr::If { condition, then_branch, else_branch } => Expr::If {
1801 condition: Box::new(self.pass_licm_expr(condition)),
1802 then_branch: self.pass_licm_block(then_branch),
1803 else_branch: else_branch.as_ref().map(|e| Box::new(self.pass_licm_expr(e))),
1804 },
1805 Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
1806 Expr::Binary { op, left, right } => Expr::Binary {
1807 op: *op,
1808 left: Box::new(self.pass_licm_expr(left)),
1809 right: Box::new(self.pass_licm_expr(right)),
1810 },
1811 Expr::Unary { op, expr: inner } => Expr::Unary {
1812 op: *op,
1813 expr: Box::new(self.pass_licm_expr(inner)),
1814 },
1815 Expr::Call { func, args } => Expr::Call {
1816 func: func.clone(),
1817 args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
1818 },
1819 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
1820 Expr::Assign { target, value } => Expr::Assign {
1821 target: target.clone(),
1822 value: Box::new(self.pass_licm_expr(value)),
1823 },
1824 other => other.clone(),
1825 }
1826 }
1827
1828 fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
1830 for stmt in &block.stmts {
1831 self.collect_modified_vars_stmt(stmt, modified);
1832 }
1833 if let Some(expr) = &block.expr {
1834 self.collect_modified_vars_expr(expr, modified);
1835 }
1836 }
1837
1838 fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
1839 match stmt {
1840 Stmt::Let { pattern, init, .. } => {
1841 if let Pattern::Ident { name, .. } = pattern {
1843 modified.insert(name.name.clone());
1844 }
1845 if let Some(e) = init {
1846 self.collect_modified_vars_expr(e, modified);
1847 }
1848 }
1849 Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
1850 _ => {}
1851 }
1852 }
1853
1854 fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
1855 match expr {
1856 Expr::Assign { target, value } => {
1857 if let Expr::Path(path) = target.as_ref() {
1858 if path.segments.len() == 1 {
1859 modified.insert(path.segments[0].ident.name.clone());
1860 }
1861 }
1862 self.collect_modified_vars_expr(value, modified);
1863 }
1864 Expr::Binary { left, right, .. } => {
1865 self.collect_modified_vars_expr(left, modified);
1866 self.collect_modified_vars_expr(right, modified);
1867 }
1868 Expr::Unary { expr: inner, .. } => {
1869 self.collect_modified_vars_expr(inner, modified);
1870 }
1871 Expr::If { condition, then_branch, else_branch } => {
1872 self.collect_modified_vars_expr(condition, modified);
1873 self.collect_modified_vars_block(then_branch, modified);
1874 if let Some(e) = else_branch {
1875 self.collect_modified_vars_expr(e, modified);
1876 }
1877 }
1878 Expr::While { condition, body } => {
1879 self.collect_modified_vars_expr(condition, modified);
1880 self.collect_modified_vars_block(body, modified);
1881 }
1882 Expr::Block(b) => self.collect_modified_vars_block(b, modified),
1883 Expr::Call { args, .. } => {
1884 for arg in args {
1885 self.collect_modified_vars_expr(arg, modified);
1886 }
1887 }
1888 Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
1889 _ => {}
1890 }
1891 }
1892
1893 fn find_loop_invariants(&self, block: &Block, modified: &HashSet<String>, out: &mut Vec<(String, Expr)>) {
1895 for stmt in &block.stmts {
1896 self.find_loop_invariants_stmt(stmt, modified, out);
1897 }
1898 if let Some(expr) = &block.expr {
1899 self.find_loop_invariants_expr(expr, modified, out);
1900 }
1901 }
1902
1903 fn find_loop_invariants_stmt(&self, stmt: &Stmt, modified: &HashSet<String>, out: &mut Vec<(String, Expr)>) {
1904 match stmt {
1905 Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
1906 Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
1907 _ => {}
1908 }
1909 }
1910
1911 fn find_loop_invariants_expr(&self, expr: &Expr, modified: &HashSet<String>, out: &mut Vec<(String, Expr)>) {
1912 match expr {
1914 Expr::Binary { left, right, .. } => {
1915 self.find_loop_invariants_expr(left, modified, out);
1916 self.find_loop_invariants_expr(right, modified, out);
1917 }
1918 Expr::Unary { expr: inner, .. } => {
1919 self.find_loop_invariants_expr(inner, modified, out);
1920 }
1921 Expr::Call { args, .. } => {
1922 for arg in args {
1923 self.find_loop_invariants_expr(arg, modified, out);
1924 }
1925 }
1926 Expr::Index { expr: e, index } => {
1927 self.find_loop_invariants_expr(e, modified, out);
1928 self.find_loop_invariants_expr(index, modified, out);
1929 }
1930 _ => {}
1931 }
1932
1933 if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
1935 let key = format!("{:?}", expr_hash(expr));
1936 if !out.iter().any(|(k, _)| k == &key) {
1938 out.push((key, expr.clone()));
1939 }
1940 }
1941 }
1942
1943 fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
1945 match expr {
1946 Expr::Literal(_) => true,
1947 Expr::Path(path) => {
1948 if path.segments.len() == 1 {
1949 !modified.contains(&path.segments[0].ident.name)
1950 } else {
1951 true }
1953 }
1954 Expr::Binary { left, right, .. } => {
1955 self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
1956 }
1957 Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
1958 Expr::Index { expr: e, index } => {
1959 self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
1960 }
1961 Expr::Call { .. } => false,
1963 _ => false,
1965 }
1966 }
1967
1968 fn replace_invariants_in_block(&self, block: &Block, invariants: &[(String, Expr)], subs: &HashMap<String, String>) -> Block {
1970 let stmts = block.stmts.iter().map(|s| self.replace_invariants_in_stmt(s, invariants, subs)).collect();
1971 let expr = block.expr.as_ref().map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
1972 Block { stmts, expr }
1973 }
1974
1975 fn replace_invariants_in_stmt(&self, stmt: &Stmt, invariants: &[(String, Expr)], subs: &HashMap<String, String>) -> Stmt {
1976 match stmt {
1977 Stmt::Let { pattern, ty, init } => Stmt::Let {
1978 pattern: pattern.clone(),
1979 ty: ty.clone(),
1980 init: init.as_ref().map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
1981 },
1982 Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
1983 Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
1984 Stmt::Item(item) => Stmt::Item(item.clone()),
1985 }
1986 }
1987
1988 fn replace_invariants_in_expr(&self, expr: &Expr, invariants: &[(String, Expr)], subs: &HashMap<String, String>) -> Expr {
1989 let key = format!("{:?}", expr_hash(expr));
1991 for (inv_key, inv_expr) in invariants {
1992 if &key == inv_key && expr_eq(expr, inv_expr) {
1993 if let Some(var_name) = subs.get(inv_key) {
1994 return Expr::Path(TypePath {
1995 segments: vec![PathSegment {
1996 ident: Ident {
1997 name: var_name.clone(),
1998 evidentiality: None,
1999 affect: None,
2000 span: Span { start: 0, end: 0 },
2001 },
2002 generics: None,
2003 }],
2004 });
2005 }
2006 }
2007 }
2008
2009 match expr {
2011 Expr::Binary { op, left, right } => Expr::Binary {
2012 op: *op,
2013 left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2014 right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2015 },
2016 Expr::Unary { op, expr: inner } => Expr::Unary {
2017 op: *op,
2018 expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2019 },
2020 Expr::Call { func, args } => Expr::Call {
2021 func: func.clone(),
2022 args: args.iter().map(|a| self.replace_invariants_in_expr(a, invariants, subs)).collect(),
2023 },
2024 Expr::If { condition, then_branch, else_branch } => Expr::If {
2025 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2026 then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2027 else_branch: else_branch.as_ref().map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2028 },
2029 Expr::While { condition, body } => Expr::While {
2030 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2031 body: self.replace_invariants_in_block(body, invariants, subs),
2032 },
2033 Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2034 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)))),
2035 Expr::Assign { target, value } => Expr::Assign {
2036 target: target.clone(),
2037 value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2038 },
2039 Expr::Index { expr: e, index } => Expr::Index {
2040 expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2041 index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2042 },
2043 other => other.clone(),
2044 }
2045 }
2046
2047 fn pass_cse_block(&mut self, block: &Block) -> Block {
2052 let mut collected = Vec::new();
2054 collect_exprs_from_block(block, &mut collected);
2055
2056 let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2058 for ce in &collected {
2059 let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2060 let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2062 if !found {
2063 entry.push(ce.expr.clone());
2064 }
2065 }
2066
2067 let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2069 for ce in &collected {
2070 let existing = occurrence_counts.iter_mut().find(|(e, _)| expr_eq(e, &ce.expr));
2072 if let Some((_, count)) = existing {
2073 *count += 1;
2074 } else {
2075 occurrence_counts.push((ce.expr.clone(), 1));
2076 }
2077 }
2078
2079 let candidates: Vec<Expr> = occurrence_counts
2081 .into_iter()
2082 .filter(|(_, count)| *count >= 2)
2083 .map(|(expr, _)| expr)
2084 .collect();
2085
2086 if candidates.is_empty() {
2087 return self.pass_cse_nested(block);
2089 }
2090
2091 let mut result_block = block.clone();
2093 let mut new_lets: Vec<Stmt> = Vec::new();
2094
2095 for expr in candidates {
2096 let var_name = format!("__cse_{}", self.cse_counter);
2097 self.cse_counter += 1;
2098
2099 new_lets.push(make_cse_let(&var_name, expr.clone()));
2101
2102 result_block = replace_in_block(&result_block, &expr, &var_name);
2104
2105 self.stats.expressions_deduplicated += 1;
2106 }
2107
2108 let mut final_stmts = new_lets;
2110 final_stmts.extend(result_block.stmts);
2111
2112 let result = Block {
2114 stmts: final_stmts,
2115 expr: result_block.expr,
2116 };
2117 self.pass_cse_nested(&result)
2118 }
2119
2120 fn pass_cse_nested(&mut self, block: &Block) -> Block {
2122 let stmts = block.stmts.iter().map(|stmt| self.pass_cse_stmt(stmt)).collect();
2123 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2124 Block { stmts, expr }
2125 }
2126
2127 fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2128 match stmt {
2129 Stmt::Let { pattern, ty, init } => Stmt::Let {
2130 pattern: pattern.clone(),
2131 ty: ty.clone(),
2132 init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2133 },
2134 Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2135 Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2136 Stmt::Item(item) => Stmt::Item(item.clone()),
2137 }
2138 }
2139
2140 fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2141 match expr {
2142 Expr::If { condition, then_branch, else_branch } => Expr::If {
2143 condition: Box::new(self.pass_cse_expr(condition)),
2144 then_branch: self.pass_cse_block(then_branch),
2145 else_branch: else_branch.as_ref().map(|e| Box::new(self.pass_cse_expr(e))),
2146 },
2147 Expr::While { condition, body } => Expr::While {
2148 condition: Box::new(self.pass_cse_expr(condition)),
2149 body: self.pass_cse_block(body),
2150 },
2151 Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2152 Expr::Binary { op, left, right } => Expr::Binary {
2153 op: *op,
2154 left: Box::new(self.pass_cse_expr(left)),
2155 right: Box::new(self.pass_cse_expr(right)),
2156 },
2157 Expr::Unary { op, expr: inner } => Expr::Unary {
2158 op: *op,
2159 expr: Box::new(self.pass_cse_expr(inner)),
2160 },
2161 Expr::Call { func, args } => Expr::Call {
2162 func: func.clone(),
2163 args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2164 },
2165 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2166 Expr::Assign { target, value } => Expr::Assign {
2167 target: target.clone(),
2168 value: Box::new(self.pass_cse_expr(value)),
2169 },
2170 other => other.clone(),
2171 }
2172 }
2173}
2174
2175fn expr_hash(expr: &Expr) -> u64 {
2181 use std::hash::Hasher;
2182 use std::collections::hash_map::DefaultHasher;
2183
2184 let mut hasher = DefaultHasher::new();
2185 expr_hash_recursive(expr, &mut hasher);
2186 hasher.finish()
2187}
2188
2189fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2190 use std::hash::Hash;
2191
2192 std::mem::discriminant(expr).hash(hasher);
2193
2194 match expr {
2195 Expr::Literal(lit) => {
2196 match lit {
2197 Literal::Int { value, .. } => value.hash(hasher),
2198 Literal::Float { value, .. } => value.hash(hasher),
2199 Literal::String(s) => s.hash(hasher),
2200 Literal::Char(c) => c.hash(hasher),
2201 Literal::Bool(b) => b.hash(hasher),
2202 _ => {}
2203 }
2204 }
2205 Expr::Path(path) => {
2206 for seg in &path.segments {
2207 seg.ident.name.hash(hasher);
2208 }
2209 }
2210 Expr::Binary { op, left, right } => {
2211 std::mem::discriminant(op).hash(hasher);
2212 expr_hash_recursive(left, hasher);
2213 expr_hash_recursive(right, hasher);
2214 }
2215 Expr::Unary { op, expr } => {
2216 std::mem::discriminant(op).hash(hasher);
2217 expr_hash_recursive(expr, hasher);
2218 }
2219 Expr::Call { func, args } => {
2220 expr_hash_recursive(func, hasher);
2221 args.len().hash(hasher);
2222 for arg in args {
2223 expr_hash_recursive(arg, hasher);
2224 }
2225 }
2226 Expr::Index { expr, index } => {
2227 expr_hash_recursive(expr, hasher);
2228 expr_hash_recursive(index, hasher);
2229 }
2230 _ => {}
2231 }
2232}
2233
2234fn is_pure_expr(expr: &Expr) -> bool {
2236 match expr {
2237 Expr::Literal(_) => true,
2238 Expr::Path(_) => true,
2239 Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
2240 Expr::Unary { expr, .. } => is_pure_expr(expr),
2241 Expr::If { condition, then_branch, else_branch } => {
2242 is_pure_expr(condition)
2243 && then_branch.stmts.is_empty()
2244 && then_branch.expr.as_ref().map(|e| is_pure_expr(e)).unwrap_or(true)
2245 && else_branch.as_ref().map(|e| is_pure_expr(e)).unwrap_or(true)
2246 }
2247 Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
2248 Expr::Array(elements) => elements.iter().all(is_pure_expr),
2249 Expr::Call { .. } => false,
2251 Expr::Assign { .. } => false,
2252 Expr::Return(_) => false,
2253 _ => false,
2254 }
2255}
2256
2257fn is_cse_worthy(expr: &Expr) -> bool {
2259 match expr {
2260 Expr::Literal(_) => false,
2262 Expr::Path(_) => false,
2263 Expr::Binary { .. } => true,
2265 Expr::Unary { .. } => true,
2267 Expr::Call { .. } => false,
2269 Expr::Index { .. } => true,
2271 _ => false,
2272 }
2273}
2274
2275fn expr_eq(a: &Expr, b: &Expr) -> bool {
2277 match (a, b) {
2278 (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
2279 (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
2280 (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
2281 (Literal::String(sa), Literal::String(sb)) => sa == sb,
2282 (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
2283 (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
2284 _ => false,
2285 },
2286 (Expr::Path(pa), Expr::Path(pb)) => {
2287 pa.segments.len() == pb.segments.len()
2288 && pa.segments.iter().zip(&pb.segments).all(|(sa, sb)| {
2289 sa.ident.name == sb.ident.name
2290 })
2291 }
2292 (Expr::Binary { op: oa, left: la, right: ra }, Expr::Binary { op: ob, left: lb, right: rb }) => {
2293 oa == ob && expr_eq(la, lb) && expr_eq(ra, rb)
2294 }
2295 (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
2296 oa == ob && expr_eq(ea, eb)
2297 }
2298 (Expr::Index { expr: ea, index: ia }, Expr::Index { expr: eb, index: ib }) => {
2299 expr_eq(ea, eb) && expr_eq(ia, ib)
2300 }
2301 (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
2302 expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
2303 }
2304 _ => false,
2305 }
2306}
2307
2308#[derive(Clone)]
2310struct CollectedExpr {
2311 expr: Expr,
2312 hash: u64,
2313}
2314
2315fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
2317 match expr {
2319 Expr::Binary { left, right, .. } => {
2320 collect_exprs_from_expr(left, out);
2321 collect_exprs_from_expr(right, out);
2322 }
2323 Expr::Unary { expr: inner, .. } => {
2324 collect_exprs_from_expr(inner, out);
2325 }
2326 Expr::Index { expr: e, index } => {
2327 collect_exprs_from_expr(e, out);
2328 collect_exprs_from_expr(index, out);
2329 }
2330 Expr::Call { func, args } => {
2331 collect_exprs_from_expr(func, out);
2332 for arg in args {
2333 collect_exprs_from_expr(arg, out);
2334 }
2335 }
2336 Expr::If { condition, then_branch, else_branch } => {
2337 collect_exprs_from_expr(condition, out);
2338 collect_exprs_from_block(then_branch, out);
2339 if let Some(else_expr) = else_branch {
2340 collect_exprs_from_expr(else_expr, out);
2341 }
2342 }
2343 Expr::While { condition, body } => {
2344 collect_exprs_from_expr(condition, out);
2345 collect_exprs_from_block(body, out);
2346 }
2347 Expr::Block(block) => {
2348 collect_exprs_from_block(block, out);
2349 }
2350 Expr::Return(Some(e)) => {
2351 collect_exprs_from_expr(e, out);
2352 }
2353 Expr::Assign { value, .. } => {
2354 collect_exprs_from_expr(value, out);
2355 }
2356 Expr::Array(elements) => {
2357 for e in elements {
2358 collect_exprs_from_expr(e, out);
2359 }
2360 }
2361 _ => {}
2362 }
2363
2364 if is_cse_worthy(expr) && is_pure_expr(expr) {
2366 out.push(CollectedExpr {
2367 expr: expr.clone(),
2368 hash: expr_hash(expr),
2369 });
2370 }
2371}
2372
2373fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
2375 for stmt in &block.stmts {
2376 match stmt {
2377 Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
2378 Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
2379 _ => {}
2380 }
2381 }
2382 if let Some(e) = &block.expr {
2383 collect_exprs_from_expr(e, out);
2384 }
2385}
2386
2387fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
2389 if expr_eq(expr, target) {
2391 return Expr::Path(TypePath {
2392 segments: vec![PathSegment {
2393 ident: Ident {
2394 name: var_name.to_string(),
2395 evidentiality: None,
2396 affect: None,
2397 span: Span { start: 0, end: 0 },
2398 },
2399 generics: None,
2400 }],
2401 });
2402 }
2403
2404 match expr {
2406 Expr::Binary { op, left, right } => Expr::Binary {
2407 op: *op,
2408 left: Box::new(replace_in_expr(left, target, var_name)),
2409 right: Box::new(replace_in_expr(right, target, var_name)),
2410 },
2411 Expr::Unary { op, expr: inner } => Expr::Unary {
2412 op: *op,
2413 expr: Box::new(replace_in_expr(inner, target, var_name)),
2414 },
2415 Expr::Index { expr: e, index } => Expr::Index {
2416 expr: Box::new(replace_in_expr(e, target, var_name)),
2417 index: Box::new(replace_in_expr(index, target, var_name)),
2418 },
2419 Expr::Call { func, args } => Expr::Call {
2420 func: Box::new(replace_in_expr(func, target, var_name)),
2421 args: args.iter().map(|a| replace_in_expr(a, target, var_name)).collect(),
2422 },
2423 Expr::If { condition, then_branch, else_branch } => Expr::If {
2424 condition: Box::new(replace_in_expr(condition, target, var_name)),
2425 then_branch: replace_in_block(then_branch, target, var_name),
2426 else_branch: else_branch.as_ref().map(|e| Box::new(replace_in_expr(e, target, var_name))),
2427 },
2428 Expr::While { condition, body } => Expr::While {
2429 condition: Box::new(replace_in_expr(condition, target, var_name)),
2430 body: replace_in_block(body, target, var_name),
2431 },
2432 Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
2433 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(replace_in_expr(e, target, var_name)))),
2434 Expr::Assign { target: t, value } => Expr::Assign {
2435 target: t.clone(),
2436 value: Box::new(replace_in_expr(value, target, var_name)),
2437 },
2438 Expr::Array(elements) => Expr::Array(
2439 elements.iter().map(|e| replace_in_expr(e, target, var_name)).collect()
2440 ),
2441 other => other.clone(),
2442 }
2443}
2444
2445fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
2447 let stmts = block.stmts.iter().map(|stmt| {
2448 match stmt {
2449 Stmt::Let { pattern, ty, init } => Stmt::Let {
2450 pattern: pattern.clone(),
2451 ty: ty.clone(),
2452 init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
2453 },
2454 Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
2455 Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
2456 Stmt::Item(item) => Stmt::Item(item.clone()),
2457 }
2458 }).collect();
2459
2460 let expr = block.expr.as_ref().map(|e| Box::new(replace_in_expr(e, target, var_name)));
2461
2462 Block { stmts, expr }
2463}
2464
2465fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
2467 Stmt::Let {
2468 pattern: Pattern::Ident {
2469 mutable: false,
2470 name: Ident {
2471 name: var_name.to_string(),
2472 evidentiality: None,
2473 affect: None,
2474 span: Span { start: 0, end: 0 },
2475 },
2476 evidentiality: None,
2477 },
2478 ty: None,
2479 init: Some(expr),
2480 }
2481}
2482
2483pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
2489 let mut optimizer = Optimizer::new(level);
2490 let optimized = optimizer.optimize_file(file);
2491 (optimized, optimizer.stats)
2492}
2493
2494#[cfg(test)]
2499mod tests {
2500 use super::*;
2501
2502 fn int_lit(v: i64) -> Expr {
2504 Expr::Literal(Literal::Int {
2505 value: v.to_string(),
2506 base: NumBase::Decimal,
2507 suffix: None,
2508 })
2509 }
2510
2511 fn var(name: &str) -> Expr {
2513 Expr::Path(TypePath {
2514 segments: vec![PathSegment {
2515 ident: Ident {
2516 name: name.to_string(),
2517 evidentiality: None,
2518 affect: None,
2519 span: Span { start: 0, end: 0 },
2520 },
2521 generics: None,
2522 }],
2523 })
2524 }
2525
2526 fn add(left: Expr, right: Expr) -> Expr {
2528 Expr::Binary {
2529 op: BinOp::Add,
2530 left: Box::new(left),
2531 right: Box::new(right),
2532 }
2533 }
2534
2535 fn mul(left: Expr, right: Expr) -> Expr {
2537 Expr::Binary {
2538 op: BinOp::Mul,
2539 left: Box::new(left),
2540 right: Box::new(right),
2541 }
2542 }
2543
2544 #[test]
2545 fn test_expr_hash_equal() {
2546 let e1 = add(var("a"), var("b"));
2548 let e2 = add(var("a"), var("b"));
2549 assert_eq!(expr_hash(&e1), expr_hash(&e2));
2550 }
2551
2552 #[test]
2553 fn test_expr_hash_different() {
2554 let e1 = add(var("a"), var("b"));
2556 let e2 = add(var("a"), var("c"));
2557 assert_ne!(expr_hash(&e1), expr_hash(&e2));
2558 }
2559
2560 #[test]
2561 fn test_expr_eq() {
2562 let e1 = add(var("a"), var("b"));
2563 let e2 = add(var("a"), var("b"));
2564 let e3 = add(var("a"), var("c"));
2565
2566 assert!(expr_eq(&e1, &e2));
2567 assert!(!expr_eq(&e1, &e3));
2568 }
2569
2570 #[test]
2571 fn test_is_pure_expr() {
2572 assert!(is_pure_expr(&int_lit(42)));
2573 assert!(is_pure_expr(&var("x")));
2574 assert!(is_pure_expr(&add(var("a"), var("b"))));
2575
2576 let call = Expr::Call {
2578 func: Box::new(var("print")),
2579 args: vec![int_lit(42)],
2580 };
2581 assert!(!is_pure_expr(&call));
2582 }
2583
2584 #[test]
2585 fn test_is_cse_worthy() {
2586 assert!(!is_cse_worthy(&int_lit(42))); assert!(!is_cse_worthy(&var("x"))); assert!(is_cse_worthy(&add(var("a"), var("b")))); }
2590
2591 #[test]
2592 fn test_cse_basic() {
2593 let a_plus_b = add(var("a"), var("b"));
2598
2599 let block = Block {
2600 stmts: vec![
2601 Stmt::Let {
2602 pattern: Pattern::Ident {
2603 mutable: false,
2604 name: Ident { name: "x".to_string(), evidentiality: None, affect: None, span: Span { start: 0, end: 0 } },
2605 evidentiality: None,
2606 },
2607 ty: None,
2608 init: Some(a_plus_b.clone()),
2609 },
2610 Stmt::Let {
2611 pattern: Pattern::Ident {
2612 mutable: false,
2613 name: Ident { name: "y".to_string(), evidentiality: None, affect: None, span: Span { start: 0, end: 0 } },
2614 evidentiality: None,
2615 },
2616 ty: None,
2617 init: Some(mul(a_plus_b.clone(), int_lit(2))),
2618 },
2619 ],
2620 expr: None,
2621 };
2622
2623 let mut optimizer = Optimizer::new(OptLevel::Standard);
2624 let result = optimizer.pass_cse_block(&block);
2625
2626 assert_eq!(result.stmts.len(), 3);
2628 assert_eq!(optimizer.stats.expressions_deduplicated, 1);
2629
2630 if let Stmt::Let { pattern: Pattern::Ident { name, .. }, .. } = &result.stmts[0] {
2632 assert_eq!(name.name, "__cse_0");
2633 } else {
2634 panic!("Expected CSE let binding");
2635 }
2636 }
2637
2638 #[test]
2639 fn test_cse_no_duplicates() {
2640 let block = Block {
2642 stmts: vec![
2643 Stmt::Let {
2644 pattern: Pattern::Ident {
2645 mutable: false,
2646 name: Ident { name: "x".to_string(), evidentiality: None, affect: None, span: Span { start: 0, end: 0 } },
2647 evidentiality: None,
2648 },
2649 ty: None,
2650 init: Some(add(var("a"), var("b"))),
2651 },
2652 Stmt::Let {
2653 pattern: Pattern::Ident {
2654 mutable: false,
2655 name: Ident { name: "y".to_string(), evidentiality: None, affect: None, span: Span { start: 0, end: 0 } },
2656 evidentiality: None,
2657 },
2658 ty: None,
2659 init: Some(add(var("c"), var("d"))),
2660 },
2661 ],
2662 expr: None,
2663 };
2664
2665 let mut optimizer = Optimizer::new(OptLevel::Standard);
2666 let result = optimizer.pass_cse_block(&block);
2667
2668 assert_eq!(result.stmts.len(), 2);
2670 assert_eq!(optimizer.stats.expressions_deduplicated, 0);
2671 }
2672}