sigil_parser/
optimize.rs

1//! Sigil Optimization Passes
2//!
3//! A comprehensive optimization framework for the Sigil language.
4//! Implements industry-standard compiler optimizations.
5
6use crate::ast::{
7    self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8    Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13// ============================================================================
14// Optimization Pass Infrastructure
15// ============================================================================
16
17/// Optimization level
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20    /// No optimizations (O0)
21    None,
22    /// Basic optimizations (O1) - constant folding, dead code elimination
23    Basic,
24    /// Standard optimizations (O2) - adds CSE, strength reduction, inlining
25    Standard,
26    /// Aggressive optimizations (O3) - adds loop opts, vectorization hints
27    Aggressive,
28    /// Size optimization (Os)
29    Size,
30}
31
32/// Statistics collected during optimization
33#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35    pub constants_folded: usize,
36    pub dead_code_eliminated: usize,
37    pub expressions_deduplicated: usize,
38    pub functions_inlined: usize,
39    pub strength_reductions: usize,
40    pub branches_simplified: usize,
41    pub loops_optimized: usize,
42    pub tail_recursion_transforms: usize,
43    pub memoization_transforms: usize,
44}
45
46/// The main optimizer that runs all passes
47pub struct Optimizer {
48    level: OptLevel,
49    stats: OptStats,
50    /// Function bodies for inlining decisions
51    functions: HashMap<String, ast::Function>,
52    /// Track which functions are recursive
53    recursive_functions: HashSet<String>,
54    /// Counter for CSE variable names
55    cse_counter: usize,
56}
57
58impl Optimizer {
59    pub fn new(level: OptLevel) -> Self {
60        Self {
61            level,
62            stats: OptStats::default(),
63            functions: HashMap::new(),
64            recursive_functions: HashSet::new(),
65            cse_counter: 0,
66        }
67    }
68
69    /// Get optimization statistics
70    pub fn stats(&self) -> &OptStats {
71        &self.stats
72    }
73
74    /// Optimize a source file (returns optimized copy)
75    pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76        // First pass: collect function information
77        for item in &file.items {
78            if let Item::Function(func) = &item.node {
79                self.functions.insert(func.name.name.clone(), func.clone());
80                if self.is_recursive(&func.name.name, func) {
81                    self.recursive_functions.insert(func.name.name.clone());
82                }
83            }
84        }
85
86        // Accumulator transformation pass (for Standard and Aggressive optimization)
87        // This transforms double-recursive functions like fib into tail-recursive form
88        // Enabled for Standard+ because it provides massive speedups (O(2^n) -> O(n))
89        let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90        let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92        if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93            for item in &file.items {
94                if let Item::Function(func) = &item.node {
95                    if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96                    {
97                        // Add helper function first
98                        new_items.push(crate::span::Spanned {
99                            node: Item::Function(helper_func),
100                            span: item.span.clone(),
101                        });
102                        transformed_functions
103                            .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104                        self.stats.tail_recursion_transforms += 1;
105                    }
106                }
107            }
108
109            // NOTE: Memoization transform is disabled for now due to complexity with
110            // cache lifetime management. Instead, use iterative implementations for
111            // functions like ackermann where beneficial.
112            //
113            // TODO: Implement proper memoization with thread-local or passed-through caches
114        }
115
116        // Optimization passes
117        let items: Vec<_> = file
118            .items
119            .iter()
120            .map(|item| {
121                let node = match &item.node {
122                    Item::Function(func) => {
123                        // Check if this function was transformed by accumulator
124                        if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125                            if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126                                && transformed_functions.contains_key(&func.name.name)
127                            {
128                                Item::Function(self.optimize_function(&wrapper))
129                            } else {
130                                Item::Function(self.optimize_function(func))
131                            }
132                        } else {
133                            Item::Function(self.optimize_function(func))
134                        }
135                    }
136                    other => other.clone(),
137                };
138                crate::span::Spanned {
139                    node,
140                    span: item.span.clone(),
141                }
142            })
143            .collect();
144
145        // Combine new helper functions with transformed items
146        new_items.extend(items);
147
148        ast::SourceFile {
149            attrs: file.attrs.clone(),
150            config: file.config.clone(),
151            items: new_items,
152        }
153    }
154
155    /// Try to transform a double-recursive function into tail-recursive form
156    /// Returns (helper_function, wrapper_function) if transformation is possible
157    fn try_accumulator_transform(
158        &self,
159        func: &ast::Function,
160    ) -> Option<(ast::Function, ast::Function)> {
161        // Only transform functions with single parameter
162        if func.params.len() != 1 {
163            return None;
164        }
165
166        // Must be recursive
167        if !self.recursive_functions.contains(&func.name.name) {
168            return None;
169        }
170
171        let body = func.body.as_ref()?;
172
173        // Detect fib-like pattern:
174        // if n <= 1 { return n; }
175        // return fib(n - 1) + fib(n - 2);
176        if !self.is_fib_like_pattern(&func.name.name, body) {
177            return None;
178        }
179
180        // Get parameter name
181        let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182            name.name.clone()
183        } else {
184            return None;
185        };
186
187        // Generate helper function name
188        let helper_name = format!("{}_tail", func.name.name);
189
190        // Create helper function: fn fib_tail(n, a, b) { if n <= 0 { return a; } return fib_tail(n - 1, b, a + b); }
191        let helper_func = self.generate_fib_helper(&helper_name, &param_name);
192
193        // Create wrapper function: fn fib(n) { return fib_tail(n, 0, 1); }
194        let wrapper_func =
195            self.generate_fib_wrapper(&func.name.name, &helper_name, &param_name, func);
196
197        Some((helper_func, wrapper_func))
198    }
199
200    /// Check if a function body matches the Fibonacci pattern
201    fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202        // Pattern we're looking for:
203        // { if n <= 1 { return n; } return f(n-1) + f(n-2); }
204        // or
205        // { if n <= 1 { return n; } f(n-1) + f(n-2) }
206
207        // Should have an if statement/expression followed by a recursive expression
208        if body.stmts.is_empty() && body.expr.is_none() {
209            return false;
210        }
211
212        // Check for the pattern in the block expression
213        if let Some(expr) = &body.expr {
214            if let Expr::If {
215                else_branch: Some(else_expr),
216                ..
217            } = expr.as_ref()
218            {
219                // Check if then_branch is a base case (return n or similar)
220                // Check if else_branch has double recursive calls
221                return self.is_double_recursive_expr(func_name, else_expr);
222            }
223        }
224
225        // Check for pattern: if ... { return ...; } return f(n-1) + f(n-2);
226        if body.stmts.len() >= 1 {
227            // Last statement or expression should be the recursive call
228            if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229                if let Expr::Return(Some(ret_expr)) = expr {
230                    return self.is_double_recursive_expr(func_name, ret_expr);
231                }
232            }
233            if let Some(expr) = &body.expr {
234                return self.is_double_recursive_expr(func_name, expr);
235            }
236        }
237
238        false
239    }
240
241    /// Check if expression is f(n-1) + f(n-2) pattern
242    fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243        if let Expr::Binary {
244            op: BinOp::Add,
245            left,
246            right,
247        } = expr
248        {
249            let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250            let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251            return left_is_recursive && right_is_recursive;
252        }
253        false
254    }
255
256    /// Check if expression is f(n - k) for some constant k
257    fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258        if let Expr::Call { func, args } = expr {
259            if let Expr::Path(path) = func.as_ref() {
260                if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261                    // Check if argument is n - constant
262                    if args.len() == 1 {
263                        if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264                            return true;
265                        }
266                    }
267                }
268            }
269        }
270        false
271    }
272
273    /// Generate the tail-recursive helper function
274    fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275        let span = Span { start: 0, end: 0 };
276
277        // fn fib_tail(n, a, b) {
278        //     if n <= 0 { return a; }
279        //     return fib_tail(n - 1, b, a + b);
280        // }
281        let n_ident = Ident {
282            name: "n".to_string(),
283            evidentiality: None,
284            affect: None,
285            span: span.clone(),
286        };
287        let a_ident = Ident {
288            name: "a".to_string(),
289            evidentiality: None,
290            affect: None,
291            span: span.clone(),
292        };
293        let b_ident = Ident {
294            name: "b".to_string(),
295            evidentiality: None,
296            affect: None,
297            span: span.clone(),
298        };
299
300        let params = vec![
301            Param {
302                pattern: Pattern::Ident {
303                    mutable: false,
304                    name: n_ident.clone(),
305                    evidentiality: None,
306                },
307                ty: TypeExpr::Infer,
308            },
309            Param {
310                pattern: Pattern::Ident {
311                    mutable: false,
312                    name: a_ident.clone(),
313                    evidentiality: None,
314                },
315                ty: TypeExpr::Infer,
316            },
317            Param {
318                pattern: Pattern::Ident {
319                    mutable: false,
320                    name: b_ident.clone(),
321                    evidentiality: None,
322                },
323                ty: TypeExpr::Infer,
324            },
325        ];
326
327        // Condition: n <= 0
328        let condition = Expr::Binary {
329            op: BinOp::Le,
330            left: Box::new(Expr::Path(TypePath {
331                segments: vec![PathSegment {
332                    ident: n_ident.clone(),
333                    generics: None,
334                }],
335            })),
336            right: Box::new(Expr::Literal(Literal::Int {
337                value: "0".to_string(),
338                base: NumBase::Decimal,
339                suffix: None,
340            })),
341        };
342
343        // Then branch: return a
344        let then_branch = Block {
345            stmts: vec![],
346            expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347                TypePath {
348                    segments: vec![PathSegment {
349                        ident: a_ident.clone(),
350                        generics: None,
351                    }],
352                },
353            )))))),
354        };
355
356        // Recursive call: fib_tail(n - 1, b, a + b)
357        let recursive_call = Expr::Call {
358            func: Box::new(Expr::Path(TypePath {
359                segments: vec![PathSegment {
360                    ident: Ident {
361                        name: name.to_string(),
362                        evidentiality: None,
363                        affect: None,
364                        span: span.clone(),
365                    },
366                    generics: None,
367                }],
368            })),
369            args: vec![
370                // n - 1
371                Expr::Binary {
372                    op: BinOp::Sub,
373                    left: Box::new(Expr::Path(TypePath {
374                        segments: vec![PathSegment {
375                            ident: n_ident.clone(),
376                            generics: None,
377                        }],
378                    })),
379                    right: Box::new(Expr::Literal(Literal::Int {
380                        value: "1".to_string(),
381                        base: NumBase::Decimal,
382                        suffix: None,
383                    })),
384                },
385                // b
386                Expr::Path(TypePath {
387                    segments: vec![PathSegment {
388                        ident: b_ident.clone(),
389                        generics: None,
390                    }],
391                }),
392                // a + b
393                Expr::Binary {
394                    op: BinOp::Add,
395                    left: Box::new(Expr::Path(TypePath {
396                        segments: vec![PathSegment {
397                            ident: a_ident.clone(),
398                            generics: None,
399                        }],
400                    })),
401                    right: Box::new(Expr::Path(TypePath {
402                        segments: vec![PathSegment {
403                            ident: b_ident.clone(),
404                            generics: None,
405                        }],
406                    })),
407                },
408            ],
409        };
410
411        // if n <= 0 { return a; } else { return fib_tail(n - 1, b, a + b); }
412        let body = Block {
413            stmts: vec![],
414            expr: Some(Box::new(Expr::If {
415                condition: Box::new(condition),
416                then_branch,
417                else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418            })),
419        };
420
421        ast::Function {
422            visibility: Visibility::default(),
423            is_async: false,
424            attrs: FunctionAttrs::default(),
425            name: Ident {
426                name: name.to_string(),
427                evidentiality: None,
428                affect: None,
429                span: span.clone(),
430            },
431            aspect: None,
432            generics: None,
433            params,
434            return_type: None,
435            where_clause: None,
436            body: Some(body),
437        }
438    }
439
440    /// Generate the wrapper function that calls the helper
441    fn generate_fib_wrapper(
442        &self,
443        name: &str,
444        helper_name: &str,
445        param_name: &str,
446        original: &ast::Function,
447    ) -> ast::Function {
448        let span = Span { start: 0, end: 0 };
449
450        // fn fib(n) { return fib_tail(n, 0, 1); }
451        let call_helper = Expr::Call {
452            func: Box::new(Expr::Path(TypePath {
453                segments: vec![PathSegment {
454                    ident: Ident {
455                        name: helper_name.to_string(),
456                        evidentiality: None,
457                        affect: None,
458                        span: span.clone(),
459                    },
460                    generics: None,
461                }],
462            })),
463            args: vec![
464                // n
465                Expr::Path(TypePath {
466                    segments: vec![PathSegment {
467                        ident: Ident {
468                            name: param_name.to_string(),
469                            evidentiality: None,
470                            affect: None,
471                            span: span.clone(),
472                        },
473                        generics: None,
474                    }],
475                }),
476                // 0 (initial acc1)
477                Expr::Literal(Literal::Int {
478                    value: "0".to_string(),
479                    base: NumBase::Decimal,
480                    suffix: None,
481                }),
482                // 1 (initial acc2)
483                Expr::Literal(Literal::Int {
484                    value: "1".to_string(),
485                    base: NumBase::Decimal,
486                    suffix: None,
487                }),
488            ],
489        };
490
491        let body = Block {
492            stmts: vec![],
493            expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
494        };
495
496        ast::Function {
497            visibility: original.visibility,
498            is_async: original.is_async,
499            attrs: original.attrs.clone(),
500            name: Ident {
501                name: name.to_string(),
502                evidentiality: None,
503                affect: None,
504                span: span.clone(),
505            },
506            aspect: original.aspect,
507            generics: original.generics.clone(),
508            params: original.params.clone(),
509            return_type: original.return_type.clone(),
510            where_clause: original.where_clause.clone(),
511            body: Some(body),
512        }
513    }
514
515    // ========================================================================
516    // Memoization Transform
517    // ========================================================================
518
519    /// Try to transform a recursive function into a memoized version
520    /// Returns: (implementation_func, cache_init_func, wrapper_func)
521    #[allow(dead_code)]
522    fn try_memoize_transform(
523        &self,
524        func: &ast::Function,
525    ) -> Option<(ast::Function, ast::Function, ast::Function)> {
526        let param_count = func.params.len();
527        if param_count != 1 && param_count != 2 {
528            return None;
529        }
530
531        let span = Span { start: 0, end: 0 };
532        let func_name = &func.name.name;
533        let impl_name = format!("_memo_impl_{}", func_name);
534        let _cache_name = format!("_memo_cache_{}", func_name);
535        let init_name = format!("_memo_init_{}", func_name);
536
537        // Get parameter names
538        let param_names: Vec<String> = func
539            .params
540            .iter()
541            .filter_map(|p| {
542                if let Pattern::Ident { name, .. } = &p.pattern {
543                    Some(name.name.clone())
544                } else {
545                    None
546                }
547            })
548            .collect();
549
550        if param_names.len() != param_count {
551            return None;
552        }
553
554        // 1. Create implementation function (renamed original with calls redirected)
555        let impl_func = ast::Function {
556            visibility: Visibility::default(),
557            is_async: func.is_async,
558            attrs: func.attrs.clone(),
559            name: Ident {
560                name: impl_name.clone(),
561                evidentiality: None,
562                affect: None,
563                span: span.clone(),
564            },
565            aspect: func.aspect,
566            generics: func.generics.clone(),
567            params: func.params.clone(),
568            return_type: func.return_type.clone(),
569            where_clause: func.where_clause.clone(),
570            body: func
571                .body
572                .as_ref()
573                .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
574        };
575
576        // 2. Create cache initializer function
577        // This is a function that returns the cache, called once at the start
578        let cache_init_body = Block {
579            stmts: vec![],
580            expr: Some(Box::new(Expr::Call {
581                func: Box::new(Expr::Path(TypePath {
582                    segments: vec![PathSegment {
583                        ident: Ident {
584                            name: "sigil_memo_new".to_string(),
585                            evidentiality: None,
586                            affect: None,
587                            span: span.clone(),
588                        },
589                        generics: None,
590                    }],
591                })),
592                args: vec![Expr::Literal(Literal::Int {
593                    value: "65536".to_string(),
594                    base: NumBase::Decimal,
595                    suffix: None,
596                })],
597            })),
598        };
599
600        let cache_init_func = ast::Function {
601            visibility: Visibility::default(),
602            is_async: false,
603            attrs: FunctionAttrs::default(),
604            name: Ident {
605                name: init_name.clone(),
606                evidentiality: None,
607                affect: None,
608                span: span.clone(),
609            },
610            aspect: None,
611            generics: None,
612            params: vec![],
613            return_type: None,
614            where_clause: None,
615            body: Some(cache_init_body),
616        };
617
618        // 3. Create wrapper function
619        let wrapper_func = self.generate_memo_wrapper(func, &impl_name, &param_names);
620
621        Some((impl_func, cache_init_func, wrapper_func))
622    }
623
624    /// Generate the memoized wrapper function
625    #[allow(dead_code)]
626    fn generate_memo_wrapper(
627        &self,
628        original: &ast::Function,
629        impl_name: &str,
630        param_names: &[String],
631    ) -> ast::Function {
632        let span = Span { start: 0, end: 0 };
633        let param_count = param_names.len();
634
635        // Variable for cache - use a static-like pattern with lazy init
636        let cache_var = Ident {
637            name: "__cache".to_string(),
638            evidentiality: None,
639            affect: None,
640            span: span.clone(),
641        };
642        let result_var = Ident {
643            name: "__result".to_string(),
644            evidentiality: None,
645            affect: None,
646            span: span.clone(),
647        };
648        let cached_var = Ident {
649            name: "__cached".to_string(),
650            evidentiality: None,
651            affect: None,
652            span: span.clone(),
653        };
654
655        let mut stmts = vec![];
656
657        // let __cache = sigil_memo_new(65536);
658        stmts.push(Stmt::Let {
659            pattern: Pattern::Ident {
660                mutable: false,
661                name: cache_var.clone(),
662                evidentiality: None,
663            },
664            ty: None,
665            init: Some(Expr::Call {
666                func: Box::new(Expr::Path(TypePath {
667                    segments: vec![PathSegment {
668                        ident: Ident {
669                            name: "sigil_memo_new".to_string(),
670                            evidentiality: None,
671                            affect: None,
672                            span: span.clone(),
673                        },
674                        generics: None,
675                    }],
676                })),
677                args: vec![Expr::Literal(Literal::Int {
678                    value: "65536".to_string(),
679                    base: NumBase::Decimal,
680                    suffix: None,
681                })],
682            }),
683        });
684
685        // Check cache: let __cached = sigil_memo_get_N(__cache, params...);
686        let get_fn_name = if param_count == 1 {
687            "sigil_memo_get_1"
688        } else {
689            "sigil_memo_get_2"
690        };
691        let mut get_args = vec![Expr::Path(TypePath {
692            segments: vec![PathSegment {
693                ident: cache_var.clone(),
694                generics: None,
695            }],
696        })];
697        for name in param_names {
698            get_args.push(Expr::Path(TypePath {
699                segments: vec![PathSegment {
700                    ident: Ident {
701                        name: name.clone(),
702                        evidentiality: None,
703                        affect: None,
704                        span: span.clone(),
705                    },
706                    generics: None,
707                }],
708            }));
709        }
710
711        stmts.push(Stmt::Let {
712            pattern: Pattern::Ident {
713                mutable: false,
714                name: cached_var.clone(),
715                evidentiality: None,
716            },
717            ty: None,
718            init: Some(Expr::Call {
719                func: Box::new(Expr::Path(TypePath {
720                    segments: vec![PathSegment {
721                        ident: Ident {
722                            name: get_fn_name.to_string(),
723                            evidentiality: None,
724                            affect: None,
725                            span: span.clone(),
726                        },
727                        generics: None,
728                    }],
729                })),
730                args: get_args,
731            }),
732        });
733
734        // if __cached != -9223372036854775807 { return __cached; }
735        // Use a large negative number as sentinel (i64::MIN + 1 to avoid overflow issues)
736        let cache_check = Expr::If {
737            condition: Box::new(Expr::Binary {
738                op: BinOp::Ne,
739                left: Box::new(Expr::Path(TypePath {
740                    segments: vec![PathSegment {
741                        ident: cached_var.clone(),
742                        generics: None,
743                    }],
744                })),
745                right: Box::new(Expr::Unary {
746                    op: UnaryOp::Neg,
747                    expr: Box::new(Expr::Literal(Literal::Int {
748                        value: "9223372036854775807".to_string(),
749                        base: NumBase::Decimal,
750                        suffix: None,
751                    })),
752                }),
753            }),
754            then_branch: Block {
755                stmts: vec![],
756                expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
757                    TypePath {
758                        segments: vec![PathSegment {
759                            ident: cached_var.clone(),
760                            generics: None,
761                        }],
762                    },
763                )))))),
764            },
765            else_branch: None,
766        };
767        stmts.push(Stmt::Semi(cache_check));
768
769        // let __result = _memo_impl_func(params...);
770        let mut impl_args = vec![];
771        for name in param_names {
772            impl_args.push(Expr::Path(TypePath {
773                segments: vec![PathSegment {
774                    ident: Ident {
775                        name: name.clone(),
776                        evidentiality: None,
777                        affect: None,
778                        span: span.clone(),
779                    },
780                    generics: None,
781                }],
782            }));
783        }
784
785        stmts.push(Stmt::Let {
786            pattern: Pattern::Ident {
787                mutable: false,
788                name: result_var.clone(),
789                evidentiality: None,
790            },
791            ty: None,
792            init: Some(Expr::Call {
793                func: Box::new(Expr::Path(TypePath {
794                    segments: vec![PathSegment {
795                        ident: Ident {
796                            name: impl_name.to_string(),
797                            evidentiality: None,
798                            affect: None,
799                            span: span.clone(),
800                        },
801                        generics: None,
802                    }],
803                })),
804                args: impl_args,
805            }),
806        });
807
808        // sigil_memo_set_N(__cache, params..., __result);
809        let set_fn_name = if param_count == 1 {
810            "sigil_memo_set_1"
811        } else {
812            "sigil_memo_set_2"
813        };
814        let mut set_args = vec![Expr::Path(TypePath {
815            segments: vec![PathSegment {
816                ident: cache_var.clone(),
817                generics: None,
818            }],
819        })];
820        for name in param_names {
821            set_args.push(Expr::Path(TypePath {
822                segments: vec![PathSegment {
823                    ident: Ident {
824                        name: name.clone(),
825                        evidentiality: None,
826                        affect: None,
827                        span: span.clone(),
828                    },
829                    generics: None,
830                }],
831            }));
832        }
833        set_args.push(Expr::Path(TypePath {
834            segments: vec![PathSegment {
835                ident: result_var.clone(),
836                generics: None,
837            }],
838        }));
839
840        stmts.push(Stmt::Semi(Expr::Call {
841            func: Box::new(Expr::Path(TypePath {
842                segments: vec![PathSegment {
843                    ident: Ident {
844                        name: set_fn_name.to_string(),
845                        evidentiality: None,
846                        affect: None,
847                        span: span.clone(),
848                    },
849                    generics: None,
850                }],
851            })),
852            args: set_args,
853        }));
854
855        // return __result;
856        let body = Block {
857            stmts,
858            expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
859                TypePath {
860                    segments: vec![PathSegment {
861                        ident: result_var.clone(),
862                        generics: None,
863                    }],
864                },
865            )))))),
866        };
867
868        ast::Function {
869            visibility: original.visibility,
870            is_async: original.is_async,
871            attrs: original.attrs.clone(),
872            name: original.name.clone(),
873            aspect: original.aspect,
874            generics: original.generics.clone(),
875            params: original.params.clone(),
876            return_type: original.return_type.clone(),
877            where_clause: original.where_clause.clone(),
878            body: Some(body),
879        }
880    }
881
882    /// Redirect all recursive calls in a block to call the original wrapper (for memoization)
883    #[allow(dead_code)]
884    fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
885        // For memoization, we keep the calls as-is since they'll go through the wrapper
886        block.clone()
887    }
888
889    /// Check if a function is recursive
890    fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
891        if let Some(body) = &func.body {
892            self.block_calls_function(name, body)
893        } else {
894            false
895        }
896    }
897
898    fn block_calls_function(&self, name: &str, block: &Block) -> bool {
899        for stmt in &block.stmts {
900            if self.stmt_calls_function(name, stmt) {
901                return true;
902            }
903        }
904        if let Some(expr) = &block.expr {
905            if self.expr_calls_function(name, expr) {
906                return true;
907            }
908        }
909        false
910    }
911
912    fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
913        match stmt {
914            Stmt::Let {
915                init: Some(expr), ..
916            } => self.expr_calls_function(name, expr),
917            Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
918            _ => false,
919        }
920    }
921
922    fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
923        match expr {
924            Expr::Call { func, args } => {
925                if let Expr::Path(path) = func.as_ref() {
926                    if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
927                        return true;
928                    }
929                }
930                args.iter().any(|a| self.expr_calls_function(name, a))
931            }
932            Expr::Binary { left, right, .. } => {
933                self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
934            }
935            Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
936            Expr::If {
937                condition,
938                then_branch,
939                else_branch,
940            } => {
941                self.expr_calls_function(name, condition)
942                    || self.block_calls_function(name, then_branch)
943                    || else_branch
944                        .as_ref()
945                        .map(|e| self.expr_calls_function(name, e))
946                        .unwrap_or(false)
947            }
948            Expr::While { condition, body } => {
949                self.expr_calls_function(name, condition) || self.block_calls_function(name, body)
950            }
951            Expr::Block(block) => self.block_calls_function(name, block),
952            Expr::Return(Some(e)) => self.expr_calls_function(name, e),
953            _ => false,
954        }
955    }
956
957    /// Optimize a single function
958    fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
959        // Reset CSE counter per function for cleaner variable names
960        self.cse_counter = 0;
961
962        let body = if let Some(body) = &func.body {
963            // Run passes based on optimization level
964            let optimized = match self.level {
965                OptLevel::None => body.clone(),
966                OptLevel::Basic => {
967                    let b = self.pass_constant_fold_block(body);
968                    self.pass_dead_code_block(&b)
969                }
970                OptLevel::Standard | OptLevel::Size => {
971                    let b = self.pass_constant_fold_block(body);
972                    let b = self.pass_inline_block(&b); // Function inlining
973                    let b = self.pass_strength_reduce_block(&b);
974                    let b = self.pass_licm_block(&b); // LICM pass
975                    let b = self.pass_cse_block(&b); // CSE pass
976                    let b = self.pass_dead_code_block(&b);
977                    self.pass_simplify_branches_block(&b)
978                }
979                OptLevel::Aggressive => {
980                    // Multiple iterations for fixed-point
981                    let mut b = body.clone();
982                    for _ in 0..3 {
983                        b = self.pass_constant_fold_block(&b);
984                        b = self.pass_inline_block(&b); // Function inlining
985                        b = self.pass_strength_reduce_block(&b);
986                        b = self.pass_loop_unroll_block(&b); // Loop unrolling
987                        b = self.pass_licm_block(&b); // LICM pass
988                        b = self.pass_cse_block(&b); // CSE pass
989                        b = self.pass_dead_code_block(&b);
990                        b = self.pass_simplify_branches_block(&b);
991                    }
992                    b
993                }
994            };
995            Some(optimized)
996        } else {
997            None
998        };
999
1000        ast::Function {
1001            visibility: func.visibility.clone(),
1002            is_async: func.is_async,
1003            attrs: func.attrs.clone(),
1004            name: func.name.clone(),
1005            aspect: func.aspect,
1006            generics: func.generics.clone(),
1007            params: func.params.clone(),
1008            return_type: func.return_type.clone(),
1009            where_clause: func.where_clause.clone(),
1010            body,
1011        }
1012    }
1013
1014    // ========================================================================
1015    // Pass: Constant Folding
1016    // ========================================================================
1017
1018    fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1019        let stmts = block
1020            .stmts
1021            .iter()
1022            .map(|s| self.pass_constant_fold_stmt(s))
1023            .collect();
1024        let expr = block
1025            .expr
1026            .as_ref()
1027            .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1028        Block { stmts, expr }
1029    }
1030
1031    fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1032        match stmt {
1033            Stmt::Let {
1034                pattern, ty, init, ..
1035            } => Stmt::Let {
1036                pattern: pattern.clone(),
1037                ty: ty.clone(),
1038                init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1039            },
1040            Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1041            Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1042            Stmt::Item(item) => Stmt::Item(item.clone()),
1043        }
1044    }
1045
1046    fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1047        match expr {
1048            Expr::Binary { op, left, right } => {
1049                let left = Box::new(self.pass_constant_fold_expr(left));
1050                let right = Box::new(self.pass_constant_fold_expr(right));
1051
1052                // Try to fold
1053                if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1054                    if let Some(result) = self.fold_binary(op.clone(), l, r) {
1055                        self.stats.constants_folded += 1;
1056                        return Expr::Literal(Literal::Int {
1057                            value: result.to_string(),
1058                            base: NumBase::Decimal,
1059                            suffix: None,
1060                        });
1061                    }
1062                }
1063
1064                Expr::Binary {
1065                    op: op.clone(),
1066                    left,
1067                    right,
1068                }
1069            }
1070            Expr::Unary { op, expr: inner } => {
1071                let inner = Box::new(self.pass_constant_fold_expr(inner));
1072
1073                if let Some(v) = self.as_int(&inner) {
1074                    if let Some(result) = self.fold_unary(*op, v) {
1075                        self.stats.constants_folded += 1;
1076                        return Expr::Literal(Literal::Int {
1077                            value: result.to_string(),
1078                            base: NumBase::Decimal,
1079                            suffix: None,
1080                        });
1081                    }
1082                }
1083
1084                Expr::Unary {
1085                    op: *op,
1086                    expr: inner,
1087                }
1088            }
1089            Expr::If {
1090                condition,
1091                then_branch,
1092                else_branch,
1093            } => {
1094                let condition = Box::new(self.pass_constant_fold_expr(condition));
1095                let then_branch = self.pass_constant_fold_block(then_branch);
1096                let else_branch = else_branch
1097                    .as_ref()
1098                    .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1099
1100                // Fold constant conditions
1101                if let Some(cond) = self.as_bool(&condition) {
1102                    self.stats.branches_simplified += 1;
1103                    if cond {
1104                        return Expr::Block(then_branch);
1105                    } else if let Some(else_expr) = else_branch {
1106                        return *else_expr;
1107                    } else {
1108                        return Expr::Literal(Literal::Bool(false));
1109                    }
1110                }
1111
1112                Expr::If {
1113                    condition,
1114                    then_branch,
1115                    else_branch,
1116                }
1117            }
1118            Expr::While { condition, body } => {
1119                let condition = Box::new(self.pass_constant_fold_expr(condition));
1120                let body = self.pass_constant_fold_block(body);
1121
1122                // while false { ... } -> nothing
1123                if let Some(false) = self.as_bool(&condition) {
1124                    self.stats.branches_simplified += 1;
1125                    return Expr::Block(Block {
1126                        stmts: vec![],
1127                        expr: None,
1128                    });
1129                }
1130
1131                Expr::While { condition, body }
1132            }
1133            Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1134            Expr::Call { func, args } => {
1135                let args = args
1136                    .iter()
1137                    .map(|a| self.pass_constant_fold_expr(a))
1138                    .collect();
1139                Expr::Call {
1140                    func: func.clone(),
1141                    args,
1142                }
1143            }
1144            Expr::Return(e) => Expr::Return(
1145                e.as_ref()
1146                    .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1147            ),
1148            Expr::Assign { target, value } => {
1149                let value = Box::new(self.pass_constant_fold_expr(value));
1150                Expr::Assign {
1151                    target: target.clone(),
1152                    value,
1153                }
1154            }
1155            Expr::Index { expr: e, index } => {
1156                let e = Box::new(self.pass_constant_fold_expr(e));
1157                let index = Box::new(self.pass_constant_fold_expr(index));
1158                Expr::Index { expr: e, index }
1159            }
1160            Expr::Array(elements) => {
1161                let elements = elements
1162                    .iter()
1163                    .map(|e| self.pass_constant_fold_expr(e))
1164                    .collect();
1165                Expr::Array(elements)
1166            }
1167            other => other.clone(),
1168        }
1169    }
1170
1171    fn as_int(&self, expr: &Expr) -> Option<i64> {
1172        match expr {
1173            Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1174            Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1175            _ => None,
1176        }
1177    }
1178
1179    fn as_bool(&self, expr: &Expr) -> Option<bool> {
1180        match expr {
1181            Expr::Literal(Literal::Bool(b)) => Some(*b),
1182            Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1183            _ => None,
1184        }
1185    }
1186
1187    fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1188        match op {
1189            BinOp::Add => Some(l.wrapping_add(r)),
1190            BinOp::Sub => Some(l.wrapping_sub(r)),
1191            BinOp::Mul => Some(l.wrapping_mul(r)),
1192            BinOp::Div if r != 0 => Some(l / r),
1193            BinOp::Rem if r != 0 => Some(l % r),
1194            BinOp::BitAnd => Some(l & r),
1195            BinOp::BitOr => Some(l | r),
1196            BinOp::BitXor => Some(l ^ r),
1197            BinOp::Shl => Some(l << (r & 63)),
1198            BinOp::Shr => Some(l >> (r & 63)),
1199            BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1200            BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1201            BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1202            BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1203            BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1204            BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1205            BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1206            BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1207            _ => None,
1208        }
1209    }
1210
1211    fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1212        match op {
1213            UnaryOp::Neg => Some(-v),
1214            UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1215            _ => None,
1216        }
1217    }
1218
1219    // ========================================================================
1220    // Pass: Strength Reduction
1221    // ========================================================================
1222
1223    fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1224        let stmts = block
1225            .stmts
1226            .iter()
1227            .map(|s| self.pass_strength_reduce_stmt(s))
1228            .collect();
1229        let expr = block
1230            .expr
1231            .as_ref()
1232            .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1233        Block { stmts, expr }
1234    }
1235
1236    fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1237        match stmt {
1238            Stmt::Let {
1239                pattern, ty, init, ..
1240            } => Stmt::Let {
1241                pattern: pattern.clone(),
1242                ty: ty.clone(),
1243                init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1244            },
1245            Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1246            Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1247            Stmt::Item(item) => Stmt::Item(item.clone()),
1248        }
1249    }
1250
1251    fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1252        match expr {
1253            Expr::Binary { op, left, right } => {
1254                let left = Box::new(self.pass_strength_reduce_expr(left));
1255                let right = Box::new(self.pass_strength_reduce_expr(right));
1256
1257                // x * 2 -> x << 1, x * 4 -> x << 2, etc.
1258                if *op == BinOp::Mul {
1259                    if let Some(n) = self.as_int(&right) {
1260                        if n > 0 && (n as u64).is_power_of_two() {
1261                            self.stats.strength_reductions += 1;
1262                            let shift = (n as u64).trailing_zeros() as i64;
1263                            return Expr::Binary {
1264                                op: BinOp::Shl,
1265                                left,
1266                                right: Box::new(Expr::Literal(Literal::Int {
1267                                    value: shift.to_string(),
1268                                    base: NumBase::Decimal,
1269                                    suffix: None,
1270                                })),
1271                            };
1272                        }
1273                    }
1274                    if let Some(n) = self.as_int(&left) {
1275                        if n > 0 && (n as u64).is_power_of_two() {
1276                            self.stats.strength_reductions += 1;
1277                            let shift = (n as u64).trailing_zeros() as i64;
1278                            return Expr::Binary {
1279                                op: BinOp::Shl,
1280                                left: right,
1281                                right: Box::new(Expr::Literal(Literal::Int {
1282                                    value: shift.to_string(),
1283                                    base: NumBase::Decimal,
1284                                    suffix: None,
1285                                })),
1286                            };
1287                        }
1288                    }
1289                }
1290
1291                // x + 0 -> x, x - 0 -> x, x * 1 -> x, x / 1 -> x
1292                if let Some(n) = self.as_int(&right) {
1293                    match (op, n) {
1294                        (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1295                        | (BinOp::Mul | BinOp::Div, 1)
1296                        | (BinOp::Shl | BinOp::Shr, 0) => {
1297                            self.stats.strength_reductions += 1;
1298                            return *left;
1299                        }
1300                        (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1301                            self.stats.strength_reductions += 1;
1302                            return Expr::Literal(Literal::Int {
1303                                value: "0".to_string(),
1304                                base: NumBase::Decimal,
1305                                suffix: None,
1306                            });
1307                        }
1308                        _ => {}
1309                    }
1310                }
1311
1312                // 0 + x -> x, 1 * x -> x
1313                if let Some(n) = self.as_int(&left) {
1314                    match (op, n) {
1315                        (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1316                            self.stats.strength_reductions += 1;
1317                            return *right;
1318                        }
1319                        (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1320                            self.stats.strength_reductions += 1;
1321                            return Expr::Literal(Literal::Int {
1322                                value: "0".to_string(),
1323                                base: NumBase::Decimal,
1324                                suffix: None,
1325                            });
1326                        }
1327                        _ => {}
1328                    }
1329                }
1330
1331                Expr::Binary {
1332                    op: op.clone(),
1333                    left,
1334                    right,
1335                }
1336            }
1337            Expr::Unary { op, expr: inner } => {
1338                let inner = Box::new(self.pass_strength_reduce_expr(inner));
1339
1340                // --x -> x
1341                if *op == UnaryOp::Neg {
1342                    if let Expr::Unary {
1343                        op: UnaryOp::Neg,
1344                        expr: inner2,
1345                    } = inner.as_ref()
1346                    {
1347                        self.stats.strength_reductions += 1;
1348                        return *inner2.clone();
1349                    }
1350                }
1351
1352                // !!x -> x (for booleans)
1353                if *op == UnaryOp::Not {
1354                    if let Expr::Unary {
1355                        op: UnaryOp::Not,
1356                        expr: inner2,
1357                    } = inner.as_ref()
1358                    {
1359                        self.stats.strength_reductions += 1;
1360                        return *inner2.clone();
1361                    }
1362                }
1363
1364                Expr::Unary {
1365                    op: *op,
1366                    expr: inner,
1367                }
1368            }
1369            Expr::If {
1370                condition,
1371                then_branch,
1372                else_branch,
1373            } => {
1374                let condition = Box::new(self.pass_strength_reduce_expr(condition));
1375                let then_branch = self.pass_strength_reduce_block(then_branch);
1376                let else_branch = else_branch
1377                    .as_ref()
1378                    .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1379                Expr::If {
1380                    condition,
1381                    then_branch,
1382                    else_branch,
1383                }
1384            }
1385            Expr::While { condition, body } => {
1386                let condition = Box::new(self.pass_strength_reduce_expr(condition));
1387                let body = self.pass_strength_reduce_block(body);
1388                Expr::While { condition, body }
1389            }
1390            Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1391            Expr::Call { func, args } => {
1392                let args = args
1393                    .iter()
1394                    .map(|a| self.pass_strength_reduce_expr(a))
1395                    .collect();
1396                Expr::Call {
1397                    func: func.clone(),
1398                    args,
1399                }
1400            }
1401            Expr::Return(e) => Expr::Return(
1402                e.as_ref()
1403                    .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1404            ),
1405            Expr::Assign { target, value } => {
1406                let value = Box::new(self.pass_strength_reduce_expr(value));
1407                Expr::Assign {
1408                    target: target.clone(),
1409                    value,
1410                }
1411            }
1412            other => other.clone(),
1413        }
1414    }
1415
1416    // ========================================================================
1417    // Pass: Dead Code Elimination
1418    // ========================================================================
1419
1420    fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1421        // Remove statements after a return
1422        let mut stmts = Vec::new();
1423        let mut found_return = false;
1424
1425        for stmt in &block.stmts {
1426            if found_return {
1427                self.stats.dead_code_eliminated += 1;
1428                continue;
1429            }
1430            let stmt = self.pass_dead_code_stmt(stmt);
1431            if self.stmt_returns(&stmt) {
1432                found_return = true;
1433            }
1434            stmts.push(stmt);
1435        }
1436
1437        // If we found a return, the trailing expression is dead
1438        let expr = if found_return {
1439            if block.expr.is_some() {
1440                self.stats.dead_code_eliminated += 1;
1441            }
1442            None
1443        } else {
1444            block
1445                .expr
1446                .as_ref()
1447                .map(|e| Box::new(self.pass_dead_code_expr(e)))
1448        };
1449
1450        Block { stmts, expr }
1451    }
1452
1453    fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1454        match stmt {
1455            Stmt::Let {
1456                pattern, ty, init, ..
1457            } => Stmt::Let {
1458                pattern: pattern.clone(),
1459                ty: ty.clone(),
1460                init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1461            },
1462            Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1463            Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1464            Stmt::Item(item) => Stmt::Item(item.clone()),
1465        }
1466    }
1467
1468    fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1469        match expr {
1470            Expr::If {
1471                condition,
1472                then_branch,
1473                else_branch,
1474            } => {
1475                let condition = Box::new(self.pass_dead_code_expr(condition));
1476                let then_branch = self.pass_dead_code_block(then_branch);
1477                let else_branch = else_branch
1478                    .as_ref()
1479                    .map(|e| Box::new(self.pass_dead_code_expr(e)));
1480                Expr::If {
1481                    condition,
1482                    then_branch,
1483                    else_branch,
1484                }
1485            }
1486            Expr::While { condition, body } => {
1487                let condition = Box::new(self.pass_dead_code_expr(condition));
1488                let body = self.pass_dead_code_block(body);
1489                Expr::While { condition, body }
1490            }
1491            Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1492            other => other.clone(),
1493        }
1494    }
1495
1496    fn stmt_returns(&self, stmt: &Stmt) -> bool {
1497        match stmt {
1498            Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1499            _ => false,
1500        }
1501    }
1502
1503    fn expr_returns(&self, expr: &Expr) -> bool {
1504        match expr {
1505            Expr::Return(_) => true,
1506            Expr::Block(block) => {
1507                block.stmts.iter().any(|s| self.stmt_returns(s))
1508                    || block
1509                        .expr
1510                        .as_ref()
1511                        .map(|e| self.expr_returns(e))
1512                        .unwrap_or(false)
1513            }
1514            _ => false,
1515        }
1516    }
1517
1518    // ========================================================================
1519    // Pass: Branch Simplification
1520    // ========================================================================
1521
1522    fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1523        let stmts = block
1524            .stmts
1525            .iter()
1526            .map(|s| self.pass_simplify_branches_stmt(s))
1527            .collect();
1528        let expr = block
1529            .expr
1530            .as_ref()
1531            .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1532        Block { stmts, expr }
1533    }
1534
1535    fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1536        match stmt {
1537            Stmt::Let {
1538                pattern, ty, init, ..
1539            } => Stmt::Let {
1540                pattern: pattern.clone(),
1541                ty: ty.clone(),
1542                init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1543            },
1544            Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1545            Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1546            Stmt::Item(item) => Stmt::Item(item.clone()),
1547        }
1548    }
1549
1550    fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1551        match expr {
1552            Expr::If {
1553                condition,
1554                then_branch,
1555                else_branch,
1556            } => {
1557                let condition = Box::new(self.pass_simplify_branches_expr(condition));
1558                let then_branch = self.pass_simplify_branches_block(then_branch);
1559                let else_branch = else_branch
1560                    .as_ref()
1561                    .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1562
1563                // if !cond { a } else { b } -> if cond { b } else { a }
1564                if let Expr::Unary {
1565                    op: UnaryOp::Not,
1566                    expr: inner,
1567                } = condition.as_ref()
1568                {
1569                    if let Some(else_expr) = &else_branch {
1570                        self.stats.branches_simplified += 1;
1571                        let new_else = Some(Box::new(Expr::Block(then_branch)));
1572                        let new_then = match else_expr.as_ref() {
1573                            Expr::Block(b) => b.clone(),
1574                            other => Block {
1575                                stmts: vec![],
1576                                expr: Some(Box::new(other.clone())),
1577                            },
1578                        };
1579                        return Expr::If {
1580                            condition: inner.clone(),
1581                            then_branch: new_then,
1582                            else_branch: new_else,
1583                        };
1584                    }
1585                }
1586
1587                Expr::If {
1588                    condition,
1589                    then_branch,
1590                    else_branch,
1591                }
1592            }
1593            Expr::While { condition, body } => {
1594                let condition = Box::new(self.pass_simplify_branches_expr(condition));
1595                let body = self.pass_simplify_branches_block(body);
1596                Expr::While { condition, body }
1597            }
1598            Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1599            Expr::Binary { op, left, right } => {
1600                let left = Box::new(self.pass_simplify_branches_expr(left));
1601                let right = Box::new(self.pass_simplify_branches_expr(right));
1602                Expr::Binary {
1603                    op: op.clone(),
1604                    left,
1605                    right,
1606                }
1607            }
1608            Expr::Unary { op, expr: inner } => {
1609                let inner = Box::new(self.pass_simplify_branches_expr(inner));
1610                Expr::Unary {
1611                    op: *op,
1612                    expr: inner,
1613                }
1614            }
1615            Expr::Call { func, args } => {
1616                let args = args
1617                    .iter()
1618                    .map(|a| self.pass_simplify_branches_expr(a))
1619                    .collect();
1620                Expr::Call {
1621                    func: func.clone(),
1622                    args,
1623                }
1624            }
1625            Expr::Return(e) => Expr::Return(
1626                e.as_ref()
1627                    .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1628            ),
1629            other => other.clone(),
1630        }
1631    }
1632
1633    // ========================================================================
1634    // Pass: Function Inlining
1635    // ========================================================================
1636
1637    /// Check if a function is small enough to inline
1638    fn should_inline(&self, func: &ast::Function) -> bool {
1639        // Don't inline recursive functions
1640        if self.recursive_functions.contains(&func.name.name) {
1641            return false;
1642        }
1643
1644        // Count the number of statements/expressions in the body
1645        if let Some(body) = &func.body {
1646            let stmt_count = self.count_stmts_in_block(body);
1647            // Inline functions with 10 or fewer statements
1648            stmt_count <= 10
1649        } else {
1650            false
1651        }
1652    }
1653
1654    /// Count statements in a block (for inlining heuristics)
1655    fn count_stmts_in_block(&self, block: &Block) -> usize {
1656        let mut count = block.stmts.len();
1657        if block.expr.is_some() {
1658            count += 1;
1659        }
1660        // Also count nested blocks in if/while
1661        for stmt in &block.stmts {
1662            count += self.count_stmts_in_stmt(stmt);
1663        }
1664        count
1665    }
1666
1667    fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1668        match stmt {
1669            Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1670            Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1671            _ => 0,
1672        }
1673    }
1674
1675    fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1676        match expr {
1677            Expr::If {
1678                then_branch,
1679                else_branch,
1680                ..
1681            } => {
1682                let mut count = self.count_stmts_in_block(then_branch);
1683                if let Some(else_expr) = else_branch {
1684                    count += self.count_stmts_in_expr(else_expr);
1685                }
1686                count
1687            }
1688            Expr::While { body, .. } => self.count_stmts_in_block(body),
1689            Expr::Block(block) => self.count_stmts_in_block(block),
1690            _ => 0,
1691        }
1692    }
1693
1694    /// Inline function call by substituting the body with parameters replaced
1695    fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1696        let body = func.body.as_ref()?;
1697
1698        // Build parameter to argument mapping
1699        let mut param_map: HashMap<String, Expr> = HashMap::new();
1700        for (param, arg) in func.params.iter().zip(args.iter()) {
1701            if let Pattern::Ident { name, .. } = &param.pattern {
1702                param_map.insert(name.name.clone(), arg.clone());
1703            }
1704        }
1705
1706        // Substitute parameters in the function body
1707        let inlined_body = self.substitute_params_in_block(body, &param_map);
1708
1709        self.stats.functions_inlined += 1;
1710
1711        // If the body has a final expression, return it
1712        // If not, wrap in a block
1713        if inlined_body.stmts.is_empty() {
1714            if let Some(expr) = inlined_body.expr {
1715                // If it's a return, unwrap it
1716                if let Expr::Return(Some(inner)) = expr.as_ref() {
1717                    return Some(inner.as_ref().clone());
1718                }
1719                return Some(*expr);
1720            }
1721        }
1722
1723        Some(Expr::Block(inlined_body))
1724    }
1725
1726    /// Substitute parameter references with argument expressions
1727    fn substitute_params_in_block(
1728        &self,
1729        block: &Block,
1730        param_map: &HashMap<String, Expr>,
1731    ) -> Block {
1732        let stmts = block
1733            .stmts
1734            .iter()
1735            .map(|s| self.substitute_params_in_stmt(s, param_map))
1736            .collect();
1737        let expr = block
1738            .expr
1739            .as_ref()
1740            .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1741        Block { stmts, expr }
1742    }
1743
1744    fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1745        match stmt {
1746            Stmt::Let { pattern, ty, init } => Stmt::Let {
1747                pattern: pattern.clone(),
1748                ty: ty.clone(),
1749                init: init
1750                    .as_ref()
1751                    .map(|e| self.substitute_params_in_expr(e, param_map)),
1752            },
1753            Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1754            Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1755            Stmt::Item(item) => Stmt::Item(item.clone()),
1756        }
1757    }
1758
1759    fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1760        match expr {
1761            Expr::Path(path) => {
1762                // Check if this is a parameter reference
1763                if path.segments.len() == 1 {
1764                    let name = &path.segments[0].ident.name;
1765                    if let Some(arg) = param_map.get(name) {
1766                        return arg.clone();
1767                    }
1768                }
1769                expr.clone()
1770            }
1771            Expr::Binary { op, left, right } => Expr::Binary {
1772                op: op.clone(),
1773                left: Box::new(self.substitute_params_in_expr(left, param_map)),
1774                right: Box::new(self.substitute_params_in_expr(right, param_map)),
1775            },
1776            Expr::Unary { op, expr: inner } => Expr::Unary {
1777                op: *op,
1778                expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1779            },
1780            Expr::If {
1781                condition,
1782                then_branch,
1783                else_branch,
1784            } => Expr::If {
1785                condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1786                then_branch: self.substitute_params_in_block(then_branch, param_map),
1787                else_branch: else_branch
1788                    .as_ref()
1789                    .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1790            },
1791            Expr::While { condition, body } => Expr::While {
1792                condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1793                body: self.substitute_params_in_block(body, param_map),
1794            },
1795            Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1796            Expr::Call { func, args } => Expr::Call {
1797                func: Box::new(self.substitute_params_in_expr(func, param_map)),
1798                args: args
1799                    .iter()
1800                    .map(|a| self.substitute_params_in_expr(a, param_map))
1801                    .collect(),
1802            },
1803            Expr::Return(e) => Expr::Return(
1804                e.as_ref()
1805                    .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1806            ),
1807            Expr::Assign { target, value } => Expr::Assign {
1808                target: target.clone(),
1809                value: Box::new(self.substitute_params_in_expr(value, param_map)),
1810            },
1811            Expr::Index { expr: e, index } => Expr::Index {
1812                expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1813                index: Box::new(self.substitute_params_in_expr(index, param_map)),
1814            },
1815            Expr::Array(elements) => Expr::Array(
1816                elements
1817                    .iter()
1818                    .map(|e| self.substitute_params_in_expr(e, param_map))
1819                    .collect(),
1820            ),
1821            other => other.clone(),
1822        }
1823    }
1824
1825    fn pass_inline_block(&mut self, block: &Block) -> Block {
1826        let stmts = block
1827            .stmts
1828            .iter()
1829            .map(|s| self.pass_inline_stmt(s))
1830            .collect();
1831        let expr = block
1832            .expr
1833            .as_ref()
1834            .map(|e| Box::new(self.pass_inline_expr(e)));
1835        Block { stmts, expr }
1836    }
1837
1838    fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1839        match stmt {
1840            Stmt::Let { pattern, ty, init } => Stmt::Let {
1841                pattern: pattern.clone(),
1842                ty: ty.clone(),
1843                init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1844            },
1845            Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1846            Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1847            Stmt::Item(item) => Stmt::Item(item.clone()),
1848        }
1849    }
1850
1851    fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1852        match expr {
1853            Expr::Call { func, args } => {
1854                // First, recursively inline arguments
1855                let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1856
1857                // Check if we can inline this call
1858                if let Expr::Path(path) = func.as_ref() {
1859                    if path.segments.len() == 1 {
1860                        let func_name = &path.segments[0].ident.name;
1861                        if let Some(target_func) = self.functions.get(func_name).cloned() {
1862                            if self.should_inline(&target_func)
1863                                && args.len() == target_func.params.len()
1864                            {
1865                                if let Some(inlined) = self.inline_call(&target_func, &args) {
1866                                    return inlined;
1867                                }
1868                            }
1869                        }
1870                    }
1871                }
1872
1873                Expr::Call {
1874                    func: func.clone(),
1875                    args,
1876                }
1877            }
1878            Expr::Binary { op, left, right } => Expr::Binary {
1879                op: op.clone(),
1880                left: Box::new(self.pass_inline_expr(left)),
1881                right: Box::new(self.pass_inline_expr(right)),
1882            },
1883            Expr::Unary { op, expr: inner } => Expr::Unary {
1884                op: *op,
1885                expr: Box::new(self.pass_inline_expr(inner)),
1886            },
1887            Expr::If {
1888                condition,
1889                then_branch,
1890                else_branch,
1891            } => Expr::If {
1892                condition: Box::new(self.pass_inline_expr(condition)),
1893                then_branch: self.pass_inline_block(then_branch),
1894                else_branch: else_branch
1895                    .as_ref()
1896                    .map(|e| Box::new(self.pass_inline_expr(e))),
1897            },
1898            Expr::While { condition, body } => Expr::While {
1899                condition: Box::new(self.pass_inline_expr(condition)),
1900                body: self.pass_inline_block(body),
1901            },
1902            Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
1903            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
1904            Expr::Assign { target, value } => Expr::Assign {
1905                target: target.clone(),
1906                value: Box::new(self.pass_inline_expr(value)),
1907            },
1908            Expr::Index { expr: e, index } => Expr::Index {
1909                expr: Box::new(self.pass_inline_expr(e)),
1910                index: Box::new(self.pass_inline_expr(index)),
1911            },
1912            Expr::Array(elements) => {
1913                Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
1914            }
1915            other => other.clone(),
1916        }
1917    }
1918
1919    // ========================================================================
1920    // Pass: Loop Unrolling
1921    // ========================================================================
1922
1923    /// Unroll small counted loops for better performance
1924    fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
1925        let stmts = block
1926            .stmts
1927            .iter()
1928            .map(|s| self.pass_loop_unroll_stmt(s))
1929            .collect();
1930        let expr = block
1931            .expr
1932            .as_ref()
1933            .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
1934        Block { stmts, expr }
1935    }
1936
1937    fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
1938        match stmt {
1939            Stmt::Let { pattern, ty, init } => Stmt::Let {
1940                pattern: pattern.clone(),
1941                ty: ty.clone(),
1942                init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
1943            },
1944            Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
1945            Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
1946            Stmt::Item(item) => Stmt::Item(item.clone()),
1947        }
1948    }
1949
1950    fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
1951        match expr {
1952            Expr::While { condition, body } => {
1953                // Try to unroll if this is a countable loop
1954                if let Some(unrolled) = self.try_unroll_loop(condition, body) {
1955                    self.stats.loops_optimized += 1;
1956                    return unrolled;
1957                }
1958                // Otherwise, just recurse
1959                Expr::While {
1960                    condition: Box::new(self.pass_loop_unroll_expr(condition)),
1961                    body: self.pass_loop_unroll_block(body),
1962                }
1963            }
1964            Expr::If {
1965                condition,
1966                then_branch,
1967                else_branch,
1968            } => Expr::If {
1969                condition: Box::new(self.pass_loop_unroll_expr(condition)),
1970                then_branch: self.pass_loop_unroll_block(then_branch),
1971                else_branch: else_branch
1972                    .as_ref()
1973                    .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
1974            },
1975            Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
1976            Expr::Binary { op, left, right } => Expr::Binary {
1977                op: *op,
1978                left: Box::new(self.pass_loop_unroll_expr(left)),
1979                right: Box::new(self.pass_loop_unroll_expr(right)),
1980            },
1981            Expr::Unary { op, expr: inner } => Expr::Unary {
1982                op: *op,
1983                expr: Box::new(self.pass_loop_unroll_expr(inner)),
1984            },
1985            Expr::Call { func, args } => Expr::Call {
1986                func: func.clone(),
1987                args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
1988            },
1989            Expr::Return(e) => {
1990                Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
1991            }
1992            Expr::Assign { target, value } => Expr::Assign {
1993                target: target.clone(),
1994                value: Box::new(self.pass_loop_unroll_expr(value)),
1995            },
1996            other => other.clone(),
1997        }
1998    }
1999
2000    /// Try to unroll a loop with known bounds
2001    /// Pattern: while i < CONST { body; i = i + 1; }
2002    fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2003        // Check for pattern: var < constant
2004        let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2005
2006        // Only unroll small loops (up to 8 iterations from 0)
2007        if upper_bound > 8 || upper_bound <= 0 {
2008            return None;
2009        }
2010
2011        // Check if body contains increment: loop_var = loop_var + 1
2012        if !self.body_has_simple_increment(&loop_var, body) {
2013            return None;
2014        }
2015
2016        // Check that loop body is small enough to unroll (max 5 statements)
2017        let stmt_count = body.stmts.len();
2018        if stmt_count > 5 {
2019            return None;
2020        }
2021
2022        // Generate unrolled body
2023        let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2024
2025        for i in 0..upper_bound {
2026            // For each iteration, substitute the loop variable with the constant
2027            let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2028
2029            // Add all statements except the increment
2030            for stmt in &substituted_body.stmts {
2031                if !self.is_increment_stmt(&loop_var, stmt) {
2032                    unrolled_stmts.push(stmt.clone());
2033                }
2034            }
2035        }
2036
2037        // Return unrolled block
2038        Some(Expr::Block(Block {
2039            stmts: unrolled_stmts,
2040            expr: None,
2041        }))
2042    }
2043
2044    /// Extract loop bounds from condition: var < constant
2045    fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2046        if let Expr::Binary {
2047            op: BinOp::Lt,
2048            left,
2049            right,
2050        } = condition
2051        {
2052            // Left should be a variable
2053            if let Expr::Path(path) = left.as_ref() {
2054                if path.segments.len() == 1 {
2055                    let var_name = path.segments[0].ident.name.clone();
2056                    // Right should be a constant
2057                    if let Some(bound) = self.as_int(right) {
2058                        return Some((var_name, bound));
2059                    }
2060                }
2061            }
2062        }
2063        None
2064    }
2065
2066    /// Check if body contains: loop_var = loop_var + 1
2067    fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2068        for stmt in &body.stmts {
2069            if self.is_increment_stmt(loop_var, stmt) {
2070                return true;
2071            }
2072        }
2073        false
2074    }
2075
2076    /// Check if statement is: var = var + 1
2077    fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2078        match stmt {
2079            Stmt::Semi(Expr::Assign { target, value })
2080            | Stmt::Expr(Expr::Assign { target, value }) => {
2081                // Target should be the loop variable
2082                if let Expr::Path(path) = target.as_ref() {
2083                    if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2084                        // Value should be var + 1
2085                        if let Expr::Binary {
2086                            op: BinOp::Add,
2087                            left,
2088                            right,
2089                        } = value.as_ref()
2090                        {
2091                            if let Expr::Path(lpath) = left.as_ref() {
2092                                if lpath.segments.len() == 1
2093                                    && lpath.segments[0].ident.name == var_name
2094                                {
2095                                    if let Some(1) = self.as_int(right) {
2096                                        return true;
2097                                    }
2098                                }
2099                            }
2100                        }
2101                    }
2102                }
2103                false
2104            }
2105            _ => false,
2106        }
2107    }
2108
2109    /// Substitute loop variable with constant value in block
2110    fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2111        let stmts = block
2112            .stmts
2113            .iter()
2114            .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2115            .collect();
2116        let expr = block
2117            .expr
2118            .as_ref()
2119            .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2120        Block { stmts, expr }
2121    }
2122
2123    fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2124        match stmt {
2125            Stmt::Let { pattern, ty, init } => Stmt::Let {
2126                pattern: pattern.clone(),
2127                ty: ty.clone(),
2128                init: init
2129                    .as_ref()
2130                    .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2131            },
2132            Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2133            Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2134            Stmt::Item(item) => Stmt::Item(item.clone()),
2135        }
2136    }
2137
2138    fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2139        match expr {
2140            Expr::Path(path) => {
2141                if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2142                    return Expr::Literal(Literal::Int {
2143                        value: value.to_string(),
2144                        base: NumBase::Decimal,
2145                        suffix: None,
2146                    });
2147                }
2148                expr.clone()
2149            }
2150            Expr::Binary { op, left, right } => Expr::Binary {
2151                op: *op,
2152                left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2153                right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2154            },
2155            Expr::Unary { op, expr: inner } => Expr::Unary {
2156                op: *op,
2157                expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2158            },
2159            Expr::Call { func, args } => Expr::Call {
2160                func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2161                args: args
2162                    .iter()
2163                    .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2164                    .collect(),
2165            },
2166            Expr::If {
2167                condition,
2168                then_branch,
2169                else_branch,
2170            } => Expr::If {
2171                condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2172                then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2173                else_branch: else_branch
2174                    .as_ref()
2175                    .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2176            },
2177            Expr::While { condition, body } => Expr::While {
2178                condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2179                body: self.substitute_loop_var_in_block(body, var_name, value),
2180            },
2181            Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2182            Expr::Return(e) => Expr::Return(
2183                e.as_ref()
2184                    .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2185            ),
2186            Expr::Assign { target, value: v } => Expr::Assign {
2187                target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
2188                value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2189            },
2190            Expr::Index { expr: e, index } => Expr::Index {
2191                expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2192                index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2193            },
2194            Expr::Array(elements) => Expr::Array(
2195                elements
2196                    .iter()
2197                    .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2198                    .collect(),
2199            ),
2200            other => other.clone(),
2201        }
2202    }
2203
2204    // ========================================================================
2205    // Pass: Loop Invariant Code Motion (LICM)
2206    // ========================================================================
2207
2208    /// Move loop-invariant computations out of loops
2209    fn pass_licm_block(&mut self, block: &Block) -> Block {
2210        let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2211        let expr = block
2212            .expr
2213            .as_ref()
2214            .map(|e| Box::new(self.pass_licm_expr(e)));
2215        Block { stmts, expr }
2216    }
2217
2218    fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2219        match stmt {
2220            Stmt::Let { pattern, ty, init } => Stmt::Let {
2221                pattern: pattern.clone(),
2222                ty: ty.clone(),
2223                init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2224            },
2225            Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2226            Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2227            Stmt::Item(item) => Stmt::Item(item.clone()),
2228        }
2229    }
2230
2231    fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2232        match expr {
2233            Expr::While { condition, body } => {
2234                // Find variables modified in the loop
2235                let mut modified_vars = HashSet::new();
2236                self.collect_modified_vars_block(body, &mut modified_vars);
2237
2238                // Also consider the loop condition might modify vars
2239                self.collect_modified_vars_expr(condition, &mut modified_vars);
2240
2241                // Find invariant expressions in the body
2242                let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2243                self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2244
2245                if invariant_exprs.is_empty() {
2246                    // No LICM opportunities, just recurse
2247                    return Expr::While {
2248                        condition: Box::new(self.pass_licm_expr(condition)),
2249                        body: self.pass_licm_block(body),
2250                    };
2251                }
2252
2253                // Create let bindings for invariant expressions before the loop
2254                let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2255                let mut substitution_map: HashMap<String, String> = HashMap::new();
2256
2257                for (original_key, invariant_expr) in &invariant_exprs {
2258                    let var_name = format!("__licm_{}", self.cse_counter);
2259                    self.cse_counter += 1;
2260
2261                    pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2262                    substitution_map.insert(original_key.clone(), var_name);
2263                    self.stats.loops_optimized += 1;
2264                }
2265
2266                // Replace invariant expressions in the loop body
2267                let new_body =
2268                    self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2269
2270                // Recurse into the modified loop
2271                let new_while = Expr::While {
2272                    condition: Box::new(self.pass_licm_expr(condition)),
2273                    body: self.pass_licm_block(&new_body),
2274                };
2275
2276                // Return block with pre-loop bindings + loop
2277                pre_loop_stmts.push(Stmt::Expr(new_while));
2278                Expr::Block(Block {
2279                    stmts: pre_loop_stmts,
2280                    expr: None,
2281                })
2282            }
2283            Expr::If {
2284                condition,
2285                then_branch,
2286                else_branch,
2287            } => Expr::If {
2288                condition: Box::new(self.pass_licm_expr(condition)),
2289                then_branch: self.pass_licm_block(then_branch),
2290                else_branch: else_branch
2291                    .as_ref()
2292                    .map(|e| Box::new(self.pass_licm_expr(e))),
2293            },
2294            Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2295            Expr::Binary { op, left, right } => Expr::Binary {
2296                op: *op,
2297                left: Box::new(self.pass_licm_expr(left)),
2298                right: Box::new(self.pass_licm_expr(right)),
2299            },
2300            Expr::Unary { op, expr: inner } => Expr::Unary {
2301                op: *op,
2302                expr: Box::new(self.pass_licm_expr(inner)),
2303            },
2304            Expr::Call { func, args } => Expr::Call {
2305                func: func.clone(),
2306                args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2307            },
2308            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2309            Expr::Assign { target, value } => Expr::Assign {
2310                target: target.clone(),
2311                value: Box::new(self.pass_licm_expr(value)),
2312            },
2313            other => other.clone(),
2314        }
2315    }
2316
2317    /// Collect all variables that are modified in a block
2318    fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2319        for stmt in &block.stmts {
2320            self.collect_modified_vars_stmt(stmt, modified);
2321        }
2322        if let Some(expr) = &block.expr {
2323            self.collect_modified_vars_expr(expr, modified);
2324        }
2325    }
2326
2327    fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2328        match stmt {
2329            Stmt::Let { pattern, init, .. } => {
2330                // Let bindings introduce new variables
2331                if let Pattern::Ident { name, .. } = pattern {
2332                    modified.insert(name.name.clone());
2333                }
2334                if let Some(e) = init {
2335                    self.collect_modified_vars_expr(e, modified);
2336                }
2337            }
2338            Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2339            _ => {}
2340        }
2341    }
2342
2343    fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2344        match expr {
2345            Expr::Assign { target, value } => {
2346                if let Expr::Path(path) = target.as_ref() {
2347                    if path.segments.len() == 1 {
2348                        modified.insert(path.segments[0].ident.name.clone());
2349                    }
2350                }
2351                self.collect_modified_vars_expr(value, modified);
2352            }
2353            Expr::Binary { left, right, .. } => {
2354                self.collect_modified_vars_expr(left, modified);
2355                self.collect_modified_vars_expr(right, modified);
2356            }
2357            Expr::Unary { expr: inner, .. } => {
2358                self.collect_modified_vars_expr(inner, modified);
2359            }
2360            Expr::If {
2361                condition,
2362                then_branch,
2363                else_branch,
2364            } => {
2365                self.collect_modified_vars_expr(condition, modified);
2366                self.collect_modified_vars_block(then_branch, modified);
2367                if let Some(e) = else_branch {
2368                    self.collect_modified_vars_expr(e, modified);
2369                }
2370            }
2371            Expr::While { condition, body } => {
2372                self.collect_modified_vars_expr(condition, modified);
2373                self.collect_modified_vars_block(body, modified);
2374            }
2375            Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2376            Expr::Call { args, .. } => {
2377                for arg in args {
2378                    self.collect_modified_vars_expr(arg, modified);
2379                }
2380            }
2381            Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2382            _ => {}
2383        }
2384    }
2385
2386    /// Find loop-invariant expressions in a block
2387    fn find_loop_invariants(
2388        &self,
2389        block: &Block,
2390        modified: &HashSet<String>,
2391        out: &mut Vec<(String, Expr)>,
2392    ) {
2393        for stmt in &block.stmts {
2394            self.find_loop_invariants_stmt(stmt, modified, out);
2395        }
2396        if let Some(expr) = &block.expr {
2397            self.find_loop_invariants_expr(expr, modified, out);
2398        }
2399    }
2400
2401    fn find_loop_invariants_stmt(
2402        &self,
2403        stmt: &Stmt,
2404        modified: &HashSet<String>,
2405        out: &mut Vec<(String, Expr)>,
2406    ) {
2407        match stmt {
2408            Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2409            Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2410            _ => {}
2411        }
2412    }
2413
2414    fn find_loop_invariants_expr(
2415        &self,
2416        expr: &Expr,
2417        modified: &HashSet<String>,
2418        out: &mut Vec<(String, Expr)>,
2419    ) {
2420        // First recurse into subexpressions
2421        match expr {
2422            Expr::Binary { left, right, .. } => {
2423                self.find_loop_invariants_expr(left, modified, out);
2424                self.find_loop_invariants_expr(right, modified, out);
2425            }
2426            Expr::Unary { expr: inner, .. } => {
2427                self.find_loop_invariants_expr(inner, modified, out);
2428            }
2429            Expr::Call { args, .. } => {
2430                for arg in args {
2431                    self.find_loop_invariants_expr(arg, modified, out);
2432                }
2433            }
2434            Expr::Index { expr: e, index } => {
2435                self.find_loop_invariants_expr(e, modified, out);
2436                self.find_loop_invariants_expr(index, modified, out);
2437            }
2438            _ => {}
2439        }
2440
2441        // Then check if this expression is loop-invariant and worth hoisting
2442        if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2443            let key = format!("{:?}", expr_hash(expr));
2444            // Check if we already have this exact expression
2445            if !out.iter().any(|(k, _)| k == &key) {
2446                out.push((key, expr.clone()));
2447            }
2448        }
2449    }
2450
2451    /// Check if an expression is loop-invariant (doesn't depend on modified variables)
2452    fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2453        match expr {
2454            Expr::Literal(_) => true,
2455            Expr::Path(path) => {
2456                if path.segments.len() == 1 {
2457                    !modified.contains(&path.segments[0].ident.name)
2458                } else {
2459                    true // Qualified paths are assumed invariant
2460                }
2461            }
2462            Expr::Binary { left, right, .. } => {
2463                self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2464            }
2465            Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2466            Expr::Index { expr: e, index } => {
2467                self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2468            }
2469            // Calls are not invariant (might have side effects)
2470            Expr::Call { .. } => false,
2471            // Other expressions are not invariant
2472            _ => false,
2473        }
2474    }
2475
2476    /// Replace invariant expressions with variable references
2477    fn replace_invariants_in_block(
2478        &self,
2479        block: &Block,
2480        invariants: &[(String, Expr)],
2481        subs: &HashMap<String, String>,
2482    ) -> Block {
2483        let stmts = block
2484            .stmts
2485            .iter()
2486            .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2487            .collect();
2488        let expr = block
2489            .expr
2490            .as_ref()
2491            .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2492        Block { stmts, expr }
2493    }
2494
2495    fn replace_invariants_in_stmt(
2496        &self,
2497        stmt: &Stmt,
2498        invariants: &[(String, Expr)],
2499        subs: &HashMap<String, String>,
2500    ) -> Stmt {
2501        match stmt {
2502            Stmt::Let { pattern, ty, init } => Stmt::Let {
2503                pattern: pattern.clone(),
2504                ty: ty.clone(),
2505                init: init
2506                    .as_ref()
2507                    .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2508            },
2509            Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2510            Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2511            Stmt::Item(item) => Stmt::Item(item.clone()),
2512        }
2513    }
2514
2515    fn replace_invariants_in_expr(
2516        &self,
2517        expr: &Expr,
2518        invariants: &[(String, Expr)],
2519        subs: &HashMap<String, String>,
2520    ) -> Expr {
2521        // Check if this expression matches an invariant
2522        let key = format!("{:?}", expr_hash(expr));
2523        for (inv_key, inv_expr) in invariants {
2524            if &key == inv_key && expr_eq(expr, inv_expr) {
2525                if let Some(var_name) = subs.get(inv_key) {
2526                    return Expr::Path(TypePath {
2527                        segments: vec![PathSegment {
2528                            ident: Ident {
2529                                name: var_name.clone(),
2530                                evidentiality: None,
2531                                affect: None,
2532                                span: Span { start: 0, end: 0 },
2533                            },
2534                            generics: None,
2535                        }],
2536                    });
2537                }
2538            }
2539        }
2540
2541        // Otherwise recurse
2542        match expr {
2543            Expr::Binary { op, left, right } => Expr::Binary {
2544                op: *op,
2545                left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2546                right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2547            },
2548            Expr::Unary { op, expr: inner } => Expr::Unary {
2549                op: *op,
2550                expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2551            },
2552            Expr::Call { func, args } => Expr::Call {
2553                func: func.clone(),
2554                args: args
2555                    .iter()
2556                    .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2557                    .collect(),
2558            },
2559            Expr::If {
2560                condition,
2561                then_branch,
2562                else_branch,
2563            } => Expr::If {
2564                condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2565                then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2566                else_branch: else_branch
2567                    .as_ref()
2568                    .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2569            },
2570            Expr::While { condition, body } => Expr::While {
2571                condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2572                body: self.replace_invariants_in_block(body, invariants, subs),
2573            },
2574            Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2575            Expr::Return(e) => Expr::Return(
2576                e.as_ref()
2577                    .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2578            ),
2579            Expr::Assign { target, value } => Expr::Assign {
2580                target: target.clone(),
2581                value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2582            },
2583            Expr::Index { expr: e, index } => Expr::Index {
2584                expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2585                index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2586            },
2587            other => other.clone(),
2588        }
2589    }
2590
2591    // ========================================================================
2592    // Pass: Common Subexpression Elimination (CSE)
2593    // ========================================================================
2594
2595    fn pass_cse_block(&mut self, block: &Block) -> Block {
2596        // Step 1: Collect all expressions in this block
2597        let mut collected = Vec::new();
2598        collect_exprs_from_block(block, &mut collected);
2599
2600        // Step 2: Count occurrences using hash + equality check
2601        let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2602        for ce in &collected {
2603            let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2604            // Check if this exact expression is already in the bucket
2605            let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2606            if !found {
2607                entry.push(ce.expr.clone());
2608            }
2609        }
2610
2611        // Count actual occurrences (need to count duplicates)
2612        let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2613        for ce in &collected {
2614            // Find or create entry for this expression
2615            let existing = occurrence_counts
2616                .iter_mut()
2617                .find(|(e, _)| expr_eq(e, &ce.expr));
2618            if let Some((_, count)) = existing {
2619                *count += 1;
2620            } else {
2621                occurrence_counts.push((ce.expr.clone(), 1));
2622            }
2623        }
2624
2625        // Step 3: Find expressions that occur 2+ times
2626        let candidates: Vec<Expr> = occurrence_counts
2627            .into_iter()
2628            .filter(|(_, count)| *count >= 2)
2629            .map(|(expr, _)| expr)
2630            .collect();
2631
2632        if candidates.is_empty() {
2633            // No CSE opportunities, just recurse into nested blocks
2634            return self.pass_cse_nested(block);
2635        }
2636
2637        // Step 4: Create let bindings for each candidate and replace occurrences
2638        let mut result_block = block.clone();
2639        let mut new_lets: Vec<Stmt> = Vec::new();
2640
2641        for expr in candidates {
2642            let var_name = format!("__cse_{}", self.cse_counter);
2643            self.cse_counter += 1;
2644
2645            // Create the let binding
2646            new_lets.push(make_cse_let(&var_name, expr.clone()));
2647
2648            // Replace all occurrences in the block
2649            result_block = replace_in_block(&result_block, &expr, &var_name);
2650
2651            self.stats.expressions_deduplicated += 1;
2652        }
2653
2654        // Step 5: Prepend the new let bindings to the block
2655        let mut final_stmts = new_lets;
2656        final_stmts.extend(result_block.stmts);
2657
2658        // Step 6: Recurse into nested blocks
2659        let result = Block {
2660            stmts: final_stmts,
2661            expr: result_block.expr,
2662        };
2663        self.pass_cse_nested(&result)
2664    }
2665
2666    /// Recurse CSE into nested blocks (if, while, block expressions)
2667    fn pass_cse_nested(&mut self, block: &Block) -> Block {
2668        let stmts = block
2669            .stmts
2670            .iter()
2671            .map(|stmt| self.pass_cse_stmt(stmt))
2672            .collect();
2673        let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2674        Block { stmts, expr }
2675    }
2676
2677    fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2678        match stmt {
2679            Stmt::Let { pattern, ty, init } => Stmt::Let {
2680                pattern: pattern.clone(),
2681                ty: ty.clone(),
2682                init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2683            },
2684            Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2685            Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2686            Stmt::Item(item) => Stmt::Item(item.clone()),
2687        }
2688    }
2689
2690    fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2691        match expr {
2692            Expr::If {
2693                condition,
2694                then_branch,
2695                else_branch,
2696            } => Expr::If {
2697                condition: Box::new(self.pass_cse_expr(condition)),
2698                then_branch: self.pass_cse_block(then_branch),
2699                else_branch: else_branch
2700                    .as_ref()
2701                    .map(|e| Box::new(self.pass_cse_expr(e))),
2702            },
2703            Expr::While { condition, body } => Expr::While {
2704                condition: Box::new(self.pass_cse_expr(condition)),
2705                body: self.pass_cse_block(body),
2706            },
2707            Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2708            Expr::Binary { op, left, right } => Expr::Binary {
2709                op: *op,
2710                left: Box::new(self.pass_cse_expr(left)),
2711                right: Box::new(self.pass_cse_expr(right)),
2712            },
2713            Expr::Unary { op, expr: inner } => Expr::Unary {
2714                op: *op,
2715                expr: Box::new(self.pass_cse_expr(inner)),
2716            },
2717            Expr::Call { func, args } => Expr::Call {
2718                func: func.clone(),
2719                args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2720            },
2721            Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2722            Expr::Assign { target, value } => Expr::Assign {
2723                target: target.clone(),
2724                value: Box::new(self.pass_cse_expr(value)),
2725            },
2726            other => other.clone(),
2727        }
2728    }
2729}
2730
2731// ============================================================================
2732// Common Subexpression Elimination (CSE) - Helper Functions
2733// ============================================================================
2734
2735/// Expr hash for CSE - identifies structurally equivalent expressions
2736fn expr_hash(expr: &Expr) -> u64 {
2737    use std::collections::hash_map::DefaultHasher;
2738    use std::hash::Hasher;
2739
2740    let mut hasher = DefaultHasher::new();
2741    expr_hash_recursive(expr, &mut hasher);
2742    hasher.finish()
2743}
2744
2745fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2746    use std::hash::Hash;
2747
2748    std::mem::discriminant(expr).hash(hasher);
2749
2750    match expr {
2751        Expr::Literal(lit) => match lit {
2752            Literal::Int { value, .. } => value.hash(hasher),
2753            Literal::Float { value, .. } => value.hash(hasher),
2754            Literal::String(s) => s.hash(hasher),
2755            Literal::Char(c) => c.hash(hasher),
2756            Literal::Bool(b) => b.hash(hasher),
2757            _ => {}
2758        },
2759        Expr::Path(path) => {
2760            for seg in &path.segments {
2761                seg.ident.name.hash(hasher);
2762            }
2763        }
2764        Expr::Binary { op, left, right } => {
2765            std::mem::discriminant(op).hash(hasher);
2766            expr_hash_recursive(left, hasher);
2767            expr_hash_recursive(right, hasher);
2768        }
2769        Expr::Unary { op, expr } => {
2770            std::mem::discriminant(op).hash(hasher);
2771            expr_hash_recursive(expr, hasher);
2772        }
2773        Expr::Call { func, args } => {
2774            expr_hash_recursive(func, hasher);
2775            args.len().hash(hasher);
2776            for arg in args {
2777                expr_hash_recursive(arg, hasher);
2778            }
2779        }
2780        Expr::Index { expr, index } => {
2781            expr_hash_recursive(expr, hasher);
2782            expr_hash_recursive(index, hasher);
2783        }
2784        _ => {}
2785    }
2786}
2787
2788/// Check if an expression is pure (no side effects)
2789fn is_pure_expr(expr: &Expr) -> bool {
2790    match expr {
2791        Expr::Literal(_) => true,
2792        Expr::Path(_) => true,
2793        Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
2794        Expr::Unary { expr, .. } => is_pure_expr(expr),
2795        Expr::If {
2796            condition,
2797            then_branch,
2798            else_branch,
2799        } => {
2800            is_pure_expr(condition)
2801                && then_branch.stmts.is_empty()
2802                && then_branch
2803                    .expr
2804                    .as_ref()
2805                    .map(|e| is_pure_expr(e))
2806                    .unwrap_or(true)
2807                && else_branch
2808                    .as_ref()
2809                    .map(|e| is_pure_expr(e))
2810                    .unwrap_or(true)
2811        }
2812        Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
2813        Expr::Array(elements) => elements.iter().all(is_pure_expr),
2814        // These have side effects
2815        Expr::Call { .. } => false,
2816        Expr::Assign { .. } => false,
2817        Expr::Return(_) => false,
2818        _ => false,
2819    }
2820}
2821
2822/// Check if an expression is worth caching (complex enough)
2823fn is_cse_worthy(expr: &Expr) -> bool {
2824    match expr {
2825        // Simple expressions - not worth CSE overhead
2826        Expr::Literal(_) => false,
2827        Expr::Path(_) => false,
2828        // Binary operations are worth it
2829        Expr::Binary { .. } => true,
2830        // Unary operations might be worth it
2831        Expr::Unary { .. } => true,
2832        // Calls might be worth it if pure (but we can't know easily)
2833        Expr::Call { .. } => false,
2834        // Index operations are worth it
2835        Expr::Index { .. } => true,
2836        _ => false,
2837    }
2838}
2839
2840/// Check if two expressions are structurally equal
2841fn expr_eq(a: &Expr, b: &Expr) -> bool {
2842    match (a, b) {
2843        (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
2844            (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
2845            (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
2846            (Literal::String(sa), Literal::String(sb)) => sa == sb,
2847            (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
2848            (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
2849            _ => false,
2850        },
2851        (Expr::Path(pa), Expr::Path(pb)) => {
2852            pa.segments.len() == pb.segments.len()
2853                && pa
2854                    .segments
2855                    .iter()
2856                    .zip(&pb.segments)
2857                    .all(|(sa, sb)| sa.ident.name == sb.ident.name)
2858        }
2859        (
2860            Expr::Binary {
2861                op: oa,
2862                left: la,
2863                right: ra,
2864            },
2865            Expr::Binary {
2866                op: ob,
2867                left: lb,
2868                right: rb,
2869            },
2870        ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
2871        (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
2872            oa == ob && expr_eq(ea, eb)
2873        }
2874        (
2875            Expr::Index {
2876                expr: ea,
2877                index: ia,
2878            },
2879            Expr::Index {
2880                expr: eb,
2881                index: ib,
2882            },
2883        ) => expr_eq(ea, eb) && expr_eq(ia, ib),
2884        (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
2885            expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
2886        }
2887        _ => false,
2888    }
2889}
2890
2891/// Collected expression with its location info
2892#[derive(Clone)]
2893struct CollectedExpr {
2894    expr: Expr,
2895    hash: u64,
2896}
2897
2898/// Collect all CSE-worthy expressions from an expression tree
2899fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
2900    // First, recurse into subexpressions
2901    match expr {
2902        Expr::Binary { left, right, .. } => {
2903            collect_exprs_from_expr(left, out);
2904            collect_exprs_from_expr(right, out);
2905        }
2906        Expr::Unary { expr: inner, .. } => {
2907            collect_exprs_from_expr(inner, out);
2908        }
2909        Expr::Index { expr: e, index } => {
2910            collect_exprs_from_expr(e, out);
2911            collect_exprs_from_expr(index, out);
2912        }
2913        Expr::Call { func, args } => {
2914            collect_exprs_from_expr(func, out);
2915            for arg in args {
2916                collect_exprs_from_expr(arg, out);
2917            }
2918        }
2919        Expr::If {
2920            condition,
2921            then_branch,
2922            else_branch,
2923        } => {
2924            collect_exprs_from_expr(condition, out);
2925            collect_exprs_from_block(then_branch, out);
2926            if let Some(else_expr) = else_branch {
2927                collect_exprs_from_expr(else_expr, out);
2928            }
2929        }
2930        Expr::While { condition, body } => {
2931            collect_exprs_from_expr(condition, out);
2932            collect_exprs_from_block(body, out);
2933        }
2934        Expr::Block(block) => {
2935            collect_exprs_from_block(block, out);
2936        }
2937        Expr::Return(Some(e)) => {
2938            collect_exprs_from_expr(e, out);
2939        }
2940        Expr::Assign { value, .. } => {
2941            collect_exprs_from_expr(value, out);
2942        }
2943        Expr::Array(elements) => {
2944            for e in elements {
2945                collect_exprs_from_expr(e, out);
2946            }
2947        }
2948        _ => {}
2949    }
2950
2951    // Then, if this expression is CSE-worthy and pure, add it
2952    if is_cse_worthy(expr) && is_pure_expr(expr) {
2953        out.push(CollectedExpr {
2954            expr: expr.clone(),
2955            hash: expr_hash(expr),
2956        });
2957    }
2958}
2959
2960/// Collect expressions from a block
2961fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
2962    for stmt in &block.stmts {
2963        match stmt {
2964            Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
2965            Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
2966            _ => {}
2967        }
2968    }
2969    if let Some(e) = &block.expr {
2970        collect_exprs_from_expr(e, out);
2971    }
2972}
2973
2974/// Replace all occurrences of target expression with a variable reference
2975fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
2976    // Check if this expression matches the target
2977    if expr_eq(expr, target) {
2978        return Expr::Path(TypePath {
2979            segments: vec![PathSegment {
2980                ident: Ident {
2981                    name: var_name.to_string(),
2982                    evidentiality: None,
2983                    affect: None,
2984                    span: Span { start: 0, end: 0 },
2985                },
2986                generics: None,
2987            }],
2988        });
2989    }
2990
2991    // Otherwise, recurse into subexpressions
2992    match expr {
2993        Expr::Binary { op, left, right } => Expr::Binary {
2994            op: *op,
2995            left: Box::new(replace_in_expr(left, target, var_name)),
2996            right: Box::new(replace_in_expr(right, target, var_name)),
2997        },
2998        Expr::Unary { op, expr: inner } => Expr::Unary {
2999            op: *op,
3000            expr: Box::new(replace_in_expr(inner, target, var_name)),
3001        },
3002        Expr::Index { expr: e, index } => Expr::Index {
3003            expr: Box::new(replace_in_expr(e, target, var_name)),
3004            index: Box::new(replace_in_expr(index, target, var_name)),
3005        },
3006        Expr::Call { func, args } => Expr::Call {
3007            func: Box::new(replace_in_expr(func, target, var_name)),
3008            args: args
3009                .iter()
3010                .map(|a| replace_in_expr(a, target, var_name))
3011                .collect(),
3012        },
3013        Expr::If {
3014            condition,
3015            then_branch,
3016            else_branch,
3017        } => Expr::If {
3018            condition: Box::new(replace_in_expr(condition, target, var_name)),
3019            then_branch: replace_in_block(then_branch, target, var_name),
3020            else_branch: else_branch
3021                .as_ref()
3022                .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3023        },
3024        Expr::While { condition, body } => Expr::While {
3025            condition: Box::new(replace_in_expr(condition, target, var_name)),
3026            body: replace_in_block(body, target, var_name),
3027        },
3028        Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3029        Expr::Return(e) => Expr::Return(
3030            e.as_ref()
3031                .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3032        ),
3033        Expr::Assign { target: t, value } => Expr::Assign {
3034            target: t.clone(),
3035            value: Box::new(replace_in_expr(value, target, var_name)),
3036        },
3037        Expr::Array(elements) => Expr::Array(
3038            elements
3039                .iter()
3040                .map(|e| replace_in_expr(e, target, var_name))
3041                .collect(),
3042        ),
3043        other => other.clone(),
3044    }
3045}
3046
3047/// Replace in a block
3048fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3049    let stmts = block
3050        .stmts
3051        .iter()
3052        .map(|stmt| match stmt {
3053            Stmt::Let { pattern, ty, init } => Stmt::Let {
3054                pattern: pattern.clone(),
3055                ty: ty.clone(),
3056                init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3057            },
3058            Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3059            Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3060            Stmt::Item(item) => Stmt::Item(item.clone()),
3061        })
3062        .collect();
3063
3064    let expr = block
3065        .expr
3066        .as_ref()
3067        .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3068
3069    Block { stmts, expr }
3070}
3071
3072/// Create a let statement for a CSE variable
3073fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3074    Stmt::Let {
3075        pattern: Pattern::Ident {
3076            mutable: false,
3077            name: Ident {
3078                name: var_name.to_string(),
3079                evidentiality: None,
3080                affect: None,
3081                span: Span { start: 0, end: 0 },
3082            },
3083            evidentiality: None,
3084        },
3085        ty: None,
3086        init: Some(expr),
3087    }
3088}
3089
3090// ============================================================================
3091// Public API
3092// ============================================================================
3093
3094/// Optimize a source file at the given optimization level
3095pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3096    let mut optimizer = Optimizer::new(level);
3097    let optimized = optimizer.optimize_file(file);
3098    (optimized, optimizer.stats)
3099}
3100
3101// ============================================================================
3102// Tests
3103// ============================================================================
3104
3105#[cfg(test)]
3106mod tests {
3107    use super::*;
3108
3109    /// Helper to create an integer literal expression
3110    fn int_lit(v: i64) -> Expr {
3111        Expr::Literal(Literal::Int {
3112            value: v.to_string(),
3113            base: NumBase::Decimal,
3114            suffix: None,
3115        })
3116    }
3117
3118    /// Helper to create a variable reference
3119    fn var(name: &str) -> Expr {
3120        Expr::Path(TypePath {
3121            segments: vec![PathSegment {
3122                ident: Ident {
3123                    name: name.to_string(),
3124                    evidentiality: None,
3125                    affect: None,
3126                    span: Span { start: 0, end: 0 },
3127                },
3128                generics: None,
3129            }],
3130        })
3131    }
3132
3133    /// Helper to create a binary add expression
3134    fn add(left: Expr, right: Expr) -> Expr {
3135        Expr::Binary {
3136            op: BinOp::Add,
3137            left: Box::new(left),
3138            right: Box::new(right),
3139        }
3140    }
3141
3142    /// Helper to create a binary multiply expression
3143    fn mul(left: Expr, right: Expr) -> Expr {
3144        Expr::Binary {
3145            op: BinOp::Mul,
3146            left: Box::new(left),
3147            right: Box::new(right),
3148        }
3149    }
3150
3151    #[test]
3152    fn test_expr_hash_equal() {
3153        // Same expressions should have same hash
3154        let e1 = add(var("a"), var("b"));
3155        let e2 = add(var("a"), var("b"));
3156        assert_eq!(expr_hash(&e1), expr_hash(&e2));
3157    }
3158
3159    #[test]
3160    fn test_expr_hash_different() {
3161        // Different expressions should have different hashes
3162        let e1 = add(var("a"), var("b"));
3163        let e2 = add(var("a"), var("c"));
3164        assert_ne!(expr_hash(&e1), expr_hash(&e2));
3165    }
3166
3167    #[test]
3168    fn test_expr_eq() {
3169        let e1 = add(var("a"), var("b"));
3170        let e2 = add(var("a"), var("b"));
3171        let e3 = add(var("a"), var("c"));
3172
3173        assert!(expr_eq(&e1, &e2));
3174        assert!(!expr_eq(&e1, &e3));
3175    }
3176
3177    #[test]
3178    fn test_is_pure_expr() {
3179        assert!(is_pure_expr(&int_lit(42)));
3180        assert!(is_pure_expr(&var("x")));
3181        assert!(is_pure_expr(&add(var("a"), var("b"))));
3182
3183        // Calls are not pure
3184        let call = Expr::Call {
3185            func: Box::new(var("print")),
3186            args: vec![int_lit(42)],
3187        };
3188        assert!(!is_pure_expr(&call));
3189    }
3190
3191    #[test]
3192    fn test_is_cse_worthy() {
3193        assert!(!is_cse_worthy(&int_lit(42))); // literals not worth it
3194        assert!(!is_cse_worthy(&var("x"))); // variables not worth it
3195        assert!(is_cse_worthy(&add(var("a"), var("b")))); // binary ops worth it
3196    }
3197
3198    #[test]
3199    fn test_cse_basic() {
3200        // Create a block with repeated subexpression:
3201        // let x = a + b;
3202        // let y = (a + b) * 2;
3203        // The (a + b) should be extracted
3204        let a_plus_b = add(var("a"), var("b"));
3205
3206        let block = Block {
3207            stmts: vec![
3208                Stmt::Let {
3209                    pattern: Pattern::Ident {
3210                        mutable: false,
3211                        name: Ident {
3212                            name: "x".to_string(),
3213                            evidentiality: None,
3214                            affect: None,
3215                            span: Span { start: 0, end: 0 },
3216                        },
3217                        evidentiality: None,
3218                    },
3219                    ty: None,
3220                    init: Some(a_plus_b.clone()),
3221                },
3222                Stmt::Let {
3223                    pattern: Pattern::Ident {
3224                        mutable: false,
3225                        name: Ident {
3226                            name: "y".to_string(),
3227                            evidentiality: None,
3228                            affect: None,
3229                            span: Span { start: 0, end: 0 },
3230                        },
3231                        evidentiality: None,
3232                    },
3233                    ty: None,
3234                    init: Some(mul(a_plus_b.clone(), int_lit(2))),
3235                },
3236            ],
3237            expr: None,
3238        };
3239
3240        let mut optimizer = Optimizer::new(OptLevel::Standard);
3241        let result = optimizer.pass_cse_block(&block);
3242
3243        // Should have 3 statements now: __cse_0 = a + b, x = __cse_0, y = __cse_0 * 2
3244        assert_eq!(result.stmts.len(), 3);
3245        assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3246
3247        // First statement should be the CSE let binding
3248        if let Stmt::Let {
3249            pattern: Pattern::Ident { name, .. },
3250            ..
3251        } = &result.stmts[0]
3252        {
3253            assert_eq!(name.name, "__cse_0");
3254        } else {
3255            panic!("Expected CSE let binding");
3256        }
3257    }
3258
3259    #[test]
3260    fn test_cse_no_duplicates() {
3261        // No repeated expressions - should not add any CSE bindings
3262        let block = Block {
3263            stmts: vec![
3264                Stmt::Let {
3265                    pattern: Pattern::Ident {
3266                        mutable: false,
3267                        name: Ident {
3268                            name: "x".to_string(),
3269                            evidentiality: None,
3270                            affect: None,
3271                            span: Span { start: 0, end: 0 },
3272                        },
3273                        evidentiality: None,
3274                    },
3275                    ty: None,
3276                    init: Some(add(var("a"), var("b"))),
3277                },
3278                Stmt::Let {
3279                    pattern: Pattern::Ident {
3280                        mutable: false,
3281                        name: Ident {
3282                            name: "y".to_string(),
3283                            evidentiality: None,
3284                            affect: None,
3285                            span: Span { start: 0, end: 0 },
3286                        },
3287                        evidentiality: None,
3288                    },
3289                    ty: None,
3290                    init: Some(add(var("c"), var("d"))),
3291                },
3292            ],
3293            expr: None,
3294        };
3295
3296        let mut optimizer = Optimizer::new(OptLevel::Standard);
3297        let result = optimizer.pass_cse_block(&block);
3298
3299        // Should still have 2 statements (no CSE applied)
3300        assert_eq!(result.stmts.len(), 2);
3301        assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3302    }
3303}