sigil_parser/
optimize.rs

1//! Sigil Optimization Passes
2//!
3//! A comprehensive optimization framework for the Sigil language.
4//! Implements industry-standard compiler optimizations.
5
6use crate::ast::{
7    self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8    Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13// ============================================================================
14// Optimization Pass Infrastructure
15// ============================================================================
16
17/// Optimization level
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20    /// No optimizations (O0)
21    None,
22    /// Basic optimizations (O1) - constant folding, dead code elimination
23    Basic,
24    /// Standard optimizations (O2) - adds CSE, strength reduction, inlining
25    Standard,
26    /// Aggressive optimizations (O3) - adds loop opts, vectorization hints
27    Aggressive,
28    /// Size optimization (Os)
29    Size,
30}
31
32/// Statistics collected during optimization
33#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35    pub constants_folded: usize,
36    pub dead_code_eliminated: usize,
37    pub expressions_deduplicated: usize,
38    pub functions_inlined: usize,
39    pub strength_reductions: usize,
40    pub branches_simplified: usize,
41    pub loops_optimized: usize,
42    pub tail_recursion_transforms: usize,
43    pub memoization_transforms: usize,
44}
45
46/// The main optimizer that runs all passes
47pub struct Optimizer {
48    level: OptLevel,
49    stats: OptStats,
50    /// Function bodies for inlining decisions
51    functions: HashMap<String, ast::Function>,
52    /// Track which functions are recursive
53    recursive_functions: HashSet<String>,
54    /// Counter for CSE variable names
55    cse_counter: usize,
56}
57
58impl Optimizer {
59    pub fn new(level: OptLevel) -> Self {
60        Self {
61            level,
62            stats: OptStats::default(),
63            functions: HashMap::new(),
64            recursive_functions: HashSet::new(),
65            cse_counter: 0,
66        }
67    }
68
69    /// Get optimization statistics
70    pub fn stats(&self) -> &OptStats {
71        &self.stats
72    }
73
74    /// Optimize a source file (returns optimized copy)
75    pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76        // First pass: collect function information
77        for item in &file.items {
78            if let Item::Function(func) = &item.node {
79                self.functions.insert(func.name.name.clone(), func.clone());
80                if self.is_recursive(&func.name.name, func) {
81                    self.recursive_functions.insert(func.name.name.clone());
82                }
83            }
84        }
85
86        // Accumulator transformation pass (for Standard and Aggressive optimization)
87        // This transforms double-recursive functions like fib into tail-recursive form
88        // Enabled for Standard+ because it provides massive speedups (O(2^n) -> O(n))
89        let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90        let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92        if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93            for item in &file.items {
94                if let Item::Function(func) = &item.node {
95                    if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96                    {
97                        // Add helper function first
98                        new_items.push(crate::span::Spanned {
99                            node: Item::Function(helper_func),
100                            span: item.span.clone(),
101                        });
102                        transformed_functions
103                            .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104                        self.stats.tail_recursion_transforms += 1;
105                    }
106                }
107            }
108
109            // NOTE: Memoization transform is disabled for now due to complexity with
110            // cache lifetime management. Instead, use iterative implementations for
111            // functions like ackermann where beneficial.
112            //
113            // TODO: Implement proper memoization with thread-local or passed-through caches
114        }
115
116        // Optimization passes
117        let items: Vec<_> = file
118            .items
119            .iter()
120            .map(|item| {
121                let node = match &item.node {
122                    Item::Function(func) => {
123                        // Check if this function was transformed by accumulator
124                        if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125                            if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126                                && transformed_functions.contains_key(&func.name.name)
127                            {
128                                Item::Function(self.optimize_function(&wrapper))
129                            } else {
130                                Item::Function(self.optimize_function(func))
131                            }
132                        } else {
133                            Item::Function(self.optimize_function(func))
134                        }
135                    }
136                    other => other.clone(),
137                };
138                crate::span::Spanned {
139                    node,
140                    span: item.span.clone(),
141                }
142            })
143            .collect();
144
145        // Combine new helper functions with transformed items
146        new_items.extend(items);
147
148        ast::SourceFile {
149            attrs: file.attrs.clone(),
150            config: file.config.clone(),
151            items: new_items,
152        }
153    }
154
155    /// Try to transform a double-recursive function into tail-recursive form
156    /// Returns (helper_function, wrapper_function) if transformation is possible
157    fn try_accumulator_transform(
158        &self,
159        func: &ast::Function,
160    ) -> Option<(ast::Function, ast::Function)> {
161        // Only transform functions with single parameter
162        if func.params.len() != 1 {
163            return None;
164        }
165
166        // Must be recursive
167        if !self.recursive_functions.contains(&func.name.name) {
168            return None;
169        }
170
171        let body = func.body.as_ref()?;
172
173        // Detect fib-like pattern:
174        // if n <= 1 { return n; }
175        // return fib(n - 1) + fib(n - 2);
176        if !self.is_fib_like_pattern(&func.name.name, body) {
177            return None;
178        }
179
180        // Get parameter name
181        let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182            name.name.clone()
183        } else {
184            return None;
185        };
186
187        // Generate helper function name
188        let helper_name = format!("{}_tail", func.name.name);
189
190        // Create helper function: fn fib_tail(n, a, b) { if n <= 0 { return a; } return fib_tail(n - 1, b, a + b); }
191        let helper_func = self.generate_fib_helper(&helper_name, &param_name);
192
193        // Create wrapper function: fn fib(n) { return fib_tail(n, 0, 1); }
194        let wrapper_func =
195            self.generate_fib_wrapper(&func.name.name, &helper_name, &param_name, func);
196
197        Some((helper_func, wrapper_func))
198    }
199
200    /// Check if a function body matches the Fibonacci pattern
201    fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202        // Pattern we're looking for:
203        // { if n <= 1 { return n; } return f(n-1) + f(n-2); }
204        // or
205        // { if n <= 1 { return n; } f(n-1) + f(n-2) }
206
207        // Should have an if statement/expression followed by a recursive expression
208        if body.stmts.is_empty() && body.expr.is_none() {
209            return false;
210        }
211
212        // Check for the pattern in the block expression
213        if let Some(expr) = &body.expr {
214            if let Expr::If {
215                else_branch: Some(else_expr),
216                ..
217            } = expr.as_ref()
218            {
219                // Check if then_branch is a base case (return n or similar)
220                // Check if else_branch has double recursive calls
221                return self.is_double_recursive_expr(func_name, else_expr);
222            }
223        }
224
225        // Check for pattern: if ... { return ...; } return f(n-1) + f(n-2);
226        if body.stmts.len() >= 1 {
227            // Last statement or expression should be the recursive call
228            if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229                if let Expr::Return(Some(ret_expr)) = expr {
230                    return self.is_double_recursive_expr(func_name, ret_expr);
231                }
232            }
233            if let Some(expr) = &body.expr {
234                return self.is_double_recursive_expr(func_name, expr);
235            }
236        }
237
238        false
239    }
240
241    /// Check if expression is f(n-1) + f(n-2) pattern
242    fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243        if let Expr::Binary {
244            op: BinOp::Add,
245            left,
246            right,
247        } = expr
248        {
249            let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250            let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251            return left_is_recursive && right_is_recursive;
252        }
253        false
254    }
255
256    /// Check if expression is f(n - k) for some constant k
257    fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258        if let Expr::Call { func, args } = expr {
259            if let Expr::Path(path) = func.as_ref() {
260                if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261                    // Check if argument is n - constant
262                    if args.len() == 1 {
263                        if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264                            return true;
265                        }
266                    }
267                }
268            }
269        }
270        false
271    }
272
273    /// Generate the tail-recursive helper function
274    fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275        let span = Span { start: 0, end: 0 };
276
277        // fn fib_tail(n, a, b) {
278        //     if n <= 0 { return a; }
279        //     return fib_tail(n - 1, b, a + b);
280        // }
281        let n_ident = Ident {
282            name: "n".to_string(),
283            evidentiality: None,
284            affect: None,
285            span: span.clone(),
286        };
287        let a_ident = Ident {
288            name: "a".to_string(),
289            evidentiality: None,
290            affect: None,
291            span: span.clone(),
292        };
293        let b_ident = Ident {
294            name: "b".to_string(),
295            evidentiality: None,
296            affect: None,
297            span: span.clone(),
298        };
299
300        let params = vec![
301            Param {
302                pattern: Pattern::Ident {
303                    mutable: false,
304                    name: n_ident.clone(),
305                    evidentiality: None,
306                },
307                ty: TypeExpr::Infer,
308            },
309            Param {
310                pattern: Pattern::Ident {
311                    mutable: false,
312                    name: a_ident.clone(),
313                    evidentiality: None,
314                },
315                ty: TypeExpr::Infer,
316            },
317            Param {
318                pattern: Pattern::Ident {
319                    mutable: false,
320                    name: b_ident.clone(),
321                    evidentiality: None,
322                },
323                ty: TypeExpr::Infer,
324            },
325        ];
326
327        // Condition: n <= 0
328        let condition = Expr::Binary {
329            op: BinOp::Le,
330            left: Box::new(Expr::Path(TypePath {
331                segments: vec![PathSegment {
332                    ident: n_ident.clone(),
333                    generics: None,
334                }],
335            })),
336            right: Box::new(Expr::Literal(Literal::Int {
337                value: "0".to_string(),
338                base: NumBase::Decimal,
339                suffix: None,
340            })),
341        };
342
343        // Then branch: return a
344        let then_branch = Block {
345            stmts: vec![],
346            expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347                TypePath {
348                    segments: vec![PathSegment {
349                        ident: a_ident.clone(),
350                        generics: None,
351                    }],
352                },
353            )))))),
354        };
355
356        // Recursive call: fib_tail(n - 1, b, a + b)
357        let recursive_call = Expr::Call {
358            func: Box::new(Expr::Path(TypePath {
359                segments: vec![PathSegment {
360                    ident: Ident {
361                        name: name.to_string(),
362                        evidentiality: None,
363                        affect: None,
364                        span: span.clone(),
365                    },
366                    generics: None,
367                }],
368            })),
369            args: vec![
370                // n - 1
371                Expr::Binary {
372                    op: BinOp::Sub,
373                    left: Box::new(Expr::Path(TypePath {
374                        segments: vec![PathSegment {
375                            ident: n_ident.clone(),
376                            generics: None,
377                        }],
378                    })),
379                    right: Box::new(Expr::Literal(Literal::Int {
380                        value: "1".to_string(),
381                        base: NumBase::Decimal,
382                        suffix: None,
383                    })),
384                },
385                // b
386                Expr::Path(TypePath {
387                    segments: vec![PathSegment {
388                        ident: b_ident.clone(),
389                        generics: None,
390                    }],
391                }),
392                // a + b
393                Expr::Binary {
394                    op: BinOp::Add,
395                    left: Box::new(Expr::Path(TypePath {
396                        segments: vec![PathSegment {
397                            ident: a_ident.clone(),
398                            generics: None,
399                        }],
400                    })),
401                    right: Box::new(Expr::Path(TypePath {
402                        segments: vec![PathSegment {
403                            ident: b_ident.clone(),
404                            generics: None,
405                        }],
406                    })),
407                },
408            ],
409        };
410
411        // if n <= 0 { return a; } else { return fib_tail(n - 1, b, a + b); }
412        let body = Block {
413            stmts: vec![],
414            expr: Some(Box::new(Expr::If {
415                condition: Box::new(condition),
416                then_branch,
417                else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418            })),
419        };
420
421        ast::Function {
422            visibility: Visibility::default(),
423            is_async: false,
424            is_const: false,
425            is_unsafe: false,
426            attrs: FunctionAttrs::default(),
427            name: Ident {
428                name: name.to_string(),
429                evidentiality: None,
430                affect: None,
431                span: span.clone(),
432            },
433            aspect: None,
434            generics: None,
435            params,
436            return_type: None,
437            where_clause: None,
438            body: Some(body),
439        }
440    }
441
442    /// Generate the wrapper function that calls the helper
443    fn generate_fib_wrapper(
444        &self,
445        name: &str,
446        helper_name: &str,
447        param_name: &str,
448        original: &ast::Function,
449    ) -> ast::Function {
450        let span = Span { start: 0, end: 0 };
451
452        // fn fib(n) { return fib_tail(n, 0, 1); }
453        let call_helper = Expr::Call {
454            func: Box::new(Expr::Path(TypePath {
455                segments: vec![PathSegment {
456                    ident: Ident {
457                        name: helper_name.to_string(),
458                        evidentiality: None,
459                        affect: None,
460                        span: span.clone(),
461                    },
462                    generics: None,
463                }],
464            })),
465            args: vec![
466                // n
467                Expr::Path(TypePath {
468                    segments: vec![PathSegment {
469                        ident: Ident {
470                            name: param_name.to_string(),
471                            evidentiality: None,
472                            affect: None,
473                            span: span.clone(),
474                        },
475                        generics: None,
476                    }],
477                }),
478                // 0 (initial acc1)
479                Expr::Literal(Literal::Int {
480                    value: "0".to_string(),
481                    base: NumBase::Decimal,
482                    suffix: None,
483                }),
484                // 1 (initial acc2)
485                Expr::Literal(Literal::Int {
486                    value: "1".to_string(),
487                    base: NumBase::Decimal,
488                    suffix: None,
489                }),
490            ],
491        };
492
493        let body = Block {
494            stmts: vec![],
495            expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
496        };
497
498        ast::Function {
499            visibility: original.visibility,
500            is_async: original.is_async,
501            is_const: original.is_const,
502            is_unsafe: original.is_unsafe,
503            attrs: original.attrs.clone(),
504            name: Ident {
505                name: name.to_string(),
506                evidentiality: None,
507                affect: None,
508                span: span.clone(),
509            },
510            aspect: original.aspect,
511            generics: original.generics.clone(),
512            params: original.params.clone(),
513            return_type: original.return_type.clone(),
514            where_clause: original.where_clause.clone(),
515            body: Some(body),
516        }
517    }
518
519    // ========================================================================
520    // Memoization Transform
521    // ========================================================================
522
523    /// Try to transform a recursive function into a memoized version
524    /// Returns: (implementation_func, cache_init_func, wrapper_func)
525    #[allow(dead_code)]
526    fn try_memoize_transform(
527        &self,
528        func: &ast::Function,
529    ) -> Option<(ast::Function, ast::Function, ast::Function)> {
530        let param_count = func.params.len();
531        if param_count != 1 && param_count != 2 {
532            return None;
533        }
534
535        let span = Span { start: 0, end: 0 };
536        let func_name = &func.name.name;
537        let impl_name = format!("_memo_impl_{}", func_name);
538        let _cache_name = format!("_memo_cache_{}", func_name);
539        let init_name = format!("_memo_init_{}", func_name);
540
541        // Get parameter names
542        let param_names: Vec<String> = func
543            .params
544            .iter()
545            .filter_map(|p| {
546                if let Pattern::Ident { name, .. } = &p.pattern {
547                    Some(name.name.clone())
548                } else {
549                    None
550                }
551            })
552            .collect();
553
554        if param_names.len() != param_count {
555            return None;
556        }
557
558        // 1. Create implementation function (renamed original with calls redirected)
559        let impl_func = ast::Function {
560            visibility: Visibility::default(),
561            is_async: func.is_async,
562            is_const: func.is_const,
563            is_unsafe: func.is_unsafe,
564            attrs: func.attrs.clone(),
565            name: Ident {
566                name: impl_name.clone(),
567                evidentiality: None,
568                affect: None,
569                span: span.clone(),
570            },
571            aspect: func.aspect,
572            generics: func.generics.clone(),
573            params: func.params.clone(),
574            return_type: func.return_type.clone(),
575            where_clause: func.where_clause.clone(),
576            body: func
577                .body
578                .as_ref()
579                .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
580        };
581
582        // 2. Create cache initializer function
583        // This is a function that returns the cache, called once at the start
584        let cache_init_body = Block {
585            stmts: vec![],
586            expr: Some(Box::new(Expr::Call {
587                func: Box::new(Expr::Path(TypePath {
588                    segments: vec![PathSegment {
589                        ident: Ident {
590                            name: "sigil_memo_new".to_string(),
591                            evidentiality: None,
592                            affect: None,
593                            span: span.clone(),
594                        },
595                        generics: None,
596                    }],
597                })),
598                args: vec![Expr::Literal(Literal::Int {
599                    value: "65536".to_string(),
600                    base: NumBase::Decimal,
601                    suffix: None,
602                })],
603            })),
604        };
605
606        let cache_init_func = ast::Function {
607            visibility: Visibility::default(),
608            is_async: false,
609            is_const: false,
610            is_unsafe: false,
611            attrs: FunctionAttrs::default(),
612            name: Ident {
613                name: init_name.clone(),
614                evidentiality: None,
615                affect: None,
616                span: span.clone(),
617            },
618            aspect: None,
619            generics: None,
620            params: vec![],
621            return_type: None,
622            where_clause: None,
623            body: Some(cache_init_body),
624        };
625
626        // 3. Create wrapper function
627        let wrapper_func = self.generate_memo_wrapper(func, &impl_name, &param_names);
628
629        Some((impl_func, cache_init_func, wrapper_func))
630    }
631
632    /// Generate the memoized wrapper function
633    #[allow(dead_code)]
634    fn generate_memo_wrapper(
635        &self,
636        original: &ast::Function,
637        impl_name: &str,
638        param_names: &[String],
639    ) -> ast::Function {
640        let span = Span { start: 0, end: 0 };
641        let param_count = param_names.len();
642
643        // Variable for cache - use a static-like pattern with lazy init
644        let cache_var = Ident {
645            name: "__cache".to_string(),
646            evidentiality: None,
647            affect: None,
648            span: span.clone(),
649        };
650        let result_var = Ident {
651            name: "__result".to_string(),
652            evidentiality: None,
653            affect: None,
654            span: span.clone(),
655        };
656        let cached_var = Ident {
657            name: "__cached".to_string(),
658            evidentiality: None,
659            affect: None,
660            span: span.clone(),
661        };
662
663        let mut stmts = vec![];
664
665        // let __cache = sigil_memo_new(65536);
666        stmts.push(Stmt::Let {
667            pattern: Pattern::Ident {
668                mutable: false,
669                name: cache_var.clone(),
670                evidentiality: None,
671            },
672            ty: None,
673            init: Some(Expr::Call {
674                func: Box::new(Expr::Path(TypePath {
675                    segments: vec![PathSegment {
676                        ident: Ident {
677                            name: "sigil_memo_new".to_string(),
678                            evidentiality: None,
679                            affect: None,
680                            span: span.clone(),
681                        },
682                        generics: None,
683                    }],
684                })),
685                args: vec![Expr::Literal(Literal::Int {
686                    value: "65536".to_string(),
687                    base: NumBase::Decimal,
688                    suffix: None,
689                })],
690            }),
691        });
692
693        // Check cache: let __cached = sigil_memo_get_N(__cache, params...);
694        let get_fn_name = if param_count == 1 {
695            "sigil_memo_get_1"
696        } else {
697            "sigil_memo_get_2"
698        };
699        let mut get_args = vec![Expr::Path(TypePath {
700            segments: vec![PathSegment {
701                ident: cache_var.clone(),
702                generics: None,
703            }],
704        })];
705        for name in param_names {
706            get_args.push(Expr::Path(TypePath {
707                segments: vec![PathSegment {
708                    ident: Ident {
709                        name: name.clone(),
710                        evidentiality: None,
711                        affect: None,
712                        span: span.clone(),
713                    },
714                    generics: None,
715                }],
716            }));
717        }
718
719        stmts.push(Stmt::Let {
720            pattern: Pattern::Ident {
721                mutable: false,
722                name: cached_var.clone(),
723                evidentiality: None,
724            },
725            ty: None,
726            init: Some(Expr::Call {
727                func: Box::new(Expr::Path(TypePath {
728                    segments: vec![PathSegment {
729                        ident: Ident {
730                            name: get_fn_name.to_string(),
731                            evidentiality: None,
732                            affect: None,
733                            span: span.clone(),
734                        },
735                        generics: None,
736                    }],
737                })),
738                args: get_args,
739            }),
740        });
741
742        // if __cached != -9223372036854775807 { return __cached; }
743        // Use a large negative number as sentinel (i64::MIN + 1 to avoid overflow issues)
744        let cache_check = Expr::If {
745            condition: Box::new(Expr::Binary {
746                op: BinOp::Ne,
747                left: Box::new(Expr::Path(TypePath {
748                    segments: vec![PathSegment {
749                        ident: cached_var.clone(),
750                        generics: None,
751                    }],
752                })),
753                right: Box::new(Expr::Unary {
754                    op: UnaryOp::Neg,
755                    expr: Box::new(Expr::Literal(Literal::Int {
756                        value: "9223372036854775807".to_string(),
757                        base: NumBase::Decimal,
758                        suffix: None,
759                    })),
760                }),
761            }),
762            then_branch: Block {
763                stmts: vec![],
764                expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
765                    TypePath {
766                        segments: vec![PathSegment {
767                            ident: cached_var.clone(),
768                            generics: None,
769                        }],
770                    },
771                )))))),
772            },
773            else_branch: None,
774        };
775        stmts.push(Stmt::Semi(cache_check));
776
777        // let __result = _memo_impl_func(params...);
778        let mut impl_args = vec![];
779        for name in param_names {
780            impl_args.push(Expr::Path(TypePath {
781                segments: vec![PathSegment {
782                    ident: Ident {
783                        name: name.clone(),
784                        evidentiality: None,
785                        affect: None,
786                        span: span.clone(),
787                    },
788                    generics: None,
789                }],
790            }));
791        }
792
793        stmts.push(Stmt::Let {
794            pattern: Pattern::Ident {
795                mutable: false,
796                name: result_var.clone(),
797                evidentiality: None,
798            },
799            ty: None,
800            init: Some(Expr::Call {
801                func: Box::new(Expr::Path(TypePath {
802                    segments: vec![PathSegment {
803                        ident: Ident {
804                            name: impl_name.to_string(),
805                            evidentiality: None,
806                            affect: None,
807                            span: span.clone(),
808                        },
809                        generics: None,
810                    }],
811                })),
812                args: impl_args,
813            }),
814        });
815
816        // sigil_memo_set_N(__cache, params..., __result);
817        let set_fn_name = if param_count == 1 {
818            "sigil_memo_set_1"
819        } else {
820            "sigil_memo_set_2"
821        };
822        let mut set_args = vec![Expr::Path(TypePath {
823            segments: vec![PathSegment {
824                ident: cache_var.clone(),
825                generics: None,
826            }],
827        })];
828        for name in param_names {
829            set_args.push(Expr::Path(TypePath {
830                segments: vec![PathSegment {
831                    ident: Ident {
832                        name: name.clone(),
833                        evidentiality: None,
834                        affect: None,
835                        span: span.clone(),
836                    },
837                    generics: None,
838                }],
839            }));
840        }
841        set_args.push(Expr::Path(TypePath {
842            segments: vec![PathSegment {
843                ident: result_var.clone(),
844                generics: None,
845            }],
846        }));
847
848        stmts.push(Stmt::Semi(Expr::Call {
849            func: Box::new(Expr::Path(TypePath {
850                segments: vec![PathSegment {
851                    ident: Ident {
852                        name: set_fn_name.to_string(),
853                        evidentiality: None,
854                        affect: None,
855                        span: span.clone(),
856                    },
857                    generics: None,
858                }],
859            })),
860            args: set_args,
861        }));
862
863        // return __result;
864        let body = Block {
865            stmts,
866            expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
867                TypePath {
868                    segments: vec![PathSegment {
869                        ident: result_var.clone(),
870                        generics: None,
871                    }],
872                },
873            )))))),
874        };
875
876        ast::Function {
877            visibility: original.visibility,
878            is_async: original.is_async,
879            is_const: original.is_const,
880            is_unsafe: original.is_unsafe,
881            attrs: original.attrs.clone(),
882            name: original.name.clone(),
883            aspect: original.aspect,
884            generics: original.generics.clone(),
885            params: original.params.clone(),
886            return_type: original.return_type.clone(),
887            where_clause: original.where_clause.clone(),
888            body: Some(body),
889        }
890    }
891
892    /// Redirect all recursive calls in a block to call the original wrapper (for memoization)
893    #[allow(dead_code)]
894    fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
895        // For memoization, we keep the calls as-is since they'll go through the wrapper
896        block.clone()
897    }
898
899    /// Check if a function is recursive
900    fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
901        if let Some(body) = &func.body {
902            self.block_calls_function(name, body)
903        } else {
904            false
905        }
906    }
907
908    fn block_calls_function(&self, name: &str, block: &Block) -> bool {
909        for stmt in &block.stmts {
910            if self.stmt_calls_function(name, stmt) {
911                return true;
912            }
913        }
914        if let Some(expr) = &block.expr {
915            if self.expr_calls_function(name, expr) {
916                return true;
917            }
918        }
919        false
920    }
921
922    fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
923        match stmt {
924            Stmt::Let {
925                init: Some(expr), ..
926            } => self.expr_calls_function(name, expr),
927            Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
928            _ => false,
929        }
930    }
931
932    fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
933        match expr {
934            Expr::Call { func, args } => {
935                if let Expr::Path(path) = func.as_ref() {
936                    if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
937                        return true;
938                    }
939                }
940                args.iter().any(|a| self.expr_calls_function(name, a))
941            }
942            Expr::Binary { left, right, .. } => {
943                self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
944            }
945            Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
946            Expr::If {
947                condition,
948                then_branch,
949                else_branch,
950            } => {
951                self.expr_calls_function(name, condition)
952                    || self.block_calls_function(name, then_branch)
953                    || else_branch
954                        .as_ref()
955                        .map(|e| self.expr_calls_function(name, e))
956                        .unwrap_or(false)
957            }
958            Expr::While {
959                label,
960                condition,
961                body,
962            } => self.expr_calls_function(name, condition) || self.block_calls_function(name, body),
963            Expr::Block(block) => self.block_calls_function(name, block),
964            Expr::Return(Some(e)) => self.expr_calls_function(name, e),
965            _ => false,
966        }
967    }
968
969    /// Optimize a single function
970    fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
971        // Reset CSE counter per function for cleaner variable names
972        self.cse_counter = 0;
973
974        let body = if let Some(body) = &func.body {
975            // Run passes based on optimization level
976            let optimized = match self.level {
977                OptLevel::None => body.clone(),
978                OptLevel::Basic => {
979                    let b = self.pass_constant_fold_block(body);
980                    self.pass_dead_code_block(&b)
981                }
982                OptLevel::Standard | OptLevel::Size => {
983                    let b = self.pass_constant_fold_block(body);
984                    let b = self.pass_inline_block(&b); // Function inlining
985                    let b = self.pass_strength_reduce_block(&b);
986                    let b = self.pass_licm_block(&b); // LICM pass
987                    let b = self.pass_cse_block(&b); // CSE pass
988                    let b = self.pass_dead_code_block(&b);
989                    self.pass_simplify_branches_block(&b)
990                }
991                OptLevel::Aggressive => {
992                    // Multiple iterations for fixed-point
993                    let mut b = body.clone();
994                    for _ in 0..3 {
995                        b = self.pass_constant_fold_block(&b);
996                        b = self.pass_inline_block(&b); // Function inlining
997                        b = self.pass_strength_reduce_block(&b);
998                        b = self.pass_loop_unroll_block(&b); // Loop unrolling
999                        b = self.pass_licm_block(&b); // LICM pass
1000                        b = self.pass_cse_block(&b); // CSE pass
1001                        b = self.pass_dead_code_block(&b);
1002                        b = self.pass_simplify_branches_block(&b);
1003                    }
1004                    b
1005                }
1006            };
1007            Some(optimized)
1008        } else {
1009            None
1010        };
1011
1012        ast::Function {
1013            visibility: func.visibility.clone(),
1014            is_async: func.is_async,
1015            is_const: func.is_const,
1016            is_unsafe: func.is_unsafe,
1017            attrs: func.attrs.clone(),
1018            name: func.name.clone(),
1019            aspect: func.aspect,
1020            generics: func.generics.clone(),
1021            params: func.params.clone(),
1022            return_type: func.return_type.clone(),
1023            where_clause: func.where_clause.clone(),
1024            body,
1025        }
1026    }
1027
1028    // ========================================================================
1029    // Pass: Constant Folding
1030    // ========================================================================
1031
1032    fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1033        let stmts = block
1034            .stmts
1035            .iter()
1036            .map(|s| self.pass_constant_fold_stmt(s))
1037            .collect();
1038        let expr = block
1039            .expr
1040            .as_ref()
1041            .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1042        Block { stmts, expr }
1043    }
1044
1045    fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1046        match stmt {
1047            Stmt::Let {
1048                pattern, ty, init, ..
1049            } => Stmt::Let {
1050                pattern: pattern.clone(),
1051                ty: ty.clone(),
1052                init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1053            },
1054            Stmt::LetElse {
1055                pattern,
1056                ty,
1057                init,
1058                else_branch,
1059            } => Stmt::LetElse {
1060                pattern: pattern.clone(),
1061                ty: ty.clone(),
1062                init: self.pass_constant_fold_expr(init),
1063                else_branch: Box::new(self.pass_constant_fold_expr(else_branch)),
1064            },
1065            Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1066            Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1067            Stmt::Item(item) => Stmt::Item(item.clone()),
1068        }
1069    }
1070
1071    fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1072        match expr {
1073            Expr::Binary { op, left, right } => {
1074                let left = Box::new(self.pass_constant_fold_expr(left));
1075                let right = Box::new(self.pass_constant_fold_expr(right));
1076
1077                // Try to fold
1078                if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1079                    if let Some(result) = self.fold_binary(op.clone(), l, r) {
1080                        self.stats.constants_folded += 1;
1081                        return Expr::Literal(Literal::Int {
1082                            value: result.to_string(),
1083                            base: NumBase::Decimal,
1084                            suffix: None,
1085                        });
1086                    }
1087                }
1088
1089                Expr::Binary {
1090                    op: op.clone(),
1091                    left,
1092                    right,
1093                }
1094            }
1095            Expr::Unary { op, expr: inner } => {
1096                let inner = Box::new(self.pass_constant_fold_expr(inner));
1097
1098                if let Some(v) = self.as_int(&inner) {
1099                    if let Some(result) = self.fold_unary(*op, v) {
1100                        self.stats.constants_folded += 1;
1101                        return Expr::Literal(Literal::Int {
1102                            value: result.to_string(),
1103                            base: NumBase::Decimal,
1104                            suffix: None,
1105                        });
1106                    }
1107                }
1108
1109                Expr::Unary {
1110                    op: *op,
1111                    expr: inner,
1112                }
1113            }
1114            Expr::If {
1115                condition,
1116                then_branch,
1117                else_branch,
1118            } => {
1119                let condition = Box::new(self.pass_constant_fold_expr(condition));
1120                let then_branch = self.pass_constant_fold_block(then_branch);
1121                let else_branch = else_branch
1122                    .as_ref()
1123                    .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1124
1125                // Fold constant conditions
1126                if let Some(cond) = self.as_bool(&condition) {
1127                    self.stats.branches_simplified += 1;
1128                    if cond {
1129                        return Expr::Block(then_branch);
1130                    } else if let Some(else_expr) = else_branch {
1131                        return *else_expr;
1132                    } else {
1133                        return Expr::Literal(Literal::Bool(false));
1134                    }
1135                }
1136
1137                Expr::If {
1138                    condition,
1139                    then_branch,
1140                    else_branch,
1141                }
1142            }
1143            Expr::While {
1144                label,
1145                condition,
1146                body,
1147            } => {
1148                let condition = Box::new(self.pass_constant_fold_expr(condition));
1149                let body = self.pass_constant_fold_block(body);
1150
1151                // while false { ... } -> nothing
1152                if let Some(false) = self.as_bool(&condition) {
1153                    self.stats.branches_simplified += 1;
1154                    return Expr::Block(Block {
1155                        stmts: vec![],
1156                        expr: None,
1157                    });
1158                }
1159
1160                Expr::While {
1161                    label: label.clone(),
1162                    condition,
1163                    body,
1164                }
1165            }
1166            Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1167            Expr::Call { func, args } => {
1168                let args = args
1169                    .iter()
1170                    .map(|a| self.pass_constant_fold_expr(a))
1171                    .collect();
1172                Expr::Call {
1173                    func: func.clone(),
1174                    args,
1175                }
1176            }
1177            Expr::Return(e) => Expr::Return(
1178                e.as_ref()
1179                    .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1180            ),
1181            Expr::Assign { target, value } => {
1182                let value = Box::new(self.pass_constant_fold_expr(value));
1183                Expr::Assign {
1184                    target: target.clone(),
1185                    value,
1186                }
1187            }
1188            Expr::Index { expr: e, index } => {
1189                let e = Box::new(self.pass_constant_fold_expr(e));
1190                let index = Box::new(self.pass_constant_fold_expr(index));
1191                Expr::Index { expr: e, index }
1192            }
1193            Expr::Array(elements) => {
1194                let elements = elements
1195                    .iter()
1196                    .map(|e| self.pass_constant_fold_expr(e))
1197                    .collect();
1198                Expr::Array(elements)
1199            }
1200            other => other.clone(),
1201        }
1202    }
1203
1204    fn as_int(&self, expr: &Expr) -> Option<i64> {
1205        match expr {
1206            Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1207            Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1208            _ => None,
1209        }
1210    }
1211
1212    fn as_bool(&self, expr: &Expr) -> Option<bool> {
1213        match expr {
1214            Expr::Literal(Literal::Bool(b)) => Some(*b),
1215            Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1216            _ => None,
1217        }
1218    }
1219
1220    fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1221        match op {
1222            BinOp::Add => Some(l.wrapping_add(r)),
1223            BinOp::Sub => Some(l.wrapping_sub(r)),
1224            BinOp::Mul => Some(l.wrapping_mul(r)),
1225            BinOp::Div if r != 0 => Some(l / r),
1226            BinOp::Rem if r != 0 => Some(l % r),
1227            BinOp::BitAnd => Some(l & r),
1228            BinOp::BitOr => Some(l | r),
1229            BinOp::BitXor => Some(l ^ r),
1230            BinOp::Shl => Some(l << (r & 63)),
1231            BinOp::Shr => Some(l >> (r & 63)),
1232            BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1233            BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1234            BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1235            BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1236            BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1237            BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1238            BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1239            BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1240            _ => None,
1241        }
1242    }
1243
1244    fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1245        match op {
1246            UnaryOp::Neg => Some(-v),
1247            UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1248            _ => None,
1249        }
1250    }
1251
1252    // ========================================================================
1253    // Pass: Strength Reduction
1254    // ========================================================================
1255
1256    fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1257        let stmts = block
1258            .stmts
1259            .iter()
1260            .map(|s| self.pass_strength_reduce_stmt(s))
1261            .collect();
1262        let expr = block
1263            .expr
1264            .as_ref()
1265            .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1266        Block { stmts, expr }
1267    }
1268
1269    fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1270        match stmt {
1271            Stmt::Let {
1272                pattern, ty, init, ..
1273            } => Stmt::Let {
1274                pattern: pattern.clone(),
1275                ty: ty.clone(),
1276                init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1277            },
1278            Stmt::LetElse {
1279                pattern,
1280                ty,
1281                init,
1282                else_branch,
1283            } => Stmt::LetElse {
1284                pattern: pattern.clone(),
1285                ty: ty.clone(),
1286                init: self.pass_strength_reduce_expr(init),
1287                else_branch: Box::new(self.pass_strength_reduce_expr(else_branch)),
1288            },
1289            Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1290            Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1291            Stmt::Item(item) => Stmt::Item(item.clone()),
1292        }
1293    }
1294
1295    fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1296        match expr {
1297            Expr::Binary { op, left, right } => {
1298                let left = Box::new(self.pass_strength_reduce_expr(left));
1299                let right = Box::new(self.pass_strength_reduce_expr(right));
1300
1301                // x * 2 -> x << 1, x * 4 -> x << 2, etc.
1302                if *op == BinOp::Mul {
1303                    if let Some(n) = self.as_int(&right) {
1304                        if n > 0 && (n as u64).is_power_of_two() {
1305                            self.stats.strength_reductions += 1;
1306                            let shift = (n as u64).trailing_zeros() as i64;
1307                            return Expr::Binary {
1308                                op: BinOp::Shl,
1309                                left,
1310                                right: Box::new(Expr::Literal(Literal::Int {
1311                                    value: shift.to_string(),
1312                                    base: NumBase::Decimal,
1313                                    suffix: None,
1314                                })),
1315                            };
1316                        }
1317                    }
1318                    if let Some(n) = self.as_int(&left) {
1319                        if n > 0 && (n as u64).is_power_of_two() {
1320                            self.stats.strength_reductions += 1;
1321                            let shift = (n as u64).trailing_zeros() as i64;
1322                            return Expr::Binary {
1323                                op: BinOp::Shl,
1324                                left: right,
1325                                right: Box::new(Expr::Literal(Literal::Int {
1326                                    value: shift.to_string(),
1327                                    base: NumBase::Decimal,
1328                                    suffix: None,
1329                                })),
1330                            };
1331                        }
1332                    }
1333                }
1334
1335                // x + 0 -> x, x - 0 -> x, x * 1 -> x, x / 1 -> x
1336                if let Some(n) = self.as_int(&right) {
1337                    match (op, n) {
1338                        (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1339                        | (BinOp::Mul | BinOp::Div, 1)
1340                        | (BinOp::Shl | BinOp::Shr, 0) => {
1341                            self.stats.strength_reductions += 1;
1342                            return *left;
1343                        }
1344                        (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1345                            self.stats.strength_reductions += 1;
1346                            return Expr::Literal(Literal::Int {
1347                                value: "0".to_string(),
1348                                base: NumBase::Decimal,
1349                                suffix: None,
1350                            });
1351                        }
1352                        _ => {}
1353                    }
1354                }
1355
1356                // 0 + x -> x, 1 * x -> x
1357                if let Some(n) = self.as_int(&left) {
1358                    match (op, n) {
1359                        (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1360                            self.stats.strength_reductions += 1;
1361                            return *right;
1362                        }
1363                        (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1364                            self.stats.strength_reductions += 1;
1365                            return Expr::Literal(Literal::Int {
1366                                value: "0".to_string(),
1367                                base: NumBase::Decimal,
1368                                suffix: None,
1369                            });
1370                        }
1371                        _ => {}
1372                    }
1373                }
1374
1375                Expr::Binary {
1376                    op: op.clone(),
1377                    left,
1378                    right,
1379                }
1380            }
1381            Expr::Unary { op, expr: inner } => {
1382                let inner = Box::new(self.pass_strength_reduce_expr(inner));
1383
1384                // --x -> x
1385                if *op == UnaryOp::Neg {
1386                    if let Expr::Unary {
1387                        op: UnaryOp::Neg,
1388                        expr: inner2,
1389                    } = inner.as_ref()
1390                    {
1391                        self.stats.strength_reductions += 1;
1392                        return *inner2.clone();
1393                    }
1394                }
1395
1396                // !!x -> x (for booleans)
1397                if *op == UnaryOp::Not {
1398                    if let Expr::Unary {
1399                        op: UnaryOp::Not,
1400                        expr: inner2,
1401                    } = inner.as_ref()
1402                    {
1403                        self.stats.strength_reductions += 1;
1404                        return *inner2.clone();
1405                    }
1406                }
1407
1408                Expr::Unary {
1409                    op: *op,
1410                    expr: inner,
1411                }
1412            }
1413            Expr::If {
1414                condition,
1415                then_branch,
1416                else_branch,
1417            } => {
1418                let condition = Box::new(self.pass_strength_reduce_expr(condition));
1419                let then_branch = self.pass_strength_reduce_block(then_branch);
1420                let else_branch = else_branch
1421                    .as_ref()
1422                    .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1423                Expr::If {
1424                    condition,
1425                    then_branch,
1426                    else_branch,
1427                }
1428            }
1429            Expr::While {
1430                label,
1431                condition,
1432                body,
1433            } => {
1434                let condition = Box::new(self.pass_strength_reduce_expr(condition));
1435                let body = self.pass_strength_reduce_block(body);
1436                Expr::While {
1437                    label: label.clone(),
1438                    condition,
1439                    body,
1440                }
1441            }
1442            Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1443            Expr::Call { func, args } => {
1444                let args = args
1445                    .iter()
1446                    .map(|a| self.pass_strength_reduce_expr(a))
1447                    .collect();
1448                Expr::Call {
1449                    func: func.clone(),
1450                    args,
1451                }
1452            }
1453            Expr::Return(e) => Expr::Return(
1454                e.as_ref()
1455                    .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1456            ),
1457            Expr::Assign { target, value } => {
1458                let value = Box::new(self.pass_strength_reduce_expr(value));
1459                Expr::Assign {
1460                    target: target.clone(),
1461                    value,
1462                }
1463            }
1464            other => other.clone(),
1465        }
1466    }
1467
1468    // ========================================================================
1469    // Pass: Dead Code Elimination
1470    // ========================================================================
1471
1472    fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1473        // Remove statements after a return
1474        let mut stmts = Vec::new();
1475        let mut found_return = false;
1476
1477        for stmt in &block.stmts {
1478            if found_return {
1479                self.stats.dead_code_eliminated += 1;
1480                continue;
1481            }
1482            let stmt = self.pass_dead_code_stmt(stmt);
1483            if self.stmt_returns(&stmt) {
1484                found_return = true;
1485            }
1486            stmts.push(stmt);
1487        }
1488
1489        // If we found a return, the trailing expression is dead
1490        let expr = if found_return {
1491            if block.expr.is_some() {
1492                self.stats.dead_code_eliminated += 1;
1493            }
1494            None
1495        } else {
1496            block
1497                .expr
1498                .as_ref()
1499                .map(|e| Box::new(self.pass_dead_code_expr(e)))
1500        };
1501
1502        Block { stmts, expr }
1503    }
1504
1505    fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1506        match stmt {
1507            Stmt::Let {
1508                pattern, ty, init, ..
1509            } => Stmt::Let {
1510                pattern: pattern.clone(),
1511                ty: ty.clone(),
1512                init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1513            },
1514            Stmt::LetElse {
1515                pattern,
1516                ty,
1517                init,
1518                else_branch,
1519            } => Stmt::LetElse {
1520                pattern: pattern.clone(),
1521                ty: ty.clone(),
1522                init: self.pass_dead_code_expr(init),
1523                else_branch: Box::new(self.pass_dead_code_expr(else_branch)),
1524            },
1525            Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1526            Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1527            Stmt::Item(item) => Stmt::Item(item.clone()),
1528        }
1529    }
1530
1531    fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1532        match expr {
1533            Expr::If {
1534                condition,
1535                then_branch,
1536                else_branch,
1537            } => {
1538                let condition = Box::new(self.pass_dead_code_expr(condition));
1539                let then_branch = self.pass_dead_code_block(then_branch);
1540                let else_branch = else_branch
1541                    .as_ref()
1542                    .map(|e| Box::new(self.pass_dead_code_expr(e)));
1543                Expr::If {
1544                    condition,
1545                    then_branch,
1546                    else_branch,
1547                }
1548            }
1549            Expr::While {
1550                label,
1551                condition,
1552                body,
1553            } => {
1554                let condition = Box::new(self.pass_dead_code_expr(condition));
1555                let body = self.pass_dead_code_block(body);
1556                Expr::While {
1557                    label: label.clone(),
1558                    condition,
1559                    body,
1560                }
1561            }
1562            Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1563            other => other.clone(),
1564        }
1565    }
1566
1567    fn stmt_returns(&self, stmt: &Stmt) -> bool {
1568        match stmt {
1569            Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1570            _ => false,
1571        }
1572    }
1573
1574    fn expr_returns(&self, expr: &Expr) -> bool {
1575        match expr {
1576            Expr::Return(_) => true,
1577            Expr::Block(block) => {
1578                block.stmts.iter().any(|s| self.stmt_returns(s))
1579                    || block
1580                        .expr
1581                        .as_ref()
1582                        .map(|e| self.expr_returns(e))
1583                        .unwrap_or(false)
1584            }
1585            _ => false,
1586        }
1587    }
1588
1589    // ========================================================================
1590    // Pass: Branch Simplification
1591    // ========================================================================
1592
1593    fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1594        let stmts = block
1595            .stmts
1596            .iter()
1597            .map(|s| self.pass_simplify_branches_stmt(s))
1598            .collect();
1599        let expr = block
1600            .expr
1601            .as_ref()
1602            .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1603        Block { stmts, expr }
1604    }
1605
1606    fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1607        match stmt {
1608            Stmt::Let {
1609                pattern, ty, init, ..
1610            } => Stmt::Let {
1611                pattern: pattern.clone(),
1612                ty: ty.clone(),
1613                init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1614            },
1615            Stmt::LetElse {
1616                pattern,
1617                ty,
1618                init,
1619                else_branch,
1620            } => Stmt::LetElse {
1621                pattern: pattern.clone(),
1622                ty: ty.clone(),
1623                init: self.pass_simplify_branches_expr(init),
1624                else_branch: Box::new(self.pass_simplify_branches_expr(else_branch)),
1625            },
1626            Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1627            Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1628            Stmt::Item(item) => Stmt::Item(item.clone()),
1629        }
1630    }
1631
1632    fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1633        match expr {
1634            Expr::If {
1635                condition,
1636                then_branch,
1637                else_branch,
1638            } => {
1639                let condition = Box::new(self.pass_simplify_branches_expr(condition));
1640                let then_branch = self.pass_simplify_branches_block(then_branch);
1641                let else_branch = else_branch
1642                    .as_ref()
1643                    .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1644
1645                // if !cond { a } else { b } -> if cond { b } else { a }
1646                if let Expr::Unary {
1647                    op: UnaryOp::Not,
1648                    expr: inner,
1649                } = condition.as_ref()
1650                {
1651                    if let Some(else_expr) = &else_branch {
1652                        self.stats.branches_simplified += 1;
1653                        let new_else = Some(Box::new(Expr::Block(then_branch)));
1654                        let new_then = match else_expr.as_ref() {
1655                            Expr::Block(b) => b.clone(),
1656                            other => Block {
1657                                stmts: vec![],
1658                                expr: Some(Box::new(other.clone())),
1659                            },
1660                        };
1661                        return Expr::If {
1662                            condition: inner.clone(),
1663                            then_branch: new_then,
1664                            else_branch: new_else,
1665                        };
1666                    }
1667                }
1668
1669                Expr::If {
1670                    condition,
1671                    then_branch,
1672                    else_branch,
1673                }
1674            }
1675            Expr::While {
1676                label,
1677                condition,
1678                body,
1679            } => {
1680                let condition = Box::new(self.pass_simplify_branches_expr(condition));
1681                let body = self.pass_simplify_branches_block(body);
1682                Expr::While {
1683                    label: label.clone(),
1684                    condition,
1685                    body,
1686                }
1687            }
1688            Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1689            Expr::Binary { op, left, right } => {
1690                let left = Box::new(self.pass_simplify_branches_expr(left));
1691                let right = Box::new(self.pass_simplify_branches_expr(right));
1692                Expr::Binary {
1693                    op: op.clone(),
1694                    left,
1695                    right,
1696                }
1697            }
1698            Expr::Unary { op, expr: inner } => {
1699                let inner = Box::new(self.pass_simplify_branches_expr(inner));
1700                Expr::Unary {
1701                    op: *op,
1702                    expr: inner,
1703                }
1704            }
1705            Expr::Call { func, args } => {
1706                let args = args
1707                    .iter()
1708                    .map(|a| self.pass_simplify_branches_expr(a))
1709                    .collect();
1710                Expr::Call {
1711                    func: func.clone(),
1712                    args,
1713                }
1714            }
1715            Expr::Return(e) => Expr::Return(
1716                e.as_ref()
1717                    .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1718            ),
1719            other => other.clone(),
1720        }
1721    }
1722
1723    // ========================================================================
1724    // Pass: Function Inlining
1725    // ========================================================================
1726
1727    /// Check if a function is small enough to inline
1728    fn should_inline(&self, func: &ast::Function) -> bool {
1729        // Don't inline recursive functions
1730        if self.recursive_functions.contains(&func.name.name) {
1731            return false;
1732        }
1733
1734        // Count the number of statements/expressions in the body
1735        if let Some(body) = &func.body {
1736            let stmt_count = self.count_stmts_in_block(body);
1737            // Inline functions with 10 or fewer statements
1738            stmt_count <= 10
1739        } else {
1740            false
1741        }
1742    }
1743
1744    /// Count statements in a block (for inlining heuristics)
1745    fn count_stmts_in_block(&self, block: &Block) -> usize {
1746        let mut count = block.stmts.len();
1747        if block.expr.is_some() {
1748            count += 1;
1749        }
1750        // Also count nested blocks in if/while
1751        for stmt in &block.stmts {
1752            count += self.count_stmts_in_stmt(stmt);
1753        }
1754        count
1755    }
1756
1757    fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1758        match stmt {
1759            Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1760            Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1761            _ => 0,
1762        }
1763    }
1764
1765    fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1766        match expr {
1767            Expr::If {
1768                then_branch,
1769                else_branch,
1770                ..
1771            } => {
1772                let mut count = self.count_stmts_in_block(then_branch);
1773                if let Some(else_expr) = else_branch {
1774                    count += self.count_stmts_in_expr(else_expr);
1775                }
1776                count
1777            }
1778            Expr::While { body, .. } => self.count_stmts_in_block(body),
1779            Expr::Block(block) => self.count_stmts_in_block(block),
1780            _ => 0,
1781        }
1782    }
1783
1784    /// Inline function call by substituting the body with parameters replaced
1785    fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1786        let body = func.body.as_ref()?;
1787
1788        // Build parameter to argument mapping
1789        let mut param_map: HashMap<String, Expr> = HashMap::new();
1790        for (param, arg) in func.params.iter().zip(args.iter()) {
1791            if let Pattern::Ident { name, .. } = &param.pattern {
1792                param_map.insert(name.name.clone(), arg.clone());
1793            }
1794        }
1795
1796        // Substitute parameters in the function body
1797        let inlined_body = self.substitute_params_in_block(body, &param_map);
1798
1799        self.stats.functions_inlined += 1;
1800
1801        // If the body has a final expression, return it
1802        // If not, wrap in a block
1803        if inlined_body.stmts.is_empty() {
1804            if let Some(expr) = inlined_body.expr {
1805                // If it's a return, unwrap it
1806                if let Expr::Return(Some(inner)) = expr.as_ref() {
1807                    return Some(inner.as_ref().clone());
1808                }
1809                return Some(*expr);
1810            }
1811        }
1812
1813        Some(Expr::Block(inlined_body))
1814    }
1815
1816    /// Substitute parameter references with argument expressions
1817    fn substitute_params_in_block(
1818        &self,
1819        block: &Block,
1820        param_map: &HashMap<String, Expr>,
1821    ) -> Block {
1822        let stmts = block
1823            .stmts
1824            .iter()
1825            .map(|s| self.substitute_params_in_stmt(s, param_map))
1826            .collect();
1827        let expr = block
1828            .expr
1829            .as_ref()
1830            .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1831        Block { stmts, expr }
1832    }
1833
1834    fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1835        match stmt {
1836            Stmt::Let { pattern, ty, init } => Stmt::Let {
1837                pattern: pattern.clone(),
1838                ty: ty.clone(),
1839                init: init
1840                    .as_ref()
1841                    .map(|e| self.substitute_params_in_expr(e, param_map)),
1842            },
1843            Stmt::LetElse {
1844                pattern,
1845                ty,
1846                init,
1847                else_branch,
1848            } => Stmt::LetElse {
1849                pattern: pattern.clone(),
1850                ty: ty.clone(),
1851                init: self.substitute_params_in_expr(init, param_map),
1852                else_branch: Box::new(self.substitute_params_in_expr(else_branch, param_map)),
1853            },
1854            Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1855            Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1856            Stmt::Item(item) => Stmt::Item(item.clone()),
1857        }
1858    }
1859
1860    fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1861        match expr {
1862            Expr::Path(path) => {
1863                // Check if this is a parameter reference
1864                if path.segments.len() == 1 {
1865                    let name = &path.segments[0].ident.name;
1866                    if let Some(arg) = param_map.get(name) {
1867                        return arg.clone();
1868                    }
1869                }
1870                expr.clone()
1871            }
1872            Expr::Binary { op, left, right } => Expr::Binary {
1873                op: op.clone(),
1874                left: Box::new(self.substitute_params_in_expr(left, param_map)),
1875                right: Box::new(self.substitute_params_in_expr(right, param_map)),
1876            },
1877            Expr::Unary { op, expr: inner } => Expr::Unary {
1878                op: *op,
1879                expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1880            },
1881            Expr::If {
1882                condition,
1883                then_branch,
1884                else_branch,
1885            } => Expr::If {
1886                condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1887                then_branch: self.substitute_params_in_block(then_branch, param_map),
1888                else_branch: else_branch
1889                    .as_ref()
1890                    .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1891            },
1892            Expr::While {
1893                label,
1894                condition,
1895                body,
1896            } => Expr::While {
1897                label: label.clone(),
1898                condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1899                body: self.substitute_params_in_block(body, param_map),
1900            },
1901            Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1902            Expr::Call { func, args } => Expr::Call {
1903                func: Box::new(self.substitute_params_in_expr(func, param_map)),
1904                args: args
1905                    .iter()
1906                    .map(|a| self.substitute_params_in_expr(a, param_map))
1907                    .collect(),
1908            },
1909            Expr::Return(e) => Expr::Return(
1910                e.as_ref()
1911                    .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1912            ),
1913            Expr::Assign { target, value } => Expr::Assign {
1914                target: target.clone(),
1915                value: Box::new(self.substitute_params_in_expr(value, param_map)),
1916            },
1917            Expr::Index { expr: e, index } => Expr::Index {
1918                expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1919                index: Box::new(self.substitute_params_in_expr(index, param_map)),
1920            },
1921            Expr::Array(elements) => Expr::Array(
1922                elements
1923                    .iter()
1924                    .map(|e| self.substitute_params_in_expr(e, param_map))
1925                    .collect(),
1926            ),
1927            other => other.clone(),
1928        }
1929    }
1930
1931    fn pass_inline_block(&mut self, block: &Block) -> Block {
1932        let stmts = block
1933            .stmts
1934            .iter()
1935            .map(|s| self.pass_inline_stmt(s))
1936            .collect();
1937        let expr = block
1938            .expr
1939            .as_ref()
1940            .map(|e| Box::new(self.pass_inline_expr(e)));
1941        Block { stmts, expr }
1942    }
1943
1944    fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1945        match stmt {
1946            Stmt::Let { pattern, ty, init } => Stmt::Let {
1947                pattern: pattern.clone(),
1948                ty: ty.clone(),
1949                init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1950            },
1951            Stmt::LetElse {
1952                pattern,
1953                ty,
1954                init,
1955                else_branch,
1956            } => Stmt::LetElse {
1957                pattern: pattern.clone(),
1958                ty: ty.clone(),
1959                init: self.pass_inline_expr(init),
1960                else_branch: Box::new(self.pass_inline_expr(else_branch)),
1961            },
1962            Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1963            Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1964            Stmt::Item(item) => Stmt::Item(item.clone()),
1965        }
1966    }
1967
1968    fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1969        match expr {
1970            Expr::Call { func, args } => {
1971                // First, recursively inline arguments
1972                let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1973
1974                // Check if we can inline this call
1975                if let Expr::Path(path) = func.as_ref() {
1976                    if path.segments.len() == 1 {
1977                        let func_name = &path.segments[0].ident.name;
1978                        if let Some(target_func) = self.functions.get(func_name).cloned() {
1979                            if self.should_inline(&target_func)
1980                                && args.len() == target_func.params.len()
1981                            {
1982                                if let Some(inlined) = self.inline_call(&target_func, &args) {
1983                                    return inlined;
1984                                }
1985                            }
1986                        }
1987                    }
1988                }
1989
1990                Expr::Call {
1991                    func: func.clone(),
1992                    args,
1993                }
1994            }
1995            Expr::Binary { op, left, right } => Expr::Binary {
1996                op: op.clone(),
1997                left: Box::new(self.pass_inline_expr(left)),
1998                right: Box::new(self.pass_inline_expr(right)),
1999            },
2000            Expr::Unary { op, expr: inner } => Expr::Unary {
2001                op: *op,
2002                expr: Box::new(self.pass_inline_expr(inner)),
2003            },
2004            Expr::If {
2005                condition,
2006                then_branch,
2007                else_branch,
2008            } => Expr::If {
2009                condition: Box::new(self.pass_inline_expr(condition)),
2010                then_branch: self.pass_inline_block(then_branch),
2011                else_branch: else_branch
2012                    .as_ref()
2013                    .map(|e| Box::new(self.pass_inline_expr(e))),
2014            },
2015            Expr::While {
2016                label,
2017                condition,
2018                body,
2019            } => Expr::While {
2020                label: label.clone(),
2021                condition: Box::new(self.pass_inline_expr(condition)),
2022                body: self.pass_inline_block(body),
2023            },
2024            Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
2025            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
2026            Expr::Assign { target, value } => Expr::Assign {
2027                target: target.clone(),
2028                value: Box::new(self.pass_inline_expr(value)),
2029            },
2030            Expr::Index { expr: e, index } => Expr::Index {
2031                expr: Box::new(self.pass_inline_expr(e)),
2032                index: Box::new(self.pass_inline_expr(index)),
2033            },
2034            Expr::Array(elements) => {
2035                Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
2036            }
2037            other => other.clone(),
2038        }
2039    }
2040
2041    // ========================================================================
2042    // Pass: Loop Unrolling
2043    // ========================================================================
2044
2045    /// Unroll small counted loops for better performance
2046    fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
2047        let stmts = block
2048            .stmts
2049            .iter()
2050            .map(|s| self.pass_loop_unroll_stmt(s))
2051            .collect();
2052        let expr = block
2053            .expr
2054            .as_ref()
2055            .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
2056        Block { stmts, expr }
2057    }
2058
2059    fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
2060        match stmt {
2061            Stmt::Let { pattern, ty, init } => Stmt::Let {
2062                pattern: pattern.clone(),
2063                ty: ty.clone(),
2064                init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
2065            },
2066            Stmt::LetElse {
2067                pattern,
2068                ty,
2069                init,
2070                else_branch,
2071            } => Stmt::LetElse {
2072                pattern: pattern.clone(),
2073                ty: ty.clone(),
2074                init: self.pass_loop_unroll_expr(init),
2075                else_branch: Box::new(self.pass_loop_unroll_expr(else_branch)),
2076            },
2077            Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
2078            Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
2079            Stmt::Item(item) => Stmt::Item(item.clone()),
2080        }
2081    }
2082
2083    fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
2084        match expr {
2085            Expr::While {
2086                label,
2087                condition,
2088                body,
2089            } => {
2090                // Try to unroll if this is a countable loop
2091                if let Some(unrolled) = self.try_unroll_loop(condition, body) {
2092                    self.stats.loops_optimized += 1;
2093                    return unrolled;
2094                }
2095                // Otherwise, just recurse
2096                Expr::While {
2097                    label: label.clone(),
2098                    condition: Box::new(self.pass_loop_unroll_expr(condition)),
2099                    body: self.pass_loop_unroll_block(body),
2100                }
2101            }
2102            Expr::If {
2103                condition,
2104                then_branch,
2105                else_branch,
2106            } => Expr::If {
2107                condition: Box::new(self.pass_loop_unroll_expr(condition)),
2108                then_branch: self.pass_loop_unroll_block(then_branch),
2109                else_branch: else_branch
2110                    .as_ref()
2111                    .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
2112            },
2113            Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
2114            Expr::Binary { op, left, right } => Expr::Binary {
2115                op: *op,
2116                left: Box::new(self.pass_loop_unroll_expr(left)),
2117                right: Box::new(self.pass_loop_unroll_expr(right)),
2118            },
2119            Expr::Unary { op, expr: inner } => Expr::Unary {
2120                op: *op,
2121                expr: Box::new(self.pass_loop_unroll_expr(inner)),
2122            },
2123            Expr::Call { func, args } => Expr::Call {
2124                func: func.clone(),
2125                args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
2126            },
2127            Expr::Return(e) => {
2128                Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
2129            }
2130            Expr::Assign { target, value } => Expr::Assign {
2131                target: target.clone(),
2132                value: Box::new(self.pass_loop_unroll_expr(value)),
2133            },
2134            other => other.clone(),
2135        }
2136    }
2137
2138    /// Try to unroll a loop with known bounds
2139    /// Pattern: while i < CONST { body; i = i + 1; }
2140    fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2141        // Check for pattern: var < constant
2142        let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2143
2144        // Only unroll small loops (up to 8 iterations from 0)
2145        if upper_bound > 8 || upper_bound <= 0 {
2146            return None;
2147        }
2148
2149        // Check if body contains increment: loop_var = loop_var + 1
2150        if !self.body_has_simple_increment(&loop_var, body) {
2151            return None;
2152        }
2153
2154        // Check that loop body is small enough to unroll (max 5 statements)
2155        let stmt_count = body.stmts.len();
2156        if stmt_count > 5 {
2157            return None;
2158        }
2159
2160        // Generate unrolled body
2161        let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2162
2163        for i in 0..upper_bound {
2164            // For each iteration, substitute the loop variable with the constant
2165            let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2166
2167            // Add all statements except the increment
2168            for stmt in &substituted_body.stmts {
2169                if !self.is_increment_stmt(&loop_var, stmt) {
2170                    unrolled_stmts.push(stmt.clone());
2171                }
2172            }
2173        }
2174
2175        // Return unrolled block
2176        Some(Expr::Block(Block {
2177            stmts: unrolled_stmts,
2178            expr: None,
2179        }))
2180    }
2181
2182    /// Extract loop bounds from condition: var < constant
2183    fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2184        if let Expr::Binary {
2185            op: BinOp::Lt,
2186            left,
2187            right,
2188        } = condition
2189        {
2190            // Left should be a variable
2191            if let Expr::Path(path) = left.as_ref() {
2192                if path.segments.len() == 1 {
2193                    let var_name = path.segments[0].ident.name.clone();
2194                    // Right should be a constant
2195                    if let Some(bound) = self.as_int(right) {
2196                        return Some((var_name, bound));
2197                    }
2198                }
2199            }
2200        }
2201        None
2202    }
2203
2204    /// Check if body contains: loop_var = loop_var + 1
2205    fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2206        for stmt in &body.stmts {
2207            if self.is_increment_stmt(loop_var, stmt) {
2208                return true;
2209            }
2210        }
2211        false
2212    }
2213
2214    /// Check if statement is: var = var + 1
2215    fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2216        match stmt {
2217            Stmt::Semi(Expr::Assign { target, value })
2218            | Stmt::Expr(Expr::Assign { target, value }) => {
2219                // Target should be the loop variable
2220                if let Expr::Path(path) = target.as_ref() {
2221                    if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2222                        // Value should be var + 1
2223                        if let Expr::Binary {
2224                            op: BinOp::Add,
2225                            left,
2226                            right,
2227                        } = value.as_ref()
2228                        {
2229                            if let Expr::Path(lpath) = left.as_ref() {
2230                                if lpath.segments.len() == 1
2231                                    && lpath.segments[0].ident.name == var_name
2232                                {
2233                                    if let Some(1) = self.as_int(right) {
2234                                        return true;
2235                                    }
2236                                }
2237                            }
2238                        }
2239                    }
2240                }
2241                false
2242            }
2243            _ => false,
2244        }
2245    }
2246
2247    /// Substitute loop variable with constant value in block
2248    fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2249        let stmts = block
2250            .stmts
2251            .iter()
2252            .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2253            .collect();
2254        let expr = block
2255            .expr
2256            .as_ref()
2257            .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2258        Block { stmts, expr }
2259    }
2260
2261    fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2262        match stmt {
2263            Stmt::Let { pattern, ty, init } => Stmt::Let {
2264                pattern: pattern.clone(),
2265                ty: ty.clone(),
2266                init: init
2267                    .as_ref()
2268                    .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2269            },
2270            Stmt::LetElse {
2271                pattern,
2272                ty,
2273                init,
2274                else_branch,
2275            } => Stmt::LetElse {
2276                pattern: pattern.clone(),
2277                ty: ty.clone(),
2278                init: self.substitute_loop_var_in_expr(init, var_name, value),
2279                else_branch: Box::new(self.substitute_loop_var_in_expr(
2280                    else_branch,
2281                    var_name,
2282                    value,
2283                )),
2284            },
2285            Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2286            Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2287            Stmt::Item(item) => Stmt::Item(item.clone()),
2288        }
2289    }
2290
2291    fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2292        match expr {
2293            Expr::Path(path) => {
2294                if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2295                    return Expr::Literal(Literal::Int {
2296                        value: value.to_string(),
2297                        base: NumBase::Decimal,
2298                        suffix: None,
2299                    });
2300                }
2301                expr.clone()
2302            }
2303            Expr::Binary { op, left, right } => Expr::Binary {
2304                op: *op,
2305                left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2306                right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2307            },
2308            Expr::Unary { op, expr: inner } => Expr::Unary {
2309                op: *op,
2310                expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2311            },
2312            Expr::Call { func, args } => Expr::Call {
2313                func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2314                args: args
2315                    .iter()
2316                    .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2317                    .collect(),
2318            },
2319            Expr::If {
2320                condition,
2321                then_branch,
2322                else_branch,
2323            } => Expr::If {
2324                condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2325                then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2326                else_branch: else_branch
2327                    .as_ref()
2328                    .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2329            },
2330            Expr::While {
2331                label,
2332                condition,
2333                body,
2334            } => Expr::While {
2335                label: label.clone(),
2336                condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2337                body: self.substitute_loop_var_in_block(body, var_name, value),
2338            },
2339            Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2340            Expr::Return(e) => Expr::Return(
2341                e.as_ref()
2342                    .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2343            ),
2344            Expr::Assign { target, value: v } => Expr::Assign {
2345                target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
2346                value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2347            },
2348            Expr::Index { expr: e, index } => Expr::Index {
2349                expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2350                index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2351            },
2352            Expr::Array(elements) => Expr::Array(
2353                elements
2354                    .iter()
2355                    .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2356                    .collect(),
2357            ),
2358            other => other.clone(),
2359        }
2360    }
2361
2362    // ========================================================================
2363    // Pass: Loop Invariant Code Motion (LICM)
2364    // ========================================================================
2365
2366    /// Move loop-invariant computations out of loops
2367    fn pass_licm_block(&mut self, block: &Block) -> Block {
2368        let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2369        let expr = block
2370            .expr
2371            .as_ref()
2372            .map(|e| Box::new(self.pass_licm_expr(e)));
2373        Block { stmts, expr }
2374    }
2375
2376    fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2377        match stmt {
2378            Stmt::Let { pattern, ty, init } => Stmt::Let {
2379                pattern: pattern.clone(),
2380                ty: ty.clone(),
2381                init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2382            },
2383            Stmt::LetElse {
2384                pattern,
2385                ty,
2386                init,
2387                else_branch,
2388            } => Stmt::LetElse {
2389                pattern: pattern.clone(),
2390                ty: ty.clone(),
2391                init: self.pass_licm_expr(init),
2392                else_branch: Box::new(self.pass_licm_expr(else_branch)),
2393            },
2394            Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2395            Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2396            Stmt::Item(item) => Stmt::Item(item.clone()),
2397        }
2398    }
2399
2400    fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2401        match expr {
2402            Expr::While {
2403                label,
2404                condition,
2405                body,
2406            } => {
2407                // Find variables modified in the loop
2408                let mut modified_vars = HashSet::new();
2409                self.collect_modified_vars_block(body, &mut modified_vars);
2410
2411                // Also consider the loop condition might modify vars
2412                self.collect_modified_vars_expr(condition, &mut modified_vars);
2413
2414                // Find invariant expressions in the body
2415                let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2416                self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2417
2418                if invariant_exprs.is_empty() {
2419                    // No LICM opportunities, just recurse
2420                    return Expr::While {
2421                        label: label.clone(),
2422                        condition: Box::new(self.pass_licm_expr(condition)),
2423                        body: self.pass_licm_block(body),
2424                    };
2425                }
2426
2427                // Create let bindings for invariant expressions before the loop
2428                let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2429                let mut substitution_map: HashMap<String, String> = HashMap::new();
2430
2431                for (original_key, invariant_expr) in &invariant_exprs {
2432                    let var_name = format!("__licm_{}", self.cse_counter);
2433                    self.cse_counter += 1;
2434
2435                    pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2436                    substitution_map.insert(original_key.clone(), var_name);
2437                    self.stats.loops_optimized += 1;
2438                }
2439
2440                // Replace invariant expressions in the loop body
2441                let new_body =
2442                    self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2443
2444                // Recurse into the modified loop
2445                let new_while = Expr::While {
2446                    label: label.clone(),
2447                    condition: Box::new(self.pass_licm_expr(condition)),
2448                    body: self.pass_licm_block(&new_body),
2449                };
2450
2451                // Return block with pre-loop bindings + loop
2452                pre_loop_stmts.push(Stmt::Expr(new_while));
2453                Expr::Block(Block {
2454                    stmts: pre_loop_stmts,
2455                    expr: None,
2456                })
2457            }
2458            Expr::If {
2459                condition,
2460                then_branch,
2461                else_branch,
2462            } => Expr::If {
2463                condition: Box::new(self.pass_licm_expr(condition)),
2464                then_branch: self.pass_licm_block(then_branch),
2465                else_branch: else_branch
2466                    .as_ref()
2467                    .map(|e| Box::new(self.pass_licm_expr(e))),
2468            },
2469            Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2470            Expr::Binary { op, left, right } => Expr::Binary {
2471                op: *op,
2472                left: Box::new(self.pass_licm_expr(left)),
2473                right: Box::new(self.pass_licm_expr(right)),
2474            },
2475            Expr::Unary { op, expr: inner } => Expr::Unary {
2476                op: *op,
2477                expr: Box::new(self.pass_licm_expr(inner)),
2478            },
2479            Expr::Call { func, args } => Expr::Call {
2480                func: func.clone(),
2481                args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2482            },
2483            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2484            Expr::Assign { target, value } => Expr::Assign {
2485                target: target.clone(),
2486                value: Box::new(self.pass_licm_expr(value)),
2487            },
2488            other => other.clone(),
2489        }
2490    }
2491
2492    /// Collect all variables that are modified in a block
2493    fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2494        for stmt in &block.stmts {
2495            self.collect_modified_vars_stmt(stmt, modified);
2496        }
2497        if let Some(expr) = &block.expr {
2498            self.collect_modified_vars_expr(expr, modified);
2499        }
2500    }
2501
2502    fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2503        match stmt {
2504            Stmt::Let { pattern, init, .. } => {
2505                // Let bindings introduce new variables
2506                if let Pattern::Ident { name, .. } = pattern {
2507                    modified.insert(name.name.clone());
2508                }
2509                if let Some(e) = init {
2510                    self.collect_modified_vars_expr(e, modified);
2511                }
2512            }
2513            Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2514            _ => {}
2515        }
2516    }
2517
2518    fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2519        match expr {
2520            Expr::Assign { target, value } => {
2521                if let Expr::Path(path) = target.as_ref() {
2522                    if path.segments.len() == 1 {
2523                        modified.insert(path.segments[0].ident.name.clone());
2524                    }
2525                }
2526                self.collect_modified_vars_expr(value, modified);
2527            }
2528            Expr::Binary { left, right, .. } => {
2529                self.collect_modified_vars_expr(left, modified);
2530                self.collect_modified_vars_expr(right, modified);
2531            }
2532            Expr::Unary { expr: inner, .. } => {
2533                self.collect_modified_vars_expr(inner, modified);
2534            }
2535            Expr::If {
2536                condition,
2537                then_branch,
2538                else_branch,
2539            } => {
2540                self.collect_modified_vars_expr(condition, modified);
2541                self.collect_modified_vars_block(then_branch, modified);
2542                if let Some(e) = else_branch {
2543                    self.collect_modified_vars_expr(e, modified);
2544                }
2545            }
2546            Expr::While {
2547                label,
2548                condition,
2549                body,
2550            } => {
2551                self.collect_modified_vars_expr(condition, modified);
2552                self.collect_modified_vars_block(body, modified);
2553            }
2554            Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2555            Expr::Call { args, .. } => {
2556                for arg in args {
2557                    self.collect_modified_vars_expr(arg, modified);
2558                }
2559            }
2560            Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2561            _ => {}
2562        }
2563    }
2564
2565    /// Find loop-invariant expressions in a block
2566    fn find_loop_invariants(
2567        &self,
2568        block: &Block,
2569        modified: &HashSet<String>,
2570        out: &mut Vec<(String, Expr)>,
2571    ) {
2572        for stmt in &block.stmts {
2573            self.find_loop_invariants_stmt(stmt, modified, out);
2574        }
2575        if let Some(expr) = &block.expr {
2576            self.find_loop_invariants_expr(expr, modified, out);
2577        }
2578    }
2579
2580    fn find_loop_invariants_stmt(
2581        &self,
2582        stmt: &Stmt,
2583        modified: &HashSet<String>,
2584        out: &mut Vec<(String, Expr)>,
2585    ) {
2586        match stmt {
2587            Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2588            Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2589            _ => {}
2590        }
2591    }
2592
2593    fn find_loop_invariants_expr(
2594        &self,
2595        expr: &Expr,
2596        modified: &HashSet<String>,
2597        out: &mut Vec<(String, Expr)>,
2598    ) {
2599        // First recurse into subexpressions
2600        match expr {
2601            Expr::Binary { left, right, .. } => {
2602                self.find_loop_invariants_expr(left, modified, out);
2603                self.find_loop_invariants_expr(right, modified, out);
2604            }
2605            Expr::Unary { expr: inner, .. } => {
2606                self.find_loop_invariants_expr(inner, modified, out);
2607            }
2608            Expr::Call { args, .. } => {
2609                for arg in args {
2610                    self.find_loop_invariants_expr(arg, modified, out);
2611                }
2612            }
2613            Expr::Index { expr: e, index } => {
2614                self.find_loop_invariants_expr(e, modified, out);
2615                self.find_loop_invariants_expr(index, modified, out);
2616            }
2617            _ => {}
2618        }
2619
2620        // Then check if this expression is loop-invariant and worth hoisting
2621        if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2622            let key = format!("{:?}", expr_hash(expr));
2623            // Check if we already have this exact expression
2624            if !out.iter().any(|(k, _)| k == &key) {
2625                out.push((key, expr.clone()));
2626            }
2627        }
2628    }
2629
2630    /// Check if an expression is loop-invariant (doesn't depend on modified variables)
2631    fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2632        match expr {
2633            Expr::Literal(_) => true,
2634            Expr::Path(path) => {
2635                if path.segments.len() == 1 {
2636                    !modified.contains(&path.segments[0].ident.name)
2637                } else {
2638                    true // Qualified paths are assumed invariant
2639                }
2640            }
2641            Expr::Binary { left, right, .. } => {
2642                self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2643            }
2644            Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2645            Expr::Index { expr: e, index } => {
2646                self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2647            }
2648            // Calls are not invariant (might have side effects)
2649            Expr::Call { .. } => false,
2650            // Other expressions are not invariant
2651            _ => false,
2652        }
2653    }
2654
2655    /// Replace invariant expressions with variable references
2656    fn replace_invariants_in_block(
2657        &self,
2658        block: &Block,
2659        invariants: &[(String, Expr)],
2660        subs: &HashMap<String, String>,
2661    ) -> Block {
2662        let stmts = block
2663            .stmts
2664            .iter()
2665            .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2666            .collect();
2667        let expr = block
2668            .expr
2669            .as_ref()
2670            .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2671        Block { stmts, expr }
2672    }
2673
2674    fn replace_invariants_in_stmt(
2675        &self,
2676        stmt: &Stmt,
2677        invariants: &[(String, Expr)],
2678        subs: &HashMap<String, String>,
2679    ) -> Stmt {
2680        match stmt {
2681            Stmt::Let { pattern, ty, init } => Stmt::Let {
2682                pattern: pattern.clone(),
2683                ty: ty.clone(),
2684                init: init
2685                    .as_ref()
2686                    .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2687            },
2688            Stmt::LetElse {
2689                pattern,
2690                ty,
2691                init,
2692                else_branch,
2693            } => Stmt::LetElse {
2694                pattern: pattern.clone(),
2695                ty: ty.clone(),
2696                init: self.replace_invariants_in_expr(init, invariants, subs),
2697                else_branch: Box::new(self.replace_invariants_in_expr(
2698                    else_branch,
2699                    invariants,
2700                    subs,
2701                )),
2702            },
2703            Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2704            Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2705            Stmt::Item(item) => Stmt::Item(item.clone()),
2706        }
2707    }
2708
2709    fn replace_invariants_in_expr(
2710        &self,
2711        expr: &Expr,
2712        invariants: &[(String, Expr)],
2713        subs: &HashMap<String, String>,
2714    ) -> Expr {
2715        // Check if this expression matches an invariant
2716        let key = format!("{:?}", expr_hash(expr));
2717        for (inv_key, inv_expr) in invariants {
2718            if &key == inv_key && expr_eq(expr, inv_expr) {
2719                if let Some(var_name) = subs.get(inv_key) {
2720                    return Expr::Path(TypePath {
2721                        segments: vec![PathSegment {
2722                            ident: Ident {
2723                                name: var_name.clone(),
2724                                evidentiality: None,
2725                                affect: None,
2726                                span: Span { start: 0, end: 0 },
2727                            },
2728                            generics: None,
2729                        }],
2730                    });
2731                }
2732            }
2733        }
2734
2735        // Otherwise recurse
2736        match expr {
2737            Expr::Binary { op, left, right } => Expr::Binary {
2738                op: *op,
2739                left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2740                right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2741            },
2742            Expr::Unary { op, expr: inner } => Expr::Unary {
2743                op: *op,
2744                expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2745            },
2746            Expr::Call { func, args } => Expr::Call {
2747                func: func.clone(),
2748                args: args
2749                    .iter()
2750                    .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2751                    .collect(),
2752            },
2753            Expr::If {
2754                condition,
2755                then_branch,
2756                else_branch,
2757            } => Expr::If {
2758                condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2759                then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2760                else_branch: else_branch
2761                    .as_ref()
2762                    .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2763            },
2764            Expr::While {
2765                label,
2766                condition,
2767                body,
2768            } => Expr::While {
2769                label: label.clone(),
2770                condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2771                body: self.replace_invariants_in_block(body, invariants, subs),
2772            },
2773            Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2774            Expr::Return(e) => Expr::Return(
2775                e.as_ref()
2776                    .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2777            ),
2778            Expr::Assign { target, value } => Expr::Assign {
2779                target: target.clone(),
2780                value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2781            },
2782            Expr::Index { expr: e, index } => Expr::Index {
2783                expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2784                index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2785            },
2786            other => other.clone(),
2787        }
2788    }
2789
2790    // ========================================================================
2791    // Pass: Common Subexpression Elimination (CSE)
2792    // ========================================================================
2793
2794    fn pass_cse_block(&mut self, block: &Block) -> Block {
2795        // Step 1: Collect all expressions in this block
2796        let mut collected = Vec::new();
2797        collect_exprs_from_block(block, &mut collected);
2798
2799        // Step 2: Count occurrences using hash + equality check
2800        let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2801        for ce in &collected {
2802            let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2803            // Check if this exact expression is already in the bucket
2804            let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2805            if !found {
2806                entry.push(ce.expr.clone());
2807            }
2808        }
2809
2810        // Count actual occurrences (need to count duplicates)
2811        let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2812        for ce in &collected {
2813            // Find or create entry for this expression
2814            let existing = occurrence_counts
2815                .iter_mut()
2816                .find(|(e, _)| expr_eq(e, &ce.expr));
2817            if let Some((_, count)) = existing {
2818                *count += 1;
2819            } else {
2820                occurrence_counts.push((ce.expr.clone(), 1));
2821            }
2822        }
2823
2824        // Step 3: Find expressions that occur 2+ times
2825        let candidates: Vec<Expr> = occurrence_counts
2826            .into_iter()
2827            .filter(|(_, count)| *count >= 2)
2828            .map(|(expr, _)| expr)
2829            .collect();
2830
2831        if candidates.is_empty() {
2832            // No CSE opportunities, just recurse into nested blocks
2833            return self.pass_cse_nested(block);
2834        }
2835
2836        // Step 4: Create let bindings for each candidate and replace occurrences
2837        let mut result_block = block.clone();
2838        let mut new_lets: Vec<Stmt> = Vec::new();
2839
2840        for expr in candidates {
2841            let var_name = format!("__cse_{}", self.cse_counter);
2842            self.cse_counter += 1;
2843
2844            // Create the let binding
2845            new_lets.push(make_cse_let(&var_name, expr.clone()));
2846
2847            // Replace all occurrences in the block
2848            result_block = replace_in_block(&result_block, &expr, &var_name);
2849
2850            self.stats.expressions_deduplicated += 1;
2851        }
2852
2853        // Step 5: Prepend the new let bindings to the block
2854        let mut final_stmts = new_lets;
2855        final_stmts.extend(result_block.stmts);
2856
2857        // Step 6: Recurse into nested blocks
2858        let result = Block {
2859            stmts: final_stmts,
2860            expr: result_block.expr,
2861        };
2862        self.pass_cse_nested(&result)
2863    }
2864
2865    /// Recurse CSE into nested blocks (if, while, block expressions)
2866    fn pass_cse_nested(&mut self, block: &Block) -> Block {
2867        let stmts = block
2868            .stmts
2869            .iter()
2870            .map(|stmt| self.pass_cse_stmt(stmt))
2871            .collect();
2872        let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2873        Block { stmts, expr }
2874    }
2875
2876    fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2877        match stmt {
2878            Stmt::Let { pattern, ty, init } => Stmt::Let {
2879                pattern: pattern.clone(),
2880                ty: ty.clone(),
2881                init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2882            },
2883            Stmt::LetElse {
2884                pattern,
2885                ty,
2886                init,
2887                else_branch,
2888            } => Stmt::LetElse {
2889                pattern: pattern.clone(),
2890                ty: ty.clone(),
2891                init: self.pass_cse_expr(init),
2892                else_branch: Box::new(self.pass_cse_expr(else_branch)),
2893            },
2894            Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2895            Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2896            Stmt::Item(item) => Stmt::Item(item.clone()),
2897        }
2898    }
2899
2900    fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2901        match expr {
2902            Expr::If {
2903                condition,
2904                then_branch,
2905                else_branch,
2906            } => Expr::If {
2907                condition: Box::new(self.pass_cse_expr(condition)),
2908                then_branch: self.pass_cse_block(then_branch),
2909                else_branch: else_branch
2910                    .as_ref()
2911                    .map(|e| Box::new(self.pass_cse_expr(e))),
2912            },
2913            Expr::While {
2914                label,
2915                condition,
2916                body,
2917            } => Expr::While {
2918                label: label.clone(),
2919                condition: Box::new(self.pass_cse_expr(condition)),
2920                body: self.pass_cse_block(body),
2921            },
2922            Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2923            Expr::Binary { op, left, right } => Expr::Binary {
2924                op: *op,
2925                left: Box::new(self.pass_cse_expr(left)),
2926                right: Box::new(self.pass_cse_expr(right)),
2927            },
2928            Expr::Unary { op, expr: inner } => Expr::Unary {
2929                op: *op,
2930                expr: Box::new(self.pass_cse_expr(inner)),
2931            },
2932            Expr::Call { func, args } => Expr::Call {
2933                func: func.clone(),
2934                args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2935            },
2936            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2937            Expr::Assign { target, value } => Expr::Assign {
2938                target: target.clone(),
2939                value: Box::new(self.pass_cse_expr(value)),
2940            },
2941            other => other.clone(),
2942        }
2943    }
2944}
2945
2946// ============================================================================
2947// Common Subexpression Elimination (CSE) - Helper Functions
2948// ============================================================================
2949
2950/// Expr hash for CSE - identifies structurally equivalent expressions
2951fn expr_hash(expr: &Expr) -> u64 {
2952    use std::collections::hash_map::DefaultHasher;
2953    use std::hash::Hasher;
2954
2955    let mut hasher = DefaultHasher::new();
2956    expr_hash_recursive(expr, &mut hasher);
2957    hasher.finish()
2958}
2959
2960fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2961    use std::hash::Hash;
2962
2963    std::mem::discriminant(expr).hash(hasher);
2964
2965    match expr {
2966        Expr::Literal(lit) => match lit {
2967            Literal::Int { value, .. } => value.hash(hasher),
2968            Literal::Float { value, .. } => value.hash(hasher),
2969            Literal::String(s) => s.hash(hasher),
2970            Literal::Char(c) => c.hash(hasher),
2971            Literal::Bool(b) => b.hash(hasher),
2972            _ => {}
2973        },
2974        Expr::Path(path) => {
2975            for seg in &path.segments {
2976                seg.ident.name.hash(hasher);
2977            }
2978        }
2979        Expr::Binary { op, left, right } => {
2980            std::mem::discriminant(op).hash(hasher);
2981            expr_hash_recursive(left, hasher);
2982            expr_hash_recursive(right, hasher);
2983        }
2984        Expr::Unary { op, expr } => {
2985            std::mem::discriminant(op).hash(hasher);
2986            expr_hash_recursive(expr, hasher);
2987        }
2988        Expr::Call { func, args } => {
2989            expr_hash_recursive(func, hasher);
2990            args.len().hash(hasher);
2991            for arg in args {
2992                expr_hash_recursive(arg, hasher);
2993            }
2994        }
2995        Expr::Index { expr, index } => {
2996            expr_hash_recursive(expr, hasher);
2997            expr_hash_recursive(index, hasher);
2998        }
2999        _ => {}
3000    }
3001}
3002
3003/// Check if an expression is pure (no side effects)
3004fn is_pure_expr(expr: &Expr) -> bool {
3005    match expr {
3006        Expr::Literal(_) => true,
3007        Expr::Path(_) => true,
3008        Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
3009        Expr::Unary { expr, .. } => is_pure_expr(expr),
3010        Expr::If {
3011            condition,
3012            then_branch,
3013            else_branch,
3014        } => {
3015            is_pure_expr(condition)
3016                && then_branch.stmts.is_empty()
3017                && then_branch
3018                    .expr
3019                    .as_ref()
3020                    .map(|e| is_pure_expr(e))
3021                    .unwrap_or(true)
3022                && else_branch
3023                    .as_ref()
3024                    .map(|e| is_pure_expr(e))
3025                    .unwrap_or(true)
3026        }
3027        Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
3028        Expr::Array(elements) => elements.iter().all(is_pure_expr),
3029        // These have side effects
3030        Expr::Call { .. } => false,
3031        Expr::Assign { .. } => false,
3032        Expr::Return(_) => false,
3033        _ => false,
3034    }
3035}
3036
3037/// Check if an expression is worth caching (complex enough)
3038fn is_cse_worthy(expr: &Expr) -> bool {
3039    match expr {
3040        // Simple expressions - not worth CSE overhead
3041        Expr::Literal(_) => false,
3042        Expr::Path(_) => false,
3043        // Binary operations are worth it
3044        Expr::Binary { .. } => true,
3045        // Unary operations might be worth it
3046        Expr::Unary { .. } => true,
3047        // Calls might be worth it if pure (but we can't know easily)
3048        Expr::Call { .. } => false,
3049        // Index operations are worth it
3050        Expr::Index { .. } => true,
3051        _ => false,
3052    }
3053}
3054
3055/// Check if two expressions are structurally equal
3056fn expr_eq(a: &Expr, b: &Expr) -> bool {
3057    match (a, b) {
3058        (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
3059            (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
3060            (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
3061            (Literal::String(sa), Literal::String(sb)) => sa == sb,
3062            (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
3063            (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
3064            _ => false,
3065        },
3066        (Expr::Path(pa), Expr::Path(pb)) => {
3067            pa.segments.len() == pb.segments.len()
3068                && pa
3069                    .segments
3070                    .iter()
3071                    .zip(&pb.segments)
3072                    .all(|(sa, sb)| sa.ident.name == sb.ident.name)
3073        }
3074        (
3075            Expr::Binary {
3076                op: oa,
3077                left: la,
3078                right: ra,
3079            },
3080            Expr::Binary {
3081                op: ob,
3082                left: lb,
3083                right: rb,
3084            },
3085        ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
3086        (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
3087            oa == ob && expr_eq(ea, eb)
3088        }
3089        (
3090            Expr::Index {
3091                expr: ea,
3092                index: ia,
3093            },
3094            Expr::Index {
3095                expr: eb,
3096                index: ib,
3097            },
3098        ) => expr_eq(ea, eb) && expr_eq(ia, ib),
3099        (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
3100            expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
3101        }
3102        _ => false,
3103    }
3104}
3105
3106/// Collected expression with its location info
3107#[derive(Clone)]
3108struct CollectedExpr {
3109    expr: Expr,
3110    hash: u64,
3111}
3112
3113/// Collect all CSE-worthy expressions from an expression tree
3114fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
3115    // First, recurse into subexpressions
3116    match expr {
3117        Expr::Binary { left, right, .. } => {
3118            collect_exprs_from_expr(left, out);
3119            collect_exprs_from_expr(right, out);
3120        }
3121        Expr::Unary { expr: inner, .. } => {
3122            collect_exprs_from_expr(inner, out);
3123        }
3124        Expr::Index { expr: e, index } => {
3125            collect_exprs_from_expr(e, out);
3126            collect_exprs_from_expr(index, out);
3127        }
3128        Expr::Call { func, args } => {
3129            collect_exprs_from_expr(func, out);
3130            for arg in args {
3131                collect_exprs_from_expr(arg, out);
3132            }
3133        }
3134        Expr::If {
3135            condition,
3136            then_branch,
3137            else_branch,
3138        } => {
3139            collect_exprs_from_expr(condition, out);
3140            collect_exprs_from_block(then_branch, out);
3141            if let Some(else_expr) = else_branch {
3142                collect_exprs_from_expr(else_expr, out);
3143            }
3144        }
3145        Expr::While {
3146            label,
3147            condition,
3148            body,
3149        } => {
3150            collect_exprs_from_expr(condition, out);
3151            collect_exprs_from_block(body, out);
3152        }
3153        Expr::Block(block) => {
3154            collect_exprs_from_block(block, out);
3155        }
3156        Expr::Return(Some(e)) => {
3157            collect_exprs_from_expr(e, out);
3158        }
3159        Expr::Assign { value, .. } => {
3160            collect_exprs_from_expr(value, out);
3161        }
3162        Expr::Array(elements) => {
3163            for e in elements {
3164                collect_exprs_from_expr(e, out);
3165            }
3166        }
3167        _ => {}
3168    }
3169
3170    // Then, if this expression is CSE-worthy and pure, add it
3171    if is_cse_worthy(expr) && is_pure_expr(expr) {
3172        out.push(CollectedExpr {
3173            expr: expr.clone(),
3174            hash: expr_hash(expr),
3175        });
3176    }
3177}
3178
3179/// Collect expressions from a block
3180fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
3181    for stmt in &block.stmts {
3182        match stmt {
3183            Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
3184            Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
3185            _ => {}
3186        }
3187    }
3188    if let Some(e) = &block.expr {
3189        collect_exprs_from_expr(e, out);
3190    }
3191}
3192
3193/// Replace all occurrences of target expression with a variable reference
3194fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
3195    // Check if this expression matches the target
3196    if expr_eq(expr, target) {
3197        return Expr::Path(TypePath {
3198            segments: vec![PathSegment {
3199                ident: Ident {
3200                    name: var_name.to_string(),
3201                    evidentiality: None,
3202                    affect: None,
3203                    span: Span { start: 0, end: 0 },
3204                },
3205                generics: None,
3206            }],
3207        });
3208    }
3209
3210    // Otherwise, recurse into subexpressions
3211    match expr {
3212        Expr::Binary { op, left, right } => Expr::Binary {
3213            op: *op,
3214            left: Box::new(replace_in_expr(left, target, var_name)),
3215            right: Box::new(replace_in_expr(right, target, var_name)),
3216        },
3217        Expr::Unary { op, expr: inner } => Expr::Unary {
3218            op: *op,
3219            expr: Box::new(replace_in_expr(inner, target, var_name)),
3220        },
3221        Expr::Index { expr: e, index } => Expr::Index {
3222            expr: Box::new(replace_in_expr(e, target, var_name)),
3223            index: Box::new(replace_in_expr(index, target, var_name)),
3224        },
3225        Expr::Call { func, args } => Expr::Call {
3226            func: Box::new(replace_in_expr(func, target, var_name)),
3227            args: args
3228                .iter()
3229                .map(|a| replace_in_expr(a, target, var_name))
3230                .collect(),
3231        },
3232        Expr::If {
3233            condition,
3234            then_branch,
3235            else_branch,
3236        } => Expr::If {
3237            condition: Box::new(replace_in_expr(condition, target, var_name)),
3238            then_branch: replace_in_block(then_branch, target, var_name),
3239            else_branch: else_branch
3240                .as_ref()
3241                .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3242        },
3243        Expr::While {
3244            label,
3245            condition,
3246            body,
3247        } => Expr::While {
3248            label: label.clone(),
3249            condition: Box::new(replace_in_expr(condition, target, var_name)),
3250            body: replace_in_block(body, target, var_name),
3251        },
3252        Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3253        Expr::Return(e) => Expr::Return(
3254            e.as_ref()
3255                .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3256        ),
3257        Expr::Assign { target: t, value } => Expr::Assign {
3258            target: t.clone(),
3259            value: Box::new(replace_in_expr(value, target, var_name)),
3260        },
3261        Expr::Array(elements) => Expr::Array(
3262            elements
3263                .iter()
3264                .map(|e| replace_in_expr(e, target, var_name))
3265                .collect(),
3266        ),
3267        other => other.clone(),
3268    }
3269}
3270
3271/// Replace in a block
3272fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3273    let stmts = block
3274        .stmts
3275        .iter()
3276        .map(|stmt| match stmt {
3277            Stmt::Let { pattern, ty, init } => Stmt::Let {
3278                pattern: pattern.clone(),
3279                ty: ty.clone(),
3280                init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3281            },
3282            Stmt::LetElse {
3283                pattern,
3284                ty,
3285                init,
3286                else_branch,
3287            } => Stmt::LetElse {
3288                pattern: pattern.clone(),
3289                ty: ty.clone(),
3290                init: replace_in_expr(init, target, var_name),
3291                else_branch: Box::new(replace_in_expr(else_branch, target, var_name)),
3292            },
3293            Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3294            Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3295            Stmt::Item(item) => Stmt::Item(item.clone()),
3296        })
3297        .collect();
3298
3299    let expr = block
3300        .expr
3301        .as_ref()
3302        .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3303
3304    Block { stmts, expr }
3305}
3306
3307/// Create a let statement for a CSE variable
3308fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3309    Stmt::Let {
3310        pattern: Pattern::Ident {
3311            mutable: false,
3312            name: Ident {
3313                name: var_name.to_string(),
3314                evidentiality: None,
3315                affect: None,
3316                span: Span { start: 0, end: 0 },
3317            },
3318            evidentiality: None,
3319        },
3320        ty: None,
3321        init: Some(expr),
3322    }
3323}
3324
3325// ============================================================================
3326// Public API
3327// ============================================================================
3328
3329/// Optimize a source file at the given optimization level
3330pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3331    let mut optimizer = Optimizer::new(level);
3332    let optimized = optimizer.optimize_file(file);
3333    (optimized, optimizer.stats)
3334}
3335
3336// ============================================================================
3337// Tests
3338// ============================================================================
3339
3340#[cfg(test)]
3341mod tests {
3342    use super::*;
3343
3344    /// Helper to create an integer literal expression
3345    fn int_lit(v: i64) -> Expr {
3346        Expr::Literal(Literal::Int {
3347            value: v.to_string(),
3348            base: NumBase::Decimal,
3349            suffix: None,
3350        })
3351    }
3352
3353    /// Helper to create a variable reference
3354    fn var(name: &str) -> Expr {
3355        Expr::Path(TypePath {
3356            segments: vec![PathSegment {
3357                ident: Ident {
3358                    name: name.to_string(),
3359                    evidentiality: None,
3360                    affect: None,
3361                    span: Span { start: 0, end: 0 },
3362                },
3363                generics: None,
3364            }],
3365        })
3366    }
3367
3368    /// Helper to create a binary add expression
3369    fn add(left: Expr, right: Expr) -> Expr {
3370        Expr::Binary {
3371            op: BinOp::Add,
3372            left: Box::new(left),
3373            right: Box::new(right),
3374        }
3375    }
3376
3377    /// Helper to create a binary multiply expression
3378    fn mul(left: Expr, right: Expr) -> Expr {
3379        Expr::Binary {
3380            op: BinOp::Mul,
3381            left: Box::new(left),
3382            right: Box::new(right),
3383        }
3384    }
3385
3386    #[test]
3387    fn test_expr_hash_equal() {
3388        // Same expressions should have same hash
3389        let e1 = add(var("a"), var("b"));
3390        let e2 = add(var("a"), var("b"));
3391        assert_eq!(expr_hash(&e1), expr_hash(&e2));
3392    }
3393
3394    #[test]
3395    fn test_expr_hash_different() {
3396        // Different expressions should have different hashes
3397        let e1 = add(var("a"), var("b"));
3398        let e2 = add(var("a"), var("c"));
3399        assert_ne!(expr_hash(&e1), expr_hash(&e2));
3400    }
3401
3402    #[test]
3403    fn test_expr_eq() {
3404        let e1 = add(var("a"), var("b"));
3405        let e2 = add(var("a"), var("b"));
3406        let e3 = add(var("a"), var("c"));
3407
3408        assert!(expr_eq(&e1, &e2));
3409        assert!(!expr_eq(&e1, &e3));
3410    }
3411
3412    #[test]
3413    fn test_is_pure_expr() {
3414        assert!(is_pure_expr(&int_lit(42)));
3415        assert!(is_pure_expr(&var("x")));
3416        assert!(is_pure_expr(&add(var("a"), var("b"))));
3417
3418        // Calls are not pure
3419        let call = Expr::Call {
3420            func: Box::new(var("print")),
3421            args: vec![int_lit(42)],
3422        };
3423        assert!(!is_pure_expr(&call));
3424    }
3425
3426    #[test]
3427    fn test_is_cse_worthy() {
3428        assert!(!is_cse_worthy(&int_lit(42))); // literals not worth it
3429        assert!(!is_cse_worthy(&var("x"))); // variables not worth it
3430        assert!(is_cse_worthy(&add(var("a"), var("b")))); // binary ops worth it
3431    }
3432
3433    #[test]
3434    fn test_cse_basic() {
3435        // Create a block with repeated subexpression:
3436        // let x = a + b;
3437        // let y = (a + b) * 2;
3438        // The (a + b) should be extracted
3439        let a_plus_b = add(var("a"), var("b"));
3440
3441        let block = Block {
3442            stmts: vec![
3443                Stmt::Let {
3444                    pattern: Pattern::Ident {
3445                        mutable: false,
3446                        name: Ident {
3447                            name: "x".to_string(),
3448                            evidentiality: None,
3449                            affect: None,
3450                            span: Span { start: 0, end: 0 },
3451                        },
3452                        evidentiality: None,
3453                    },
3454                    ty: None,
3455                    init: Some(a_plus_b.clone()),
3456                },
3457                Stmt::Let {
3458                    pattern: Pattern::Ident {
3459                        mutable: false,
3460                        name: Ident {
3461                            name: "y".to_string(),
3462                            evidentiality: None,
3463                            affect: None,
3464                            span: Span { start: 0, end: 0 },
3465                        },
3466                        evidentiality: None,
3467                    },
3468                    ty: None,
3469                    init: Some(mul(a_plus_b.clone(), int_lit(2))),
3470                },
3471            ],
3472            expr: None,
3473        };
3474
3475        let mut optimizer = Optimizer::new(OptLevel::Standard);
3476        let result = optimizer.pass_cse_block(&block);
3477
3478        // Should have 3 statements now: __cse_0 = a + b, x = __cse_0, y = __cse_0 * 2
3479        assert_eq!(result.stmts.len(), 3);
3480        assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3481
3482        // First statement should be the CSE let binding
3483        if let Stmt::Let {
3484            pattern: Pattern::Ident { name, .. },
3485            ..
3486        } = &result.stmts[0]
3487        {
3488            assert_eq!(name.name, "__cse_0");
3489        } else {
3490            panic!("Expected CSE let binding");
3491        }
3492    }
3493
3494    #[test]
3495    fn test_cse_no_duplicates() {
3496        // No repeated expressions - should not add any CSE bindings
3497        let block = Block {
3498            stmts: vec![
3499                Stmt::Let {
3500                    pattern: Pattern::Ident {
3501                        mutable: false,
3502                        name: Ident {
3503                            name: "x".to_string(),
3504                            evidentiality: None,
3505                            affect: None,
3506                            span: Span { start: 0, end: 0 },
3507                        },
3508                        evidentiality: None,
3509                    },
3510                    ty: None,
3511                    init: Some(add(var("a"), var("b"))),
3512                },
3513                Stmt::Let {
3514                    pattern: Pattern::Ident {
3515                        mutable: false,
3516                        name: Ident {
3517                            name: "y".to_string(),
3518                            evidentiality: None,
3519                            affect: None,
3520                            span: Span { start: 0, end: 0 },
3521                        },
3522                        evidentiality: None,
3523                    },
3524                    ty: None,
3525                    init: Some(add(var("c"), var("d"))),
3526                },
3527            ],
3528            expr: None,
3529        };
3530
3531        let mut optimizer = Optimizer::new(OptLevel::Standard);
3532        let result = optimizer.pass_cse_block(&block);
3533
3534        // Should still have 2 statements (no CSE applied)
3535        assert_eq!(result.stmts.len(), 2);
3536        assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3537    }
3538}