1use crate::ast::{
7 self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8 Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20 None,
22 Basic,
24 Standard,
26 Aggressive,
28 Size,
30}
31
32#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35 pub constants_folded: usize,
36 pub dead_code_eliminated: usize,
37 pub expressions_deduplicated: usize,
38 pub functions_inlined: usize,
39 pub strength_reductions: usize,
40 pub branches_simplified: usize,
41 pub loops_optimized: usize,
42 pub tail_recursion_transforms: usize,
43 pub memoization_transforms: usize,
44}
45
46pub struct Optimizer {
48 level: OptLevel,
49 stats: OptStats,
50 functions: HashMap<String, ast::Function>,
52 recursive_functions: HashSet<String>,
54 cse_counter: usize,
56}
57
58impl Optimizer {
59 pub fn new(level: OptLevel) -> Self {
60 Self {
61 level,
62 stats: OptStats::default(),
63 functions: HashMap::new(),
64 recursive_functions: HashSet::new(),
65 cse_counter: 0,
66 }
67 }
68
69 pub fn stats(&self) -> &OptStats {
71 &self.stats
72 }
73
74 pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76 for item in &file.items {
78 if let Item::Function(func) = &item.node {
79 self.functions.insert(func.name.name.clone(), func.clone());
80 if self.is_recursive(&func.name.name, func) {
81 self.recursive_functions.insert(func.name.name.clone());
82 }
83 }
84 }
85
86 let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90 let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93 for item in &file.items {
94 if let Item::Function(func) = &item.node {
95 if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96 {
97 new_items.push(crate::span::Spanned {
99 node: Item::Function(helper_func),
100 span: item.span.clone(),
101 });
102 transformed_functions
103 .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104 self.stats.tail_recursion_transforms += 1;
105 }
106 }
107 }
108
109 }
115
116 let items: Vec<_> = file
118 .items
119 .iter()
120 .map(|item| {
121 let node = match &item.node {
122 Item::Function(func) => {
123 if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126 && transformed_functions.contains_key(&func.name.name)
127 {
128 Item::Function(self.optimize_function(&wrapper))
129 } else {
130 Item::Function(self.optimize_function(func))
131 }
132 } else {
133 Item::Function(self.optimize_function(func))
134 }
135 }
136 other => other.clone(),
137 };
138 crate::span::Spanned {
139 node,
140 span: item.span.clone(),
141 }
142 })
143 .collect();
144
145 new_items.extend(items);
147
148 ast::SourceFile {
149 attrs: file.attrs.clone(),
150 config: file.config.clone(),
151 items: new_items,
152 }
153 }
154
155 fn try_accumulator_transform(
158 &self,
159 func: &ast::Function,
160 ) -> Option<(ast::Function, ast::Function)> {
161 if func.params.len() != 1 {
163 return None;
164 }
165
166 if !self.recursive_functions.contains(&func.name.name) {
168 return None;
169 }
170
171 let body = func.body.as_ref()?;
172
173 if !self.is_fib_like_pattern(&func.name.name, body) {
177 return None;
178 }
179
180 let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182 name.name.clone()
183 } else {
184 return None;
185 };
186
187 let helper_name = format!("{}_tail", func.name.name);
189
190 let helper_func = self.generate_fib_helper(&helper_name, ¶m_name);
192
193 let wrapper_func =
195 self.generate_fib_wrapper(&func.name.name, &helper_name, ¶m_name, func);
196
197 Some((helper_func, wrapper_func))
198 }
199
200 fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202 if body.stmts.is_empty() && body.expr.is_none() {
209 return false;
210 }
211
212 if let Some(expr) = &body.expr {
214 if let Expr::If {
215 else_branch: Some(else_expr),
216 ..
217 } = expr.as_ref()
218 {
219 return self.is_double_recursive_expr(func_name, else_expr);
222 }
223 }
224
225 if body.stmts.len() >= 1 {
227 if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229 if let Expr::Return(Some(ret_expr)) = expr {
230 return self.is_double_recursive_expr(func_name, ret_expr);
231 }
232 }
233 if let Some(expr) = &body.expr {
234 return self.is_double_recursive_expr(func_name, expr);
235 }
236 }
237
238 false
239 }
240
241 fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243 if let Expr::Binary {
244 op: BinOp::Add,
245 left,
246 right,
247 } = expr
248 {
249 let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250 let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251 return left_is_recursive && right_is_recursive;
252 }
253 false
254 }
255
256 fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258 if let Expr::Call { func, args } = expr {
259 if let Expr::Path(path) = func.as_ref() {
260 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261 if args.len() == 1 {
263 if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264 return true;
265 }
266 }
267 }
268 }
269 }
270 false
271 }
272
273 fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275 let span = Span { start: 0, end: 0 };
276
277 let n_ident = Ident {
282 name: "n".to_string(),
283 evidentiality: None,
284 affect: None,
285 span: span.clone(),
286 };
287 let a_ident = Ident {
288 name: "a".to_string(),
289 evidentiality: None,
290 affect: None,
291 span: span.clone(),
292 };
293 let b_ident = Ident {
294 name: "b".to_string(),
295 evidentiality: None,
296 affect: None,
297 span: span.clone(),
298 };
299
300 let params = vec![
301 Param {
302 pattern: Pattern::Ident {
303 mutable: false,
304 name: n_ident.clone(),
305 evidentiality: None,
306 },
307 ty: TypeExpr::Infer,
308 },
309 Param {
310 pattern: Pattern::Ident {
311 mutable: false,
312 name: a_ident.clone(),
313 evidentiality: None,
314 },
315 ty: TypeExpr::Infer,
316 },
317 Param {
318 pattern: Pattern::Ident {
319 mutable: false,
320 name: b_ident.clone(),
321 evidentiality: None,
322 },
323 ty: TypeExpr::Infer,
324 },
325 ];
326
327 let condition = Expr::Binary {
329 op: BinOp::Le,
330 left: Box::new(Expr::Path(TypePath {
331 segments: vec![PathSegment {
332 ident: n_ident.clone(),
333 generics: None,
334 }],
335 })),
336 right: Box::new(Expr::Literal(Literal::Int {
337 value: "0".to_string(),
338 base: NumBase::Decimal,
339 suffix: None,
340 })),
341 };
342
343 let then_branch = Block {
345 stmts: vec![],
346 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347 TypePath {
348 segments: vec![PathSegment {
349 ident: a_ident.clone(),
350 generics: None,
351 }],
352 },
353 )))))),
354 };
355
356 let recursive_call = Expr::Call {
358 func: Box::new(Expr::Path(TypePath {
359 segments: vec![PathSegment {
360 ident: Ident {
361 name: name.to_string(),
362 evidentiality: None,
363 affect: None,
364 span: span.clone(),
365 },
366 generics: None,
367 }],
368 })),
369 args: vec![
370 Expr::Binary {
372 op: BinOp::Sub,
373 left: Box::new(Expr::Path(TypePath {
374 segments: vec![PathSegment {
375 ident: n_ident.clone(),
376 generics: None,
377 }],
378 })),
379 right: Box::new(Expr::Literal(Literal::Int {
380 value: "1".to_string(),
381 base: NumBase::Decimal,
382 suffix: None,
383 })),
384 },
385 Expr::Path(TypePath {
387 segments: vec![PathSegment {
388 ident: b_ident.clone(),
389 generics: None,
390 }],
391 }),
392 Expr::Binary {
394 op: BinOp::Add,
395 left: Box::new(Expr::Path(TypePath {
396 segments: vec![PathSegment {
397 ident: a_ident.clone(),
398 generics: None,
399 }],
400 })),
401 right: Box::new(Expr::Path(TypePath {
402 segments: vec![PathSegment {
403 ident: b_ident.clone(),
404 generics: None,
405 }],
406 })),
407 },
408 ],
409 };
410
411 let body = Block {
413 stmts: vec![],
414 expr: Some(Box::new(Expr::If {
415 condition: Box::new(condition),
416 then_branch,
417 else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418 })),
419 };
420
421 ast::Function {
422 visibility: Visibility::default(),
423 is_async: false,
424 is_const: false,
425 is_unsafe: false,
426 attrs: FunctionAttrs::default(),
427 name: Ident {
428 name: name.to_string(),
429 evidentiality: None,
430 affect: None,
431 span: span.clone(),
432 },
433 aspect: None,
434 generics: None,
435 params,
436 return_type: None,
437 where_clause: None,
438 body: Some(body),
439 }
440 }
441
442 fn generate_fib_wrapper(
444 &self,
445 name: &str,
446 helper_name: &str,
447 param_name: &str,
448 original: &ast::Function,
449 ) -> ast::Function {
450 let span = Span { start: 0, end: 0 };
451
452 let call_helper = Expr::Call {
454 func: Box::new(Expr::Path(TypePath {
455 segments: vec![PathSegment {
456 ident: Ident {
457 name: helper_name.to_string(),
458 evidentiality: None,
459 affect: None,
460 span: span.clone(),
461 },
462 generics: None,
463 }],
464 })),
465 args: vec![
466 Expr::Path(TypePath {
468 segments: vec![PathSegment {
469 ident: Ident {
470 name: param_name.to_string(),
471 evidentiality: None,
472 affect: None,
473 span: span.clone(),
474 },
475 generics: None,
476 }],
477 }),
478 Expr::Literal(Literal::Int {
480 value: "0".to_string(),
481 base: NumBase::Decimal,
482 suffix: None,
483 }),
484 Expr::Literal(Literal::Int {
486 value: "1".to_string(),
487 base: NumBase::Decimal,
488 suffix: None,
489 }),
490 ],
491 };
492
493 let body = Block {
494 stmts: vec![],
495 expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
496 };
497
498 ast::Function {
499 visibility: original.visibility,
500 is_async: original.is_async,
501 is_const: original.is_const,
502 is_unsafe: original.is_unsafe,
503 attrs: original.attrs.clone(),
504 name: Ident {
505 name: name.to_string(),
506 evidentiality: None,
507 affect: None,
508 span: span.clone(),
509 },
510 aspect: original.aspect,
511 generics: original.generics.clone(),
512 params: original.params.clone(),
513 return_type: original.return_type.clone(),
514 where_clause: original.where_clause.clone(),
515 body: Some(body),
516 }
517 }
518
519 #[allow(dead_code)]
526 fn try_memoize_transform(
527 &self,
528 func: &ast::Function,
529 ) -> Option<(ast::Function, ast::Function, ast::Function)> {
530 let param_count = func.params.len();
531 if param_count != 1 && param_count != 2 {
532 return None;
533 }
534
535 let span = Span { start: 0, end: 0 };
536 let func_name = &func.name.name;
537 let impl_name = format!("_memo_impl_{}", func_name);
538 let _cache_name = format!("_memo_cache_{}", func_name);
539 let init_name = format!("_memo_init_{}", func_name);
540
541 let param_names: Vec<String> = func
543 .params
544 .iter()
545 .filter_map(|p| {
546 if let Pattern::Ident { name, .. } = &p.pattern {
547 Some(name.name.clone())
548 } else {
549 None
550 }
551 })
552 .collect();
553
554 if param_names.len() != param_count {
555 return None;
556 }
557
558 let impl_func = ast::Function {
560 visibility: Visibility::default(),
561 is_async: func.is_async,
562 is_const: func.is_const,
563 is_unsafe: func.is_unsafe,
564 attrs: func.attrs.clone(),
565 name: Ident {
566 name: impl_name.clone(),
567 evidentiality: None,
568 affect: None,
569 span: span.clone(),
570 },
571 aspect: func.aspect,
572 generics: func.generics.clone(),
573 params: func.params.clone(),
574 return_type: func.return_type.clone(),
575 where_clause: func.where_clause.clone(),
576 body: func
577 .body
578 .as_ref()
579 .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
580 };
581
582 let cache_init_body = Block {
585 stmts: vec![],
586 expr: Some(Box::new(Expr::Call {
587 func: Box::new(Expr::Path(TypePath {
588 segments: vec![PathSegment {
589 ident: Ident {
590 name: "sigil_memo_new".to_string(),
591 evidentiality: None,
592 affect: None,
593 span: span.clone(),
594 },
595 generics: None,
596 }],
597 })),
598 args: vec![Expr::Literal(Literal::Int {
599 value: "65536".to_string(),
600 base: NumBase::Decimal,
601 suffix: None,
602 })],
603 })),
604 };
605
606 let cache_init_func = ast::Function {
607 visibility: Visibility::default(),
608 is_async: false,
609 is_const: false,
610 is_unsafe: false,
611 attrs: FunctionAttrs::default(),
612 name: Ident {
613 name: init_name.clone(),
614 evidentiality: None,
615 affect: None,
616 span: span.clone(),
617 },
618 aspect: None,
619 generics: None,
620 params: vec![],
621 return_type: None,
622 where_clause: None,
623 body: Some(cache_init_body),
624 };
625
626 let wrapper_func = self.generate_memo_wrapper(func, &impl_name, ¶m_names);
628
629 Some((impl_func, cache_init_func, wrapper_func))
630 }
631
632 #[allow(dead_code)]
634 fn generate_memo_wrapper(
635 &self,
636 original: &ast::Function,
637 impl_name: &str,
638 param_names: &[String],
639 ) -> ast::Function {
640 let span = Span { start: 0, end: 0 };
641 let param_count = param_names.len();
642
643 let cache_var = Ident {
645 name: "__cache".to_string(),
646 evidentiality: None,
647 affect: None,
648 span: span.clone(),
649 };
650 let result_var = Ident {
651 name: "__result".to_string(),
652 evidentiality: None,
653 affect: None,
654 span: span.clone(),
655 };
656 let cached_var = Ident {
657 name: "__cached".to_string(),
658 evidentiality: None,
659 affect: None,
660 span: span.clone(),
661 };
662
663 let mut stmts = vec![];
664
665 stmts.push(Stmt::Let {
667 pattern: Pattern::Ident {
668 mutable: false,
669 name: cache_var.clone(),
670 evidentiality: None,
671 },
672 ty: None,
673 init: Some(Expr::Call {
674 func: Box::new(Expr::Path(TypePath {
675 segments: vec![PathSegment {
676 ident: Ident {
677 name: "sigil_memo_new".to_string(),
678 evidentiality: None,
679 affect: None,
680 span: span.clone(),
681 },
682 generics: None,
683 }],
684 })),
685 args: vec![Expr::Literal(Literal::Int {
686 value: "65536".to_string(),
687 base: NumBase::Decimal,
688 suffix: None,
689 })],
690 }),
691 });
692
693 let get_fn_name = if param_count == 1 {
695 "sigil_memo_get_1"
696 } else {
697 "sigil_memo_get_2"
698 };
699 let mut get_args = vec![Expr::Path(TypePath {
700 segments: vec![PathSegment {
701 ident: cache_var.clone(),
702 generics: None,
703 }],
704 })];
705 for name in param_names {
706 get_args.push(Expr::Path(TypePath {
707 segments: vec![PathSegment {
708 ident: Ident {
709 name: name.clone(),
710 evidentiality: None,
711 affect: None,
712 span: span.clone(),
713 },
714 generics: None,
715 }],
716 }));
717 }
718
719 stmts.push(Stmt::Let {
720 pattern: Pattern::Ident {
721 mutable: false,
722 name: cached_var.clone(),
723 evidentiality: None,
724 },
725 ty: None,
726 init: Some(Expr::Call {
727 func: Box::new(Expr::Path(TypePath {
728 segments: vec![PathSegment {
729 ident: Ident {
730 name: get_fn_name.to_string(),
731 evidentiality: None,
732 affect: None,
733 span: span.clone(),
734 },
735 generics: None,
736 }],
737 })),
738 args: get_args,
739 }),
740 });
741
742 let cache_check = Expr::If {
745 condition: Box::new(Expr::Binary {
746 op: BinOp::Ne,
747 left: Box::new(Expr::Path(TypePath {
748 segments: vec![PathSegment {
749 ident: cached_var.clone(),
750 generics: None,
751 }],
752 })),
753 right: Box::new(Expr::Unary {
754 op: UnaryOp::Neg,
755 expr: Box::new(Expr::Literal(Literal::Int {
756 value: "9223372036854775807".to_string(),
757 base: NumBase::Decimal,
758 suffix: None,
759 })),
760 }),
761 }),
762 then_branch: Block {
763 stmts: vec![],
764 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
765 TypePath {
766 segments: vec![PathSegment {
767 ident: cached_var.clone(),
768 generics: None,
769 }],
770 },
771 )))))),
772 },
773 else_branch: None,
774 };
775 stmts.push(Stmt::Semi(cache_check));
776
777 let mut impl_args = vec![];
779 for name in param_names {
780 impl_args.push(Expr::Path(TypePath {
781 segments: vec![PathSegment {
782 ident: Ident {
783 name: name.clone(),
784 evidentiality: None,
785 affect: None,
786 span: span.clone(),
787 },
788 generics: None,
789 }],
790 }));
791 }
792
793 stmts.push(Stmt::Let {
794 pattern: Pattern::Ident {
795 mutable: false,
796 name: result_var.clone(),
797 evidentiality: None,
798 },
799 ty: None,
800 init: Some(Expr::Call {
801 func: Box::new(Expr::Path(TypePath {
802 segments: vec![PathSegment {
803 ident: Ident {
804 name: impl_name.to_string(),
805 evidentiality: None,
806 affect: None,
807 span: span.clone(),
808 },
809 generics: None,
810 }],
811 })),
812 args: impl_args,
813 }),
814 });
815
816 let set_fn_name = if param_count == 1 {
818 "sigil_memo_set_1"
819 } else {
820 "sigil_memo_set_2"
821 };
822 let mut set_args = vec![Expr::Path(TypePath {
823 segments: vec![PathSegment {
824 ident: cache_var.clone(),
825 generics: None,
826 }],
827 })];
828 for name in param_names {
829 set_args.push(Expr::Path(TypePath {
830 segments: vec![PathSegment {
831 ident: Ident {
832 name: name.clone(),
833 evidentiality: None,
834 affect: None,
835 span: span.clone(),
836 },
837 generics: None,
838 }],
839 }));
840 }
841 set_args.push(Expr::Path(TypePath {
842 segments: vec![PathSegment {
843 ident: result_var.clone(),
844 generics: None,
845 }],
846 }));
847
848 stmts.push(Stmt::Semi(Expr::Call {
849 func: Box::new(Expr::Path(TypePath {
850 segments: vec![PathSegment {
851 ident: Ident {
852 name: set_fn_name.to_string(),
853 evidentiality: None,
854 affect: None,
855 span: span.clone(),
856 },
857 generics: None,
858 }],
859 })),
860 args: set_args,
861 }));
862
863 let body = Block {
865 stmts,
866 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
867 TypePath {
868 segments: vec![PathSegment {
869 ident: result_var.clone(),
870 generics: None,
871 }],
872 },
873 )))))),
874 };
875
876 ast::Function {
877 visibility: original.visibility,
878 is_async: original.is_async,
879 is_const: original.is_const,
880 is_unsafe: original.is_unsafe,
881 attrs: original.attrs.clone(),
882 name: original.name.clone(),
883 aspect: original.aspect,
884 generics: original.generics.clone(),
885 params: original.params.clone(),
886 return_type: original.return_type.clone(),
887 where_clause: original.where_clause.clone(),
888 body: Some(body),
889 }
890 }
891
892 #[allow(dead_code)]
894 fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
895 block.clone()
897 }
898
899 fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
901 if let Some(body) = &func.body {
902 self.block_calls_function(name, body)
903 } else {
904 false
905 }
906 }
907
908 fn block_calls_function(&self, name: &str, block: &Block) -> bool {
909 for stmt in &block.stmts {
910 if self.stmt_calls_function(name, stmt) {
911 return true;
912 }
913 }
914 if let Some(expr) = &block.expr {
915 if self.expr_calls_function(name, expr) {
916 return true;
917 }
918 }
919 false
920 }
921
922 fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
923 match stmt {
924 Stmt::Let {
925 init: Some(expr), ..
926 } => self.expr_calls_function(name, expr),
927 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
928 _ => false,
929 }
930 }
931
932 fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
933 match expr {
934 Expr::Call { func, args } => {
935 if let Expr::Path(path) = func.as_ref() {
936 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
937 return true;
938 }
939 }
940 args.iter().any(|a| self.expr_calls_function(name, a))
941 }
942 Expr::Binary { left, right, .. } => {
943 self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
944 }
945 Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
946 Expr::If {
947 condition,
948 then_branch,
949 else_branch,
950 } => {
951 self.expr_calls_function(name, condition)
952 || self.block_calls_function(name, then_branch)
953 || else_branch
954 .as_ref()
955 .map(|e| self.expr_calls_function(name, e))
956 .unwrap_or(false)
957 }
958 Expr::While { label, condition, body } => {
959 self.expr_calls_function(name, condition) || self.block_calls_function(name, body)
960 }
961 Expr::Block(block) => self.block_calls_function(name, block),
962 Expr::Return(Some(e)) => self.expr_calls_function(name, e),
963 _ => false,
964 }
965 }
966
967 fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
969 self.cse_counter = 0;
971
972 let body = if let Some(body) = &func.body {
973 let optimized = match self.level {
975 OptLevel::None => body.clone(),
976 OptLevel::Basic => {
977 let b = self.pass_constant_fold_block(body);
978 self.pass_dead_code_block(&b)
979 }
980 OptLevel::Standard | OptLevel::Size => {
981 let b = self.pass_constant_fold_block(body);
982 let b = self.pass_inline_block(&b); let b = self.pass_strength_reduce_block(&b);
984 let b = self.pass_licm_block(&b); let b = self.pass_cse_block(&b); let b = self.pass_dead_code_block(&b);
987 self.pass_simplify_branches_block(&b)
988 }
989 OptLevel::Aggressive => {
990 let mut b = body.clone();
992 for _ in 0..3 {
993 b = self.pass_constant_fold_block(&b);
994 b = self.pass_inline_block(&b); b = self.pass_strength_reduce_block(&b);
996 b = self.pass_loop_unroll_block(&b); b = self.pass_licm_block(&b); b = self.pass_cse_block(&b); b = self.pass_dead_code_block(&b);
1000 b = self.pass_simplify_branches_block(&b);
1001 }
1002 b
1003 }
1004 };
1005 Some(optimized)
1006 } else {
1007 None
1008 };
1009
1010 ast::Function {
1011 visibility: func.visibility.clone(),
1012 is_async: func.is_async,
1013 is_const: func.is_const,
1014 is_unsafe: func.is_unsafe,
1015 attrs: func.attrs.clone(),
1016 name: func.name.clone(),
1017 aspect: func.aspect,
1018 generics: func.generics.clone(),
1019 params: func.params.clone(),
1020 return_type: func.return_type.clone(),
1021 where_clause: func.where_clause.clone(),
1022 body,
1023 }
1024 }
1025
1026 fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1031 let stmts = block
1032 .stmts
1033 .iter()
1034 .map(|s| self.pass_constant_fold_stmt(s))
1035 .collect();
1036 let expr = block
1037 .expr
1038 .as_ref()
1039 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1040 Block { stmts, expr }
1041 }
1042
1043 fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1044 match stmt {
1045 Stmt::Let {
1046 pattern, ty, init, ..
1047 } => Stmt::Let {
1048 pattern: pattern.clone(),
1049 ty: ty.clone(),
1050 init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1051 },
1052 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1053 pattern: pattern.clone(),
1054 ty: ty.clone(),
1055 init: self.pass_constant_fold_expr(init),
1056 else_branch: Box::new(self.pass_constant_fold_expr(else_branch)),
1057 },
1058 Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1059 Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1060 Stmt::Item(item) => Stmt::Item(item.clone()),
1061 }
1062 }
1063
1064 fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1065 match expr {
1066 Expr::Binary { op, left, right } => {
1067 let left = Box::new(self.pass_constant_fold_expr(left));
1068 let right = Box::new(self.pass_constant_fold_expr(right));
1069
1070 if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1072 if let Some(result) = self.fold_binary(op.clone(), l, r) {
1073 self.stats.constants_folded += 1;
1074 return Expr::Literal(Literal::Int {
1075 value: result.to_string(),
1076 base: NumBase::Decimal,
1077 suffix: None,
1078 });
1079 }
1080 }
1081
1082 Expr::Binary {
1083 op: op.clone(),
1084 left,
1085 right,
1086 }
1087 }
1088 Expr::Unary { op, expr: inner } => {
1089 let inner = Box::new(self.pass_constant_fold_expr(inner));
1090
1091 if let Some(v) = self.as_int(&inner) {
1092 if let Some(result) = self.fold_unary(*op, v) {
1093 self.stats.constants_folded += 1;
1094 return Expr::Literal(Literal::Int {
1095 value: result.to_string(),
1096 base: NumBase::Decimal,
1097 suffix: None,
1098 });
1099 }
1100 }
1101
1102 Expr::Unary {
1103 op: *op,
1104 expr: inner,
1105 }
1106 }
1107 Expr::If {
1108 condition,
1109 then_branch,
1110 else_branch,
1111 } => {
1112 let condition = Box::new(self.pass_constant_fold_expr(condition));
1113 let then_branch = self.pass_constant_fold_block(then_branch);
1114 let else_branch = else_branch
1115 .as_ref()
1116 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1117
1118 if let Some(cond) = self.as_bool(&condition) {
1120 self.stats.branches_simplified += 1;
1121 if cond {
1122 return Expr::Block(then_branch);
1123 } else if let Some(else_expr) = else_branch {
1124 return *else_expr;
1125 } else {
1126 return Expr::Literal(Literal::Bool(false));
1127 }
1128 }
1129
1130 Expr::If {
1131 condition,
1132 then_branch,
1133 else_branch,
1134 }
1135 }
1136 Expr::While { label, condition, body } => {
1137 let condition = Box::new(self.pass_constant_fold_expr(condition));
1138 let body = self.pass_constant_fold_block(body);
1139
1140 if let Some(false) = self.as_bool(&condition) {
1142 self.stats.branches_simplified += 1;
1143 return Expr::Block(Block {
1144 stmts: vec![],
1145 expr: None,
1146 });
1147 }
1148
1149 Expr::While { label: label.clone(), condition, body }
1150 }
1151 Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1152 Expr::Call { func, args } => {
1153 let args = args
1154 .iter()
1155 .map(|a| self.pass_constant_fold_expr(a))
1156 .collect();
1157 Expr::Call {
1158 func: func.clone(),
1159 args,
1160 }
1161 }
1162 Expr::Return(e) => Expr::Return(
1163 e.as_ref()
1164 .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1165 ),
1166 Expr::Assign { target, value } => {
1167 let value = Box::new(self.pass_constant_fold_expr(value));
1168 Expr::Assign {
1169 target: target.clone(),
1170 value,
1171 }
1172 }
1173 Expr::Index { expr: e, index } => {
1174 let e = Box::new(self.pass_constant_fold_expr(e));
1175 let index = Box::new(self.pass_constant_fold_expr(index));
1176 Expr::Index { expr: e, index }
1177 }
1178 Expr::Array(elements) => {
1179 let elements = elements
1180 .iter()
1181 .map(|e| self.pass_constant_fold_expr(e))
1182 .collect();
1183 Expr::Array(elements)
1184 }
1185 other => other.clone(),
1186 }
1187 }
1188
1189 fn as_int(&self, expr: &Expr) -> Option<i64> {
1190 match expr {
1191 Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1192 Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1193 _ => None,
1194 }
1195 }
1196
1197 fn as_bool(&self, expr: &Expr) -> Option<bool> {
1198 match expr {
1199 Expr::Literal(Literal::Bool(b)) => Some(*b),
1200 Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1201 _ => None,
1202 }
1203 }
1204
1205 fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1206 match op {
1207 BinOp::Add => Some(l.wrapping_add(r)),
1208 BinOp::Sub => Some(l.wrapping_sub(r)),
1209 BinOp::Mul => Some(l.wrapping_mul(r)),
1210 BinOp::Div if r != 0 => Some(l / r),
1211 BinOp::Rem if r != 0 => Some(l % r),
1212 BinOp::BitAnd => Some(l & r),
1213 BinOp::BitOr => Some(l | r),
1214 BinOp::BitXor => Some(l ^ r),
1215 BinOp::Shl => Some(l << (r & 63)),
1216 BinOp::Shr => Some(l >> (r & 63)),
1217 BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1218 BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1219 BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1220 BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1221 BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1222 BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1223 BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1224 BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1225 _ => None,
1226 }
1227 }
1228
1229 fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1230 match op {
1231 UnaryOp::Neg => Some(-v),
1232 UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1233 _ => None,
1234 }
1235 }
1236
1237 fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1242 let stmts = block
1243 .stmts
1244 .iter()
1245 .map(|s| self.pass_strength_reduce_stmt(s))
1246 .collect();
1247 let expr = block
1248 .expr
1249 .as_ref()
1250 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1251 Block { stmts, expr }
1252 }
1253
1254 fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1255 match stmt {
1256 Stmt::Let {
1257 pattern, ty, init, ..
1258 } => Stmt::Let {
1259 pattern: pattern.clone(),
1260 ty: ty.clone(),
1261 init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1262 },
1263 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1264 pattern: pattern.clone(),
1265 ty: ty.clone(),
1266 init: self.pass_strength_reduce_expr(init),
1267 else_branch: Box::new(self.pass_strength_reduce_expr(else_branch)),
1268 },
1269 Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1270 Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1271 Stmt::Item(item) => Stmt::Item(item.clone()),
1272 }
1273 }
1274
1275 fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1276 match expr {
1277 Expr::Binary { op, left, right } => {
1278 let left = Box::new(self.pass_strength_reduce_expr(left));
1279 let right = Box::new(self.pass_strength_reduce_expr(right));
1280
1281 if *op == BinOp::Mul {
1283 if let Some(n) = self.as_int(&right) {
1284 if n > 0 && (n as u64).is_power_of_two() {
1285 self.stats.strength_reductions += 1;
1286 let shift = (n as u64).trailing_zeros() as i64;
1287 return Expr::Binary {
1288 op: BinOp::Shl,
1289 left,
1290 right: Box::new(Expr::Literal(Literal::Int {
1291 value: shift.to_string(),
1292 base: NumBase::Decimal,
1293 suffix: None,
1294 })),
1295 };
1296 }
1297 }
1298 if let Some(n) = self.as_int(&left) {
1299 if n > 0 && (n as u64).is_power_of_two() {
1300 self.stats.strength_reductions += 1;
1301 let shift = (n as u64).trailing_zeros() as i64;
1302 return Expr::Binary {
1303 op: BinOp::Shl,
1304 left: right,
1305 right: Box::new(Expr::Literal(Literal::Int {
1306 value: shift.to_string(),
1307 base: NumBase::Decimal,
1308 suffix: None,
1309 })),
1310 };
1311 }
1312 }
1313 }
1314
1315 if let Some(n) = self.as_int(&right) {
1317 match (op, n) {
1318 (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1319 | (BinOp::Mul | BinOp::Div, 1)
1320 | (BinOp::Shl | BinOp::Shr, 0) => {
1321 self.stats.strength_reductions += 1;
1322 return *left;
1323 }
1324 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1325 self.stats.strength_reductions += 1;
1326 return Expr::Literal(Literal::Int {
1327 value: "0".to_string(),
1328 base: NumBase::Decimal,
1329 suffix: None,
1330 });
1331 }
1332 _ => {}
1333 }
1334 }
1335
1336 if let Some(n) = self.as_int(&left) {
1338 match (op, n) {
1339 (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1340 self.stats.strength_reductions += 1;
1341 return *right;
1342 }
1343 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1344 self.stats.strength_reductions += 1;
1345 return Expr::Literal(Literal::Int {
1346 value: "0".to_string(),
1347 base: NumBase::Decimal,
1348 suffix: None,
1349 });
1350 }
1351 _ => {}
1352 }
1353 }
1354
1355 Expr::Binary {
1356 op: op.clone(),
1357 left,
1358 right,
1359 }
1360 }
1361 Expr::Unary { op, expr: inner } => {
1362 let inner = Box::new(self.pass_strength_reduce_expr(inner));
1363
1364 if *op == UnaryOp::Neg {
1366 if let Expr::Unary {
1367 op: UnaryOp::Neg,
1368 expr: inner2,
1369 } = inner.as_ref()
1370 {
1371 self.stats.strength_reductions += 1;
1372 return *inner2.clone();
1373 }
1374 }
1375
1376 if *op == UnaryOp::Not {
1378 if let Expr::Unary {
1379 op: UnaryOp::Not,
1380 expr: inner2,
1381 } = inner.as_ref()
1382 {
1383 self.stats.strength_reductions += 1;
1384 return *inner2.clone();
1385 }
1386 }
1387
1388 Expr::Unary {
1389 op: *op,
1390 expr: inner,
1391 }
1392 }
1393 Expr::If {
1394 condition,
1395 then_branch,
1396 else_branch,
1397 } => {
1398 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1399 let then_branch = self.pass_strength_reduce_block(then_branch);
1400 let else_branch = else_branch
1401 .as_ref()
1402 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1403 Expr::If {
1404 condition,
1405 then_branch,
1406 else_branch,
1407 }
1408 }
1409 Expr::While { label, condition, body } => {
1410 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1411 let body = self.pass_strength_reduce_block(body);
1412 Expr::While { label: label.clone(), condition, body }
1413 }
1414 Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1415 Expr::Call { func, args } => {
1416 let args = args
1417 .iter()
1418 .map(|a| self.pass_strength_reduce_expr(a))
1419 .collect();
1420 Expr::Call {
1421 func: func.clone(),
1422 args,
1423 }
1424 }
1425 Expr::Return(e) => Expr::Return(
1426 e.as_ref()
1427 .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1428 ),
1429 Expr::Assign { target, value } => {
1430 let value = Box::new(self.pass_strength_reduce_expr(value));
1431 Expr::Assign {
1432 target: target.clone(),
1433 value,
1434 }
1435 }
1436 other => other.clone(),
1437 }
1438 }
1439
1440 fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1445 let mut stmts = Vec::new();
1447 let mut found_return = false;
1448
1449 for stmt in &block.stmts {
1450 if found_return {
1451 self.stats.dead_code_eliminated += 1;
1452 continue;
1453 }
1454 let stmt = self.pass_dead_code_stmt(stmt);
1455 if self.stmt_returns(&stmt) {
1456 found_return = true;
1457 }
1458 stmts.push(stmt);
1459 }
1460
1461 let expr = if found_return {
1463 if block.expr.is_some() {
1464 self.stats.dead_code_eliminated += 1;
1465 }
1466 None
1467 } else {
1468 block
1469 .expr
1470 .as_ref()
1471 .map(|e| Box::new(self.pass_dead_code_expr(e)))
1472 };
1473
1474 Block { stmts, expr }
1475 }
1476
1477 fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1478 match stmt {
1479 Stmt::Let {
1480 pattern, ty, init, ..
1481 } => Stmt::Let {
1482 pattern: pattern.clone(),
1483 ty: ty.clone(),
1484 init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1485 },
1486 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1487 pattern: pattern.clone(),
1488 ty: ty.clone(),
1489 init: self.pass_dead_code_expr(init),
1490 else_branch: Box::new(self.pass_dead_code_expr(else_branch)),
1491 },
1492 Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1493 Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1494 Stmt::Item(item) => Stmt::Item(item.clone()),
1495 }
1496 }
1497
1498 fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1499 match expr {
1500 Expr::If {
1501 condition,
1502 then_branch,
1503 else_branch,
1504 } => {
1505 let condition = Box::new(self.pass_dead_code_expr(condition));
1506 let then_branch = self.pass_dead_code_block(then_branch);
1507 let else_branch = else_branch
1508 .as_ref()
1509 .map(|e| Box::new(self.pass_dead_code_expr(e)));
1510 Expr::If {
1511 condition,
1512 then_branch,
1513 else_branch,
1514 }
1515 }
1516 Expr::While { label, condition, body } => {
1517 let condition = Box::new(self.pass_dead_code_expr(condition));
1518 let body = self.pass_dead_code_block(body);
1519 Expr::While { label: label.clone(), condition, body }
1520 }
1521 Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1522 other => other.clone(),
1523 }
1524 }
1525
1526 fn stmt_returns(&self, stmt: &Stmt) -> bool {
1527 match stmt {
1528 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1529 _ => false,
1530 }
1531 }
1532
1533 fn expr_returns(&self, expr: &Expr) -> bool {
1534 match expr {
1535 Expr::Return(_) => true,
1536 Expr::Block(block) => {
1537 block.stmts.iter().any(|s| self.stmt_returns(s))
1538 || block
1539 .expr
1540 .as_ref()
1541 .map(|e| self.expr_returns(e))
1542 .unwrap_or(false)
1543 }
1544 _ => false,
1545 }
1546 }
1547
1548 fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1553 let stmts = block
1554 .stmts
1555 .iter()
1556 .map(|s| self.pass_simplify_branches_stmt(s))
1557 .collect();
1558 let expr = block
1559 .expr
1560 .as_ref()
1561 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1562 Block { stmts, expr }
1563 }
1564
1565 fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1566 match stmt {
1567 Stmt::Let {
1568 pattern, ty, init, ..
1569 } => Stmt::Let {
1570 pattern: pattern.clone(),
1571 ty: ty.clone(),
1572 init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1573 },
1574 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1575 pattern: pattern.clone(),
1576 ty: ty.clone(),
1577 init: self.pass_simplify_branches_expr(init),
1578 else_branch: Box::new(self.pass_simplify_branches_expr(else_branch)),
1579 },
1580 Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1581 Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1582 Stmt::Item(item) => Stmt::Item(item.clone()),
1583 }
1584 }
1585
1586 fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1587 match expr {
1588 Expr::If {
1589 condition,
1590 then_branch,
1591 else_branch,
1592 } => {
1593 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1594 let then_branch = self.pass_simplify_branches_block(then_branch);
1595 let else_branch = else_branch
1596 .as_ref()
1597 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1598
1599 if let Expr::Unary {
1601 op: UnaryOp::Not,
1602 expr: inner,
1603 } = condition.as_ref()
1604 {
1605 if let Some(else_expr) = &else_branch {
1606 self.stats.branches_simplified += 1;
1607 let new_else = Some(Box::new(Expr::Block(then_branch)));
1608 let new_then = match else_expr.as_ref() {
1609 Expr::Block(b) => b.clone(),
1610 other => Block {
1611 stmts: vec![],
1612 expr: Some(Box::new(other.clone())),
1613 },
1614 };
1615 return Expr::If {
1616 condition: inner.clone(),
1617 then_branch: new_then,
1618 else_branch: new_else,
1619 };
1620 }
1621 }
1622
1623 Expr::If {
1624 condition,
1625 then_branch,
1626 else_branch,
1627 }
1628 }
1629 Expr::While { label, condition, body } => {
1630 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1631 let body = self.pass_simplify_branches_block(body);
1632 Expr::While { label: label.clone(), condition, body }
1633 }
1634 Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1635 Expr::Binary { op, left, right } => {
1636 let left = Box::new(self.pass_simplify_branches_expr(left));
1637 let right = Box::new(self.pass_simplify_branches_expr(right));
1638 Expr::Binary {
1639 op: op.clone(),
1640 left,
1641 right,
1642 }
1643 }
1644 Expr::Unary { op, expr: inner } => {
1645 let inner = Box::new(self.pass_simplify_branches_expr(inner));
1646 Expr::Unary {
1647 op: *op,
1648 expr: inner,
1649 }
1650 }
1651 Expr::Call { func, args } => {
1652 let args = args
1653 .iter()
1654 .map(|a| self.pass_simplify_branches_expr(a))
1655 .collect();
1656 Expr::Call {
1657 func: func.clone(),
1658 args,
1659 }
1660 }
1661 Expr::Return(e) => Expr::Return(
1662 e.as_ref()
1663 .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1664 ),
1665 other => other.clone(),
1666 }
1667 }
1668
1669 fn should_inline(&self, func: &ast::Function) -> bool {
1675 if self.recursive_functions.contains(&func.name.name) {
1677 return false;
1678 }
1679
1680 if let Some(body) = &func.body {
1682 let stmt_count = self.count_stmts_in_block(body);
1683 stmt_count <= 10
1685 } else {
1686 false
1687 }
1688 }
1689
1690 fn count_stmts_in_block(&self, block: &Block) -> usize {
1692 let mut count = block.stmts.len();
1693 if block.expr.is_some() {
1694 count += 1;
1695 }
1696 for stmt in &block.stmts {
1698 count += self.count_stmts_in_stmt(stmt);
1699 }
1700 count
1701 }
1702
1703 fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1704 match stmt {
1705 Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1706 Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1707 _ => 0,
1708 }
1709 }
1710
1711 fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1712 match expr {
1713 Expr::If {
1714 then_branch,
1715 else_branch,
1716 ..
1717 } => {
1718 let mut count = self.count_stmts_in_block(then_branch);
1719 if let Some(else_expr) = else_branch {
1720 count += self.count_stmts_in_expr(else_expr);
1721 }
1722 count
1723 }
1724 Expr::While { body, .. } => self.count_stmts_in_block(body),
1725 Expr::Block(block) => self.count_stmts_in_block(block),
1726 _ => 0,
1727 }
1728 }
1729
1730 fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1732 let body = func.body.as_ref()?;
1733
1734 let mut param_map: HashMap<String, Expr> = HashMap::new();
1736 for (param, arg) in func.params.iter().zip(args.iter()) {
1737 if let Pattern::Ident { name, .. } = ¶m.pattern {
1738 param_map.insert(name.name.clone(), arg.clone());
1739 }
1740 }
1741
1742 let inlined_body = self.substitute_params_in_block(body, ¶m_map);
1744
1745 self.stats.functions_inlined += 1;
1746
1747 if inlined_body.stmts.is_empty() {
1750 if let Some(expr) = inlined_body.expr {
1751 if let Expr::Return(Some(inner)) = expr.as_ref() {
1753 return Some(inner.as_ref().clone());
1754 }
1755 return Some(*expr);
1756 }
1757 }
1758
1759 Some(Expr::Block(inlined_body))
1760 }
1761
1762 fn substitute_params_in_block(
1764 &self,
1765 block: &Block,
1766 param_map: &HashMap<String, Expr>,
1767 ) -> Block {
1768 let stmts = block
1769 .stmts
1770 .iter()
1771 .map(|s| self.substitute_params_in_stmt(s, param_map))
1772 .collect();
1773 let expr = block
1774 .expr
1775 .as_ref()
1776 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1777 Block { stmts, expr }
1778 }
1779
1780 fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1781 match stmt {
1782 Stmt::Let { pattern, ty, init } => Stmt::Let {
1783 pattern: pattern.clone(),
1784 ty: ty.clone(),
1785 init: init
1786 .as_ref()
1787 .map(|e| self.substitute_params_in_expr(e, param_map)),
1788 },
1789 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1790 pattern: pattern.clone(),
1791 ty: ty.clone(),
1792 init: self.substitute_params_in_expr(init, param_map),
1793 else_branch: Box::new(self.substitute_params_in_expr(else_branch, param_map)),
1794 },
1795 Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1796 Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1797 Stmt::Item(item) => Stmt::Item(item.clone()),
1798 }
1799 }
1800
1801 fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1802 match expr {
1803 Expr::Path(path) => {
1804 if path.segments.len() == 1 {
1806 let name = &path.segments[0].ident.name;
1807 if let Some(arg) = param_map.get(name) {
1808 return arg.clone();
1809 }
1810 }
1811 expr.clone()
1812 }
1813 Expr::Binary { op, left, right } => Expr::Binary {
1814 op: op.clone(),
1815 left: Box::new(self.substitute_params_in_expr(left, param_map)),
1816 right: Box::new(self.substitute_params_in_expr(right, param_map)),
1817 },
1818 Expr::Unary { op, expr: inner } => Expr::Unary {
1819 op: *op,
1820 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1821 },
1822 Expr::If {
1823 condition,
1824 then_branch,
1825 else_branch,
1826 } => Expr::If {
1827 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1828 then_branch: self.substitute_params_in_block(then_branch, param_map),
1829 else_branch: else_branch
1830 .as_ref()
1831 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1832 },
1833 Expr::While { label, condition, body } => Expr::While {
1834 label: label.clone(),
1835 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1836 body: self.substitute_params_in_block(body, param_map),
1837 },
1838 Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1839 Expr::Call { func, args } => Expr::Call {
1840 func: Box::new(self.substitute_params_in_expr(func, param_map)),
1841 args: args
1842 .iter()
1843 .map(|a| self.substitute_params_in_expr(a, param_map))
1844 .collect(),
1845 },
1846 Expr::Return(e) => Expr::Return(
1847 e.as_ref()
1848 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1849 ),
1850 Expr::Assign { target, value } => Expr::Assign {
1851 target: target.clone(),
1852 value: Box::new(self.substitute_params_in_expr(value, param_map)),
1853 },
1854 Expr::Index { expr: e, index } => Expr::Index {
1855 expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1856 index: Box::new(self.substitute_params_in_expr(index, param_map)),
1857 },
1858 Expr::Array(elements) => Expr::Array(
1859 elements
1860 .iter()
1861 .map(|e| self.substitute_params_in_expr(e, param_map))
1862 .collect(),
1863 ),
1864 other => other.clone(),
1865 }
1866 }
1867
1868 fn pass_inline_block(&mut self, block: &Block) -> Block {
1869 let stmts = block
1870 .stmts
1871 .iter()
1872 .map(|s| self.pass_inline_stmt(s))
1873 .collect();
1874 let expr = block
1875 .expr
1876 .as_ref()
1877 .map(|e| Box::new(self.pass_inline_expr(e)));
1878 Block { stmts, expr }
1879 }
1880
1881 fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1882 match stmt {
1883 Stmt::Let { pattern, ty, init } => Stmt::Let {
1884 pattern: pattern.clone(),
1885 ty: ty.clone(),
1886 init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1887 },
1888 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1889 pattern: pattern.clone(),
1890 ty: ty.clone(),
1891 init: self.pass_inline_expr(init),
1892 else_branch: Box::new(self.pass_inline_expr(else_branch)),
1893 },
1894 Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1895 Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1896 Stmt::Item(item) => Stmt::Item(item.clone()),
1897 }
1898 }
1899
1900 fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1901 match expr {
1902 Expr::Call { func, args } => {
1903 let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1905
1906 if let Expr::Path(path) = func.as_ref() {
1908 if path.segments.len() == 1 {
1909 let func_name = &path.segments[0].ident.name;
1910 if let Some(target_func) = self.functions.get(func_name).cloned() {
1911 if self.should_inline(&target_func)
1912 && args.len() == target_func.params.len()
1913 {
1914 if let Some(inlined) = self.inline_call(&target_func, &args) {
1915 return inlined;
1916 }
1917 }
1918 }
1919 }
1920 }
1921
1922 Expr::Call {
1923 func: func.clone(),
1924 args,
1925 }
1926 }
1927 Expr::Binary { op, left, right } => Expr::Binary {
1928 op: op.clone(),
1929 left: Box::new(self.pass_inline_expr(left)),
1930 right: Box::new(self.pass_inline_expr(right)),
1931 },
1932 Expr::Unary { op, expr: inner } => Expr::Unary {
1933 op: *op,
1934 expr: Box::new(self.pass_inline_expr(inner)),
1935 },
1936 Expr::If {
1937 condition,
1938 then_branch,
1939 else_branch,
1940 } => Expr::If {
1941 condition: Box::new(self.pass_inline_expr(condition)),
1942 then_branch: self.pass_inline_block(then_branch),
1943 else_branch: else_branch
1944 .as_ref()
1945 .map(|e| Box::new(self.pass_inline_expr(e))),
1946 },
1947 Expr::While { label, condition, body } => Expr::While {
1948 label: label.clone(),
1949 condition: Box::new(self.pass_inline_expr(condition)),
1950 body: self.pass_inline_block(body),
1951 },
1952 Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
1953 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
1954 Expr::Assign { target, value } => Expr::Assign {
1955 target: target.clone(),
1956 value: Box::new(self.pass_inline_expr(value)),
1957 },
1958 Expr::Index { expr: e, index } => Expr::Index {
1959 expr: Box::new(self.pass_inline_expr(e)),
1960 index: Box::new(self.pass_inline_expr(index)),
1961 },
1962 Expr::Array(elements) => {
1963 Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
1964 }
1965 other => other.clone(),
1966 }
1967 }
1968
1969 fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
1975 let stmts = block
1976 .stmts
1977 .iter()
1978 .map(|s| self.pass_loop_unroll_stmt(s))
1979 .collect();
1980 let expr = block
1981 .expr
1982 .as_ref()
1983 .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
1984 Block { stmts, expr }
1985 }
1986
1987 fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
1988 match stmt {
1989 Stmt::Let { pattern, ty, init } => Stmt::Let {
1990 pattern: pattern.clone(),
1991 ty: ty.clone(),
1992 init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
1993 },
1994 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1995 pattern: pattern.clone(),
1996 ty: ty.clone(),
1997 init: self.pass_loop_unroll_expr(init),
1998 else_branch: Box::new(self.pass_loop_unroll_expr(else_branch)),
1999 },
2000 Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
2001 Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
2002 Stmt::Item(item) => Stmt::Item(item.clone()),
2003 }
2004 }
2005
2006 fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
2007 match expr {
2008 Expr::While { label, condition, body } => {
2009 if let Some(unrolled) = self.try_unroll_loop(condition, body) {
2011 self.stats.loops_optimized += 1;
2012 return unrolled;
2013 }
2014 Expr::While {
2016 label: label.clone(),
2017 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2018 body: self.pass_loop_unroll_block(body),
2019 }
2020 }
2021 Expr::If {
2022 condition,
2023 then_branch,
2024 else_branch,
2025 } => Expr::If {
2026 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2027 then_branch: self.pass_loop_unroll_block(then_branch),
2028 else_branch: else_branch
2029 .as_ref()
2030 .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
2031 },
2032 Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
2033 Expr::Binary { op, left, right } => Expr::Binary {
2034 op: *op,
2035 left: Box::new(self.pass_loop_unroll_expr(left)),
2036 right: Box::new(self.pass_loop_unroll_expr(right)),
2037 },
2038 Expr::Unary { op, expr: inner } => Expr::Unary {
2039 op: *op,
2040 expr: Box::new(self.pass_loop_unroll_expr(inner)),
2041 },
2042 Expr::Call { func, args } => Expr::Call {
2043 func: func.clone(),
2044 args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
2045 },
2046 Expr::Return(e) => {
2047 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
2048 }
2049 Expr::Assign { target, value } => Expr::Assign {
2050 target: target.clone(),
2051 value: Box::new(self.pass_loop_unroll_expr(value)),
2052 },
2053 other => other.clone(),
2054 }
2055 }
2056
2057 fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2060 let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2062
2063 if upper_bound > 8 || upper_bound <= 0 {
2065 return None;
2066 }
2067
2068 if !self.body_has_simple_increment(&loop_var, body) {
2070 return None;
2071 }
2072
2073 let stmt_count = body.stmts.len();
2075 if stmt_count > 5 {
2076 return None;
2077 }
2078
2079 let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2081
2082 for i in 0..upper_bound {
2083 let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2085
2086 for stmt in &substituted_body.stmts {
2088 if !self.is_increment_stmt(&loop_var, stmt) {
2089 unrolled_stmts.push(stmt.clone());
2090 }
2091 }
2092 }
2093
2094 Some(Expr::Block(Block {
2096 stmts: unrolled_stmts,
2097 expr: None,
2098 }))
2099 }
2100
2101 fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2103 if let Expr::Binary {
2104 op: BinOp::Lt,
2105 left,
2106 right,
2107 } = condition
2108 {
2109 if let Expr::Path(path) = left.as_ref() {
2111 if path.segments.len() == 1 {
2112 let var_name = path.segments[0].ident.name.clone();
2113 if let Some(bound) = self.as_int(right) {
2115 return Some((var_name, bound));
2116 }
2117 }
2118 }
2119 }
2120 None
2121 }
2122
2123 fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2125 for stmt in &body.stmts {
2126 if self.is_increment_stmt(loop_var, stmt) {
2127 return true;
2128 }
2129 }
2130 false
2131 }
2132
2133 fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2135 match stmt {
2136 Stmt::Semi(Expr::Assign { target, value })
2137 | Stmt::Expr(Expr::Assign { target, value }) => {
2138 if let Expr::Path(path) = target.as_ref() {
2140 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2141 if let Expr::Binary {
2143 op: BinOp::Add,
2144 left,
2145 right,
2146 } = value.as_ref()
2147 {
2148 if let Expr::Path(lpath) = left.as_ref() {
2149 if lpath.segments.len() == 1
2150 && lpath.segments[0].ident.name == var_name
2151 {
2152 if let Some(1) = self.as_int(right) {
2153 return true;
2154 }
2155 }
2156 }
2157 }
2158 }
2159 }
2160 false
2161 }
2162 _ => false,
2163 }
2164 }
2165
2166 fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2168 let stmts = block
2169 .stmts
2170 .iter()
2171 .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2172 .collect();
2173 let expr = block
2174 .expr
2175 .as_ref()
2176 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2177 Block { stmts, expr }
2178 }
2179
2180 fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2181 match stmt {
2182 Stmt::Let { pattern, ty, init } => Stmt::Let {
2183 pattern: pattern.clone(),
2184 ty: ty.clone(),
2185 init: init
2186 .as_ref()
2187 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2188 },
2189 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2190 pattern: pattern.clone(),
2191 ty: ty.clone(),
2192 init: self.substitute_loop_var_in_expr(init, var_name, value),
2193 else_branch: Box::new(self.substitute_loop_var_in_expr(else_branch, var_name, value)),
2194 },
2195 Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2196 Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2197 Stmt::Item(item) => Stmt::Item(item.clone()),
2198 }
2199 }
2200
2201 fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2202 match expr {
2203 Expr::Path(path) => {
2204 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2205 return Expr::Literal(Literal::Int {
2206 value: value.to_string(),
2207 base: NumBase::Decimal,
2208 suffix: None,
2209 });
2210 }
2211 expr.clone()
2212 }
2213 Expr::Binary { op, left, right } => Expr::Binary {
2214 op: *op,
2215 left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2216 right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2217 },
2218 Expr::Unary { op, expr: inner } => Expr::Unary {
2219 op: *op,
2220 expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2221 },
2222 Expr::Call { func, args } => Expr::Call {
2223 func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2224 args: args
2225 .iter()
2226 .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2227 .collect(),
2228 },
2229 Expr::If {
2230 condition,
2231 then_branch,
2232 else_branch,
2233 } => Expr::If {
2234 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2235 then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2236 else_branch: else_branch
2237 .as_ref()
2238 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2239 },
2240 Expr::While { label, condition, body } => Expr::While {
2241 label: label.clone(),
2242 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2243 body: self.substitute_loop_var_in_block(body, var_name, value),
2244 },
2245 Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2246 Expr::Return(e) => Expr::Return(
2247 e.as_ref()
2248 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2249 ),
2250 Expr::Assign { target, value: v } => Expr::Assign {
2251 target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
2252 value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2253 },
2254 Expr::Index { expr: e, index } => Expr::Index {
2255 expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2256 index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2257 },
2258 Expr::Array(elements) => Expr::Array(
2259 elements
2260 .iter()
2261 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2262 .collect(),
2263 ),
2264 other => other.clone(),
2265 }
2266 }
2267
2268 fn pass_licm_block(&mut self, block: &Block) -> Block {
2274 let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2275 let expr = block
2276 .expr
2277 .as_ref()
2278 .map(|e| Box::new(self.pass_licm_expr(e)));
2279 Block { stmts, expr }
2280 }
2281
2282 fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2283 match stmt {
2284 Stmt::Let { pattern, ty, init } => Stmt::Let {
2285 pattern: pattern.clone(),
2286 ty: ty.clone(),
2287 init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2288 },
2289 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2290 pattern: pattern.clone(),
2291 ty: ty.clone(),
2292 init: self.pass_licm_expr(init),
2293 else_branch: Box::new(self.pass_licm_expr(else_branch)),
2294 },
2295 Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2296 Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2297 Stmt::Item(item) => Stmt::Item(item.clone()),
2298 }
2299 }
2300
2301 fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2302 match expr {
2303 Expr::While { label, condition, body } => {
2304 let mut modified_vars = HashSet::new();
2306 self.collect_modified_vars_block(body, &mut modified_vars);
2307
2308 self.collect_modified_vars_expr(condition, &mut modified_vars);
2310
2311 let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2313 self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2314
2315 if invariant_exprs.is_empty() {
2316 return Expr::While {
2318 label: label.clone(),
2319 condition: Box::new(self.pass_licm_expr(condition)),
2320 body: self.pass_licm_block(body),
2321 };
2322 }
2323
2324 let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2326 let mut substitution_map: HashMap<String, String> = HashMap::new();
2327
2328 for (original_key, invariant_expr) in &invariant_exprs {
2329 let var_name = format!("__licm_{}", self.cse_counter);
2330 self.cse_counter += 1;
2331
2332 pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2333 substitution_map.insert(original_key.clone(), var_name);
2334 self.stats.loops_optimized += 1;
2335 }
2336
2337 let new_body =
2339 self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2340
2341 let new_while = Expr::While {
2343 label: label.clone(),
2344 condition: Box::new(self.pass_licm_expr(condition)),
2345 body: self.pass_licm_block(&new_body),
2346 };
2347
2348 pre_loop_stmts.push(Stmt::Expr(new_while));
2350 Expr::Block(Block {
2351 stmts: pre_loop_stmts,
2352 expr: None,
2353 })
2354 }
2355 Expr::If {
2356 condition,
2357 then_branch,
2358 else_branch,
2359 } => Expr::If {
2360 condition: Box::new(self.pass_licm_expr(condition)),
2361 then_branch: self.pass_licm_block(then_branch),
2362 else_branch: else_branch
2363 .as_ref()
2364 .map(|e| Box::new(self.pass_licm_expr(e))),
2365 },
2366 Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2367 Expr::Binary { op, left, right } => Expr::Binary {
2368 op: *op,
2369 left: Box::new(self.pass_licm_expr(left)),
2370 right: Box::new(self.pass_licm_expr(right)),
2371 },
2372 Expr::Unary { op, expr: inner } => Expr::Unary {
2373 op: *op,
2374 expr: Box::new(self.pass_licm_expr(inner)),
2375 },
2376 Expr::Call { func, args } => Expr::Call {
2377 func: func.clone(),
2378 args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2379 },
2380 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2381 Expr::Assign { target, value } => Expr::Assign {
2382 target: target.clone(),
2383 value: Box::new(self.pass_licm_expr(value)),
2384 },
2385 other => other.clone(),
2386 }
2387 }
2388
2389 fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2391 for stmt in &block.stmts {
2392 self.collect_modified_vars_stmt(stmt, modified);
2393 }
2394 if let Some(expr) = &block.expr {
2395 self.collect_modified_vars_expr(expr, modified);
2396 }
2397 }
2398
2399 fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2400 match stmt {
2401 Stmt::Let { pattern, init, .. } => {
2402 if let Pattern::Ident { name, .. } = pattern {
2404 modified.insert(name.name.clone());
2405 }
2406 if let Some(e) = init {
2407 self.collect_modified_vars_expr(e, modified);
2408 }
2409 }
2410 Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2411 _ => {}
2412 }
2413 }
2414
2415 fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2416 match expr {
2417 Expr::Assign { target, value } => {
2418 if let Expr::Path(path) = target.as_ref() {
2419 if path.segments.len() == 1 {
2420 modified.insert(path.segments[0].ident.name.clone());
2421 }
2422 }
2423 self.collect_modified_vars_expr(value, modified);
2424 }
2425 Expr::Binary { left, right, .. } => {
2426 self.collect_modified_vars_expr(left, modified);
2427 self.collect_modified_vars_expr(right, modified);
2428 }
2429 Expr::Unary { expr: inner, .. } => {
2430 self.collect_modified_vars_expr(inner, modified);
2431 }
2432 Expr::If {
2433 condition,
2434 then_branch,
2435 else_branch,
2436 } => {
2437 self.collect_modified_vars_expr(condition, modified);
2438 self.collect_modified_vars_block(then_branch, modified);
2439 if let Some(e) = else_branch {
2440 self.collect_modified_vars_expr(e, modified);
2441 }
2442 }
2443 Expr::While { label, condition, body } => {
2444 self.collect_modified_vars_expr(condition, modified);
2445 self.collect_modified_vars_block(body, modified);
2446 }
2447 Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2448 Expr::Call { args, .. } => {
2449 for arg in args {
2450 self.collect_modified_vars_expr(arg, modified);
2451 }
2452 }
2453 Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2454 _ => {}
2455 }
2456 }
2457
2458 fn find_loop_invariants(
2460 &self,
2461 block: &Block,
2462 modified: &HashSet<String>,
2463 out: &mut Vec<(String, Expr)>,
2464 ) {
2465 for stmt in &block.stmts {
2466 self.find_loop_invariants_stmt(stmt, modified, out);
2467 }
2468 if let Some(expr) = &block.expr {
2469 self.find_loop_invariants_expr(expr, modified, out);
2470 }
2471 }
2472
2473 fn find_loop_invariants_stmt(
2474 &self,
2475 stmt: &Stmt,
2476 modified: &HashSet<String>,
2477 out: &mut Vec<(String, Expr)>,
2478 ) {
2479 match stmt {
2480 Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2481 Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2482 _ => {}
2483 }
2484 }
2485
2486 fn find_loop_invariants_expr(
2487 &self,
2488 expr: &Expr,
2489 modified: &HashSet<String>,
2490 out: &mut Vec<(String, Expr)>,
2491 ) {
2492 match expr {
2494 Expr::Binary { left, right, .. } => {
2495 self.find_loop_invariants_expr(left, modified, out);
2496 self.find_loop_invariants_expr(right, modified, out);
2497 }
2498 Expr::Unary { expr: inner, .. } => {
2499 self.find_loop_invariants_expr(inner, modified, out);
2500 }
2501 Expr::Call { args, .. } => {
2502 for arg in args {
2503 self.find_loop_invariants_expr(arg, modified, out);
2504 }
2505 }
2506 Expr::Index { expr: e, index } => {
2507 self.find_loop_invariants_expr(e, modified, out);
2508 self.find_loop_invariants_expr(index, modified, out);
2509 }
2510 _ => {}
2511 }
2512
2513 if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2515 let key = format!("{:?}", expr_hash(expr));
2516 if !out.iter().any(|(k, _)| k == &key) {
2518 out.push((key, expr.clone()));
2519 }
2520 }
2521 }
2522
2523 fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2525 match expr {
2526 Expr::Literal(_) => true,
2527 Expr::Path(path) => {
2528 if path.segments.len() == 1 {
2529 !modified.contains(&path.segments[0].ident.name)
2530 } else {
2531 true }
2533 }
2534 Expr::Binary { left, right, .. } => {
2535 self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2536 }
2537 Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2538 Expr::Index { expr: e, index } => {
2539 self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2540 }
2541 Expr::Call { .. } => false,
2543 _ => false,
2545 }
2546 }
2547
2548 fn replace_invariants_in_block(
2550 &self,
2551 block: &Block,
2552 invariants: &[(String, Expr)],
2553 subs: &HashMap<String, String>,
2554 ) -> Block {
2555 let stmts = block
2556 .stmts
2557 .iter()
2558 .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2559 .collect();
2560 let expr = block
2561 .expr
2562 .as_ref()
2563 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2564 Block { stmts, expr }
2565 }
2566
2567 fn replace_invariants_in_stmt(
2568 &self,
2569 stmt: &Stmt,
2570 invariants: &[(String, Expr)],
2571 subs: &HashMap<String, String>,
2572 ) -> Stmt {
2573 match stmt {
2574 Stmt::Let { pattern, ty, init } => Stmt::Let {
2575 pattern: pattern.clone(),
2576 ty: ty.clone(),
2577 init: init
2578 .as_ref()
2579 .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2580 },
2581 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2582 pattern: pattern.clone(),
2583 ty: ty.clone(),
2584 init: self.replace_invariants_in_expr(init, invariants, subs),
2585 else_branch: Box::new(self.replace_invariants_in_expr(else_branch, invariants, subs)),
2586 },
2587 Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2588 Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2589 Stmt::Item(item) => Stmt::Item(item.clone()),
2590 }
2591 }
2592
2593 fn replace_invariants_in_expr(
2594 &self,
2595 expr: &Expr,
2596 invariants: &[(String, Expr)],
2597 subs: &HashMap<String, String>,
2598 ) -> Expr {
2599 let key = format!("{:?}", expr_hash(expr));
2601 for (inv_key, inv_expr) in invariants {
2602 if &key == inv_key && expr_eq(expr, inv_expr) {
2603 if let Some(var_name) = subs.get(inv_key) {
2604 return Expr::Path(TypePath {
2605 segments: vec![PathSegment {
2606 ident: Ident {
2607 name: var_name.clone(),
2608 evidentiality: None,
2609 affect: None,
2610 span: Span { start: 0, end: 0 },
2611 },
2612 generics: None,
2613 }],
2614 });
2615 }
2616 }
2617 }
2618
2619 match expr {
2621 Expr::Binary { op, left, right } => Expr::Binary {
2622 op: *op,
2623 left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2624 right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2625 },
2626 Expr::Unary { op, expr: inner } => Expr::Unary {
2627 op: *op,
2628 expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2629 },
2630 Expr::Call { func, args } => Expr::Call {
2631 func: func.clone(),
2632 args: args
2633 .iter()
2634 .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2635 .collect(),
2636 },
2637 Expr::If {
2638 condition,
2639 then_branch,
2640 else_branch,
2641 } => Expr::If {
2642 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2643 then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2644 else_branch: else_branch
2645 .as_ref()
2646 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2647 },
2648 Expr::While { label, condition, body } => Expr::While {
2649 label: label.clone(),
2650 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2651 body: self.replace_invariants_in_block(body, invariants, subs),
2652 },
2653 Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2654 Expr::Return(e) => Expr::Return(
2655 e.as_ref()
2656 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2657 ),
2658 Expr::Assign { target, value } => Expr::Assign {
2659 target: target.clone(),
2660 value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2661 },
2662 Expr::Index { expr: e, index } => Expr::Index {
2663 expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2664 index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2665 },
2666 other => other.clone(),
2667 }
2668 }
2669
2670 fn pass_cse_block(&mut self, block: &Block) -> Block {
2675 let mut collected = Vec::new();
2677 collect_exprs_from_block(block, &mut collected);
2678
2679 let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2681 for ce in &collected {
2682 let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2683 let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2685 if !found {
2686 entry.push(ce.expr.clone());
2687 }
2688 }
2689
2690 let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2692 for ce in &collected {
2693 let existing = occurrence_counts
2695 .iter_mut()
2696 .find(|(e, _)| expr_eq(e, &ce.expr));
2697 if let Some((_, count)) = existing {
2698 *count += 1;
2699 } else {
2700 occurrence_counts.push((ce.expr.clone(), 1));
2701 }
2702 }
2703
2704 let candidates: Vec<Expr> = occurrence_counts
2706 .into_iter()
2707 .filter(|(_, count)| *count >= 2)
2708 .map(|(expr, _)| expr)
2709 .collect();
2710
2711 if candidates.is_empty() {
2712 return self.pass_cse_nested(block);
2714 }
2715
2716 let mut result_block = block.clone();
2718 let mut new_lets: Vec<Stmt> = Vec::new();
2719
2720 for expr in candidates {
2721 let var_name = format!("__cse_{}", self.cse_counter);
2722 self.cse_counter += 1;
2723
2724 new_lets.push(make_cse_let(&var_name, expr.clone()));
2726
2727 result_block = replace_in_block(&result_block, &expr, &var_name);
2729
2730 self.stats.expressions_deduplicated += 1;
2731 }
2732
2733 let mut final_stmts = new_lets;
2735 final_stmts.extend(result_block.stmts);
2736
2737 let result = Block {
2739 stmts: final_stmts,
2740 expr: result_block.expr,
2741 };
2742 self.pass_cse_nested(&result)
2743 }
2744
2745 fn pass_cse_nested(&mut self, block: &Block) -> Block {
2747 let stmts = block
2748 .stmts
2749 .iter()
2750 .map(|stmt| self.pass_cse_stmt(stmt))
2751 .collect();
2752 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2753 Block { stmts, expr }
2754 }
2755
2756 fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2757 match stmt {
2758 Stmt::Let { pattern, ty, init } => Stmt::Let {
2759 pattern: pattern.clone(),
2760 ty: ty.clone(),
2761 init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2762 },
2763 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2764 pattern: pattern.clone(),
2765 ty: ty.clone(),
2766 init: self.pass_cse_expr(init),
2767 else_branch: Box::new(self.pass_cse_expr(else_branch)),
2768 },
2769 Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2770 Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2771 Stmt::Item(item) => Stmt::Item(item.clone()),
2772 }
2773 }
2774
2775 fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2776 match expr {
2777 Expr::If {
2778 condition,
2779 then_branch,
2780 else_branch,
2781 } => Expr::If {
2782 condition: Box::new(self.pass_cse_expr(condition)),
2783 then_branch: self.pass_cse_block(then_branch),
2784 else_branch: else_branch
2785 .as_ref()
2786 .map(|e| Box::new(self.pass_cse_expr(e))),
2787 },
2788 Expr::While { label, condition, body } => Expr::While {
2789 label: label.clone(),
2790 condition: Box::new(self.pass_cse_expr(condition)),
2791 body: self.pass_cse_block(body),
2792 },
2793 Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2794 Expr::Binary { op, left, right } => Expr::Binary {
2795 op: *op,
2796 left: Box::new(self.pass_cse_expr(left)),
2797 right: Box::new(self.pass_cse_expr(right)),
2798 },
2799 Expr::Unary { op, expr: inner } => Expr::Unary {
2800 op: *op,
2801 expr: Box::new(self.pass_cse_expr(inner)),
2802 },
2803 Expr::Call { func, args } => Expr::Call {
2804 func: func.clone(),
2805 args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2806 },
2807 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2808 Expr::Assign { target, value } => Expr::Assign {
2809 target: target.clone(),
2810 value: Box::new(self.pass_cse_expr(value)),
2811 },
2812 other => other.clone(),
2813 }
2814 }
2815}
2816
2817fn expr_hash(expr: &Expr) -> u64 {
2823 use std::collections::hash_map::DefaultHasher;
2824 use std::hash::Hasher;
2825
2826 let mut hasher = DefaultHasher::new();
2827 expr_hash_recursive(expr, &mut hasher);
2828 hasher.finish()
2829}
2830
2831fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2832 use std::hash::Hash;
2833
2834 std::mem::discriminant(expr).hash(hasher);
2835
2836 match expr {
2837 Expr::Literal(lit) => match lit {
2838 Literal::Int { value, .. } => value.hash(hasher),
2839 Literal::Float { value, .. } => value.hash(hasher),
2840 Literal::String(s) => s.hash(hasher),
2841 Literal::Char(c) => c.hash(hasher),
2842 Literal::Bool(b) => b.hash(hasher),
2843 _ => {}
2844 },
2845 Expr::Path(path) => {
2846 for seg in &path.segments {
2847 seg.ident.name.hash(hasher);
2848 }
2849 }
2850 Expr::Binary { op, left, right } => {
2851 std::mem::discriminant(op).hash(hasher);
2852 expr_hash_recursive(left, hasher);
2853 expr_hash_recursive(right, hasher);
2854 }
2855 Expr::Unary { op, expr } => {
2856 std::mem::discriminant(op).hash(hasher);
2857 expr_hash_recursive(expr, hasher);
2858 }
2859 Expr::Call { func, args } => {
2860 expr_hash_recursive(func, hasher);
2861 args.len().hash(hasher);
2862 for arg in args {
2863 expr_hash_recursive(arg, hasher);
2864 }
2865 }
2866 Expr::Index { expr, index } => {
2867 expr_hash_recursive(expr, hasher);
2868 expr_hash_recursive(index, hasher);
2869 }
2870 _ => {}
2871 }
2872}
2873
2874fn is_pure_expr(expr: &Expr) -> bool {
2876 match expr {
2877 Expr::Literal(_) => true,
2878 Expr::Path(_) => true,
2879 Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
2880 Expr::Unary { expr, .. } => is_pure_expr(expr),
2881 Expr::If {
2882 condition,
2883 then_branch,
2884 else_branch,
2885 } => {
2886 is_pure_expr(condition)
2887 && then_branch.stmts.is_empty()
2888 && then_branch
2889 .expr
2890 .as_ref()
2891 .map(|e| is_pure_expr(e))
2892 .unwrap_or(true)
2893 && else_branch
2894 .as_ref()
2895 .map(|e| is_pure_expr(e))
2896 .unwrap_or(true)
2897 }
2898 Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
2899 Expr::Array(elements) => elements.iter().all(is_pure_expr),
2900 Expr::Call { .. } => false,
2902 Expr::Assign { .. } => false,
2903 Expr::Return(_) => false,
2904 _ => false,
2905 }
2906}
2907
2908fn is_cse_worthy(expr: &Expr) -> bool {
2910 match expr {
2911 Expr::Literal(_) => false,
2913 Expr::Path(_) => false,
2914 Expr::Binary { .. } => true,
2916 Expr::Unary { .. } => true,
2918 Expr::Call { .. } => false,
2920 Expr::Index { .. } => true,
2922 _ => false,
2923 }
2924}
2925
2926fn expr_eq(a: &Expr, b: &Expr) -> bool {
2928 match (a, b) {
2929 (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
2930 (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
2931 (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
2932 (Literal::String(sa), Literal::String(sb)) => sa == sb,
2933 (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
2934 (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
2935 _ => false,
2936 },
2937 (Expr::Path(pa), Expr::Path(pb)) => {
2938 pa.segments.len() == pb.segments.len()
2939 && pa
2940 .segments
2941 .iter()
2942 .zip(&pb.segments)
2943 .all(|(sa, sb)| sa.ident.name == sb.ident.name)
2944 }
2945 (
2946 Expr::Binary {
2947 op: oa,
2948 left: la,
2949 right: ra,
2950 },
2951 Expr::Binary {
2952 op: ob,
2953 left: lb,
2954 right: rb,
2955 },
2956 ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
2957 (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
2958 oa == ob && expr_eq(ea, eb)
2959 }
2960 (
2961 Expr::Index {
2962 expr: ea,
2963 index: ia,
2964 },
2965 Expr::Index {
2966 expr: eb,
2967 index: ib,
2968 },
2969 ) => expr_eq(ea, eb) && expr_eq(ia, ib),
2970 (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
2971 expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
2972 }
2973 _ => false,
2974 }
2975}
2976
2977#[derive(Clone)]
2979struct CollectedExpr {
2980 expr: Expr,
2981 hash: u64,
2982}
2983
2984fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
2986 match expr {
2988 Expr::Binary { left, right, .. } => {
2989 collect_exprs_from_expr(left, out);
2990 collect_exprs_from_expr(right, out);
2991 }
2992 Expr::Unary { expr: inner, .. } => {
2993 collect_exprs_from_expr(inner, out);
2994 }
2995 Expr::Index { expr: e, index } => {
2996 collect_exprs_from_expr(e, out);
2997 collect_exprs_from_expr(index, out);
2998 }
2999 Expr::Call { func, args } => {
3000 collect_exprs_from_expr(func, out);
3001 for arg in args {
3002 collect_exprs_from_expr(arg, out);
3003 }
3004 }
3005 Expr::If {
3006 condition,
3007 then_branch,
3008 else_branch,
3009 } => {
3010 collect_exprs_from_expr(condition, out);
3011 collect_exprs_from_block(then_branch, out);
3012 if let Some(else_expr) = else_branch {
3013 collect_exprs_from_expr(else_expr, out);
3014 }
3015 }
3016 Expr::While { label, condition, body } => {
3017 collect_exprs_from_expr(condition, out);
3018 collect_exprs_from_block(body, out);
3019 }
3020 Expr::Block(block) => {
3021 collect_exprs_from_block(block, out);
3022 }
3023 Expr::Return(Some(e)) => {
3024 collect_exprs_from_expr(e, out);
3025 }
3026 Expr::Assign { value, .. } => {
3027 collect_exprs_from_expr(value, out);
3028 }
3029 Expr::Array(elements) => {
3030 for e in elements {
3031 collect_exprs_from_expr(e, out);
3032 }
3033 }
3034 _ => {}
3035 }
3036
3037 if is_cse_worthy(expr) && is_pure_expr(expr) {
3039 out.push(CollectedExpr {
3040 expr: expr.clone(),
3041 hash: expr_hash(expr),
3042 });
3043 }
3044}
3045
3046fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
3048 for stmt in &block.stmts {
3049 match stmt {
3050 Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
3051 Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
3052 _ => {}
3053 }
3054 }
3055 if let Some(e) = &block.expr {
3056 collect_exprs_from_expr(e, out);
3057 }
3058}
3059
3060fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
3062 if expr_eq(expr, target) {
3064 return Expr::Path(TypePath {
3065 segments: vec![PathSegment {
3066 ident: Ident {
3067 name: var_name.to_string(),
3068 evidentiality: None,
3069 affect: None,
3070 span: Span { start: 0, end: 0 },
3071 },
3072 generics: None,
3073 }],
3074 });
3075 }
3076
3077 match expr {
3079 Expr::Binary { op, left, right } => Expr::Binary {
3080 op: *op,
3081 left: Box::new(replace_in_expr(left, target, var_name)),
3082 right: Box::new(replace_in_expr(right, target, var_name)),
3083 },
3084 Expr::Unary { op, expr: inner } => Expr::Unary {
3085 op: *op,
3086 expr: Box::new(replace_in_expr(inner, target, var_name)),
3087 },
3088 Expr::Index { expr: e, index } => Expr::Index {
3089 expr: Box::new(replace_in_expr(e, target, var_name)),
3090 index: Box::new(replace_in_expr(index, target, var_name)),
3091 },
3092 Expr::Call { func, args } => Expr::Call {
3093 func: Box::new(replace_in_expr(func, target, var_name)),
3094 args: args
3095 .iter()
3096 .map(|a| replace_in_expr(a, target, var_name))
3097 .collect(),
3098 },
3099 Expr::If {
3100 condition,
3101 then_branch,
3102 else_branch,
3103 } => Expr::If {
3104 condition: Box::new(replace_in_expr(condition, target, var_name)),
3105 then_branch: replace_in_block(then_branch, target, var_name),
3106 else_branch: else_branch
3107 .as_ref()
3108 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3109 },
3110 Expr::While { label, condition, body } => Expr::While {
3111 label: label.clone(),
3112 condition: Box::new(replace_in_expr(condition, target, var_name)),
3113 body: replace_in_block(body, target, var_name),
3114 },
3115 Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3116 Expr::Return(e) => Expr::Return(
3117 e.as_ref()
3118 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3119 ),
3120 Expr::Assign { target: t, value } => Expr::Assign {
3121 target: t.clone(),
3122 value: Box::new(replace_in_expr(value, target, var_name)),
3123 },
3124 Expr::Array(elements) => Expr::Array(
3125 elements
3126 .iter()
3127 .map(|e| replace_in_expr(e, target, var_name))
3128 .collect(),
3129 ),
3130 other => other.clone(),
3131 }
3132}
3133
3134fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3136 let stmts = block
3137 .stmts
3138 .iter()
3139 .map(|stmt| match stmt {
3140 Stmt::Let { pattern, ty, init } => Stmt::Let {
3141 pattern: pattern.clone(),
3142 ty: ty.clone(),
3143 init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3144 },
3145 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
3146 pattern: pattern.clone(),
3147 ty: ty.clone(),
3148 init: replace_in_expr(init, target, var_name),
3149 else_branch: Box::new(replace_in_expr(else_branch, target, var_name)),
3150 },
3151 Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3152 Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3153 Stmt::Item(item) => Stmt::Item(item.clone()),
3154 })
3155 .collect();
3156
3157 let expr = block
3158 .expr
3159 .as_ref()
3160 .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3161
3162 Block { stmts, expr }
3163}
3164
3165fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3167 Stmt::Let {
3168 pattern: Pattern::Ident {
3169 mutable: false,
3170 name: Ident {
3171 name: var_name.to_string(),
3172 evidentiality: None,
3173 affect: None,
3174 span: Span { start: 0, end: 0 },
3175 },
3176 evidentiality: None,
3177 },
3178 ty: None,
3179 init: Some(expr),
3180 }
3181}
3182
3183pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3189 let mut optimizer = Optimizer::new(level);
3190 let optimized = optimizer.optimize_file(file);
3191 (optimized, optimizer.stats)
3192}
3193
3194#[cfg(test)]
3199mod tests {
3200 use super::*;
3201
3202 fn int_lit(v: i64) -> Expr {
3204 Expr::Literal(Literal::Int {
3205 value: v.to_string(),
3206 base: NumBase::Decimal,
3207 suffix: None,
3208 })
3209 }
3210
3211 fn var(name: &str) -> Expr {
3213 Expr::Path(TypePath {
3214 segments: vec![PathSegment {
3215 ident: Ident {
3216 name: name.to_string(),
3217 evidentiality: None,
3218 affect: None,
3219 span: Span { start: 0, end: 0 },
3220 },
3221 generics: None,
3222 }],
3223 })
3224 }
3225
3226 fn add(left: Expr, right: Expr) -> Expr {
3228 Expr::Binary {
3229 op: BinOp::Add,
3230 left: Box::new(left),
3231 right: Box::new(right),
3232 }
3233 }
3234
3235 fn mul(left: Expr, right: Expr) -> Expr {
3237 Expr::Binary {
3238 op: BinOp::Mul,
3239 left: Box::new(left),
3240 right: Box::new(right),
3241 }
3242 }
3243
3244 #[test]
3245 fn test_expr_hash_equal() {
3246 let e1 = add(var("a"), var("b"));
3248 let e2 = add(var("a"), var("b"));
3249 assert_eq!(expr_hash(&e1), expr_hash(&e2));
3250 }
3251
3252 #[test]
3253 fn test_expr_hash_different() {
3254 let e1 = add(var("a"), var("b"));
3256 let e2 = add(var("a"), var("c"));
3257 assert_ne!(expr_hash(&e1), expr_hash(&e2));
3258 }
3259
3260 #[test]
3261 fn test_expr_eq() {
3262 let e1 = add(var("a"), var("b"));
3263 let e2 = add(var("a"), var("b"));
3264 let e3 = add(var("a"), var("c"));
3265
3266 assert!(expr_eq(&e1, &e2));
3267 assert!(!expr_eq(&e1, &e3));
3268 }
3269
3270 #[test]
3271 fn test_is_pure_expr() {
3272 assert!(is_pure_expr(&int_lit(42)));
3273 assert!(is_pure_expr(&var("x")));
3274 assert!(is_pure_expr(&add(var("a"), var("b"))));
3275
3276 let call = Expr::Call {
3278 func: Box::new(var("print")),
3279 args: vec![int_lit(42)],
3280 };
3281 assert!(!is_pure_expr(&call));
3282 }
3283
3284 #[test]
3285 fn test_is_cse_worthy() {
3286 assert!(!is_cse_worthy(&int_lit(42))); assert!(!is_cse_worthy(&var("x"))); assert!(is_cse_worthy(&add(var("a"), var("b")))); }
3290
3291 #[test]
3292 fn test_cse_basic() {
3293 let a_plus_b = add(var("a"), var("b"));
3298
3299 let block = Block {
3300 stmts: vec![
3301 Stmt::Let {
3302 pattern: Pattern::Ident {
3303 mutable: false,
3304 name: Ident {
3305 name: "x".to_string(),
3306 evidentiality: None,
3307 affect: None,
3308 span: Span { start: 0, end: 0 },
3309 },
3310 evidentiality: None,
3311 },
3312 ty: None,
3313 init: Some(a_plus_b.clone()),
3314 },
3315 Stmt::Let {
3316 pattern: Pattern::Ident {
3317 mutable: false,
3318 name: Ident {
3319 name: "y".to_string(),
3320 evidentiality: None,
3321 affect: None,
3322 span: Span { start: 0, end: 0 },
3323 },
3324 evidentiality: None,
3325 },
3326 ty: None,
3327 init: Some(mul(a_plus_b.clone(), int_lit(2))),
3328 },
3329 ],
3330 expr: None,
3331 };
3332
3333 let mut optimizer = Optimizer::new(OptLevel::Standard);
3334 let result = optimizer.pass_cse_block(&block);
3335
3336 assert_eq!(result.stmts.len(), 3);
3338 assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3339
3340 if let Stmt::Let {
3342 pattern: Pattern::Ident { name, .. },
3343 ..
3344 } = &result.stmts[0]
3345 {
3346 assert_eq!(name.name, "__cse_0");
3347 } else {
3348 panic!("Expected CSE let binding");
3349 }
3350 }
3351
3352 #[test]
3353 fn test_cse_no_duplicates() {
3354 let block = Block {
3356 stmts: vec![
3357 Stmt::Let {
3358 pattern: Pattern::Ident {
3359 mutable: false,
3360 name: Ident {
3361 name: "x".to_string(),
3362 evidentiality: None,
3363 affect: None,
3364 span: Span { start: 0, end: 0 },
3365 },
3366 evidentiality: None,
3367 },
3368 ty: None,
3369 init: Some(add(var("a"), var("b"))),
3370 },
3371 Stmt::Let {
3372 pattern: Pattern::Ident {
3373 mutable: false,
3374 name: Ident {
3375 name: "y".to_string(),
3376 evidentiality: None,
3377 affect: None,
3378 span: Span { start: 0, end: 0 },
3379 },
3380 evidentiality: None,
3381 },
3382 ty: None,
3383 init: Some(add(var("c"), var("d"))),
3384 },
3385 ],
3386 expr: None,
3387 };
3388
3389 let mut optimizer = Optimizer::new(OptLevel::Standard);
3390 let result = optimizer.pass_cse_block(&block);
3391
3392 assert_eq!(result.stmts.len(), 2);
3394 assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3395 }
3396}