Skip to main content

oxilean_codegen/opt_passes/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::functions::*;
9use std::collections::VecDeque;
10
11/// Beta reduction pass -- reduce lambda applications.
12pub struct BetaReductionPass {
13    pub reductions: u32,
14}
15impl BetaReductionPass {
16    pub fn new() -> Self {
17        BetaReductionPass { reductions: 0 }
18    }
19    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
20        for decl in decls.iter_mut() {
21            self.reduce_expr(&mut decl.body);
22        }
23    }
24    pub(super) fn reduce_expr(&mut self, expr: &mut LcnfExpr) {
25        match expr {
26            LcnfExpr::Let { body, .. } => {
27                self.reduce_expr(body);
28            }
29            LcnfExpr::Case { alts, default, .. } => {
30                for alt in alts.iter_mut() {
31                    self.reduce_expr(&mut alt.body);
32                }
33                if let Some(def) = default {
34                    self.reduce_expr(def);
35                }
36            }
37            LcnfExpr::TailCall(LcnfArg::Lit(_), _) => {
38                self.reductions += 1;
39            }
40            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
41        }
42    }
43}
44/// Describes a dependency between two passes.
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct PassDependency {
47    /// Name of the pass that depends on another.
48    pub pass: String,
49    /// Name of the pass that must run first.
50    pub depends_on: String,
51}
52impl PassDependency {
53    /// Create a new dependency.
54    pub fn new(pass: impl Into<String>, depends_on: impl Into<String>) -> Self {
55        PassDependency {
56            pass: pass.into(),
57            depends_on: depends_on.into(),
58        }
59    }
60}
61#[allow(dead_code)]
62pub struct OPPassRegistry {
63    pub(super) configs: Vec<OPPassConfig>,
64    pub(super) stats: std::collections::HashMap<String, OPPassStats>,
65}
66impl OPPassRegistry {
67    #[allow(dead_code)]
68    pub fn new() -> Self {
69        OPPassRegistry {
70            configs: Vec::new(),
71            stats: std::collections::HashMap::new(),
72        }
73    }
74    #[allow(dead_code)]
75    pub fn register(&mut self, config: OPPassConfig) {
76        self.stats
77            .insert(config.pass_name.clone(), OPPassStats::new());
78        self.configs.push(config);
79    }
80    #[allow(dead_code)]
81    pub fn enabled_passes(&self) -> Vec<&OPPassConfig> {
82        self.configs.iter().filter(|c| c.enabled).collect()
83    }
84    #[allow(dead_code)]
85    pub fn get_stats(&self, name: &str) -> Option<&OPPassStats> {
86        self.stats.get(name)
87    }
88    #[allow(dead_code)]
89    pub fn total_passes(&self) -> usize {
90        self.configs.len()
91    }
92    #[allow(dead_code)]
93    pub fn enabled_count(&self) -> usize {
94        self.enabled_passes().len()
95    }
96    #[allow(dead_code)]
97    pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
98        if let Some(stats) = self.stats.get_mut(name) {
99            stats.record_run(changes, time_ms, iter);
100        }
101    }
102}
103#[allow(dead_code)]
104#[derive(Debug, Clone, PartialEq)]
105pub enum OPPassPhase {
106    Analysis,
107    Transformation,
108    Verification,
109    Cleanup,
110}
111impl OPPassPhase {
112    #[allow(dead_code)]
113    pub fn name(&self) -> &str {
114        match self {
115            OPPassPhase::Analysis => "analysis",
116            OPPassPhase::Transformation => "transformation",
117            OPPassPhase::Verification => "verification",
118            OPPassPhase::Cleanup => "cleanup",
119        }
120    }
121    #[allow(dead_code)]
122    pub fn is_modifying(&self) -> bool {
123        matches!(self, OPPassPhase::Transformation | OPPassPhase::Cleanup)
124    }
125}
126#[allow(dead_code)]
127#[derive(Debug, Clone)]
128pub struct OPWorklist {
129    pub(super) items: std::collections::VecDeque<u32>,
130    pub(super) in_worklist: std::collections::HashSet<u32>,
131}
132impl OPWorklist {
133    #[allow(dead_code)]
134    pub fn new() -> Self {
135        OPWorklist {
136            items: std::collections::VecDeque::new(),
137            in_worklist: std::collections::HashSet::new(),
138        }
139    }
140    #[allow(dead_code)]
141    pub fn push(&mut self, item: u32) -> bool {
142        if self.in_worklist.insert(item) {
143            self.items.push_back(item);
144            true
145        } else {
146            false
147        }
148    }
149    #[allow(dead_code)]
150    pub fn pop(&mut self) -> Option<u32> {
151        let item = self.items.pop_front()?;
152        self.in_worklist.remove(&item);
153        Some(item)
154    }
155    #[allow(dead_code)]
156    pub fn is_empty(&self) -> bool {
157        self.items.is_empty()
158    }
159    #[allow(dead_code)]
160    pub fn len(&self) -> usize {
161        self.items.len()
162    }
163    #[allow(dead_code)]
164    pub fn contains(&self, item: u32) -> bool {
165        self.in_worklist.contains(&item)
166    }
167}
168#[allow(dead_code)]
169#[derive(Debug, Clone)]
170pub struct OPAnalysisCache {
171    pub(super) entries: std::collections::HashMap<String, OPCacheEntry>,
172    pub(super) max_size: usize,
173    pub(super) hits: u64,
174    pub(super) misses: u64,
175}
176impl OPAnalysisCache {
177    #[allow(dead_code)]
178    pub fn new(max_size: usize) -> Self {
179        OPAnalysisCache {
180            entries: std::collections::HashMap::new(),
181            max_size,
182            hits: 0,
183            misses: 0,
184        }
185    }
186    #[allow(dead_code)]
187    pub fn get(&mut self, key: &str) -> Option<&OPCacheEntry> {
188        if self.entries.contains_key(key) {
189            self.hits += 1;
190            self.entries.get(key)
191        } else {
192            self.misses += 1;
193            None
194        }
195    }
196    #[allow(dead_code)]
197    pub fn insert(&mut self, key: String, data: Vec<u8>) {
198        if self.entries.len() >= self.max_size {
199            if let Some(oldest) = self.entries.keys().next().cloned() {
200                self.entries.remove(&oldest);
201            }
202        }
203        self.entries.insert(
204            key.clone(),
205            OPCacheEntry {
206                key,
207                data,
208                timestamp: 0,
209                valid: true,
210            },
211        );
212    }
213    #[allow(dead_code)]
214    pub fn invalidate(&mut self, key: &str) {
215        if let Some(entry) = self.entries.get_mut(key) {
216            entry.valid = false;
217        }
218    }
219    #[allow(dead_code)]
220    pub fn clear(&mut self) {
221        self.entries.clear();
222    }
223    #[allow(dead_code)]
224    pub fn hit_rate(&self) -> f64 {
225        let total = self.hits + self.misses;
226        if total == 0 {
227            return 0.0;
228        }
229        self.hits as f64 / total as f64
230    }
231    #[allow(dead_code)]
232    pub fn size(&self) -> usize {
233        self.entries.len()
234    }
235}
236/// Strength reduction: replaces expensive operations with cheaper equivalents.
237/// For example, multiplication by a power of 2 becomes a left shift.
238pub struct StrengthReductionPass {
239    pub reductions: u32,
240}
241impl StrengthReductionPass {
242    pub fn new() -> Self {
243        StrengthReductionPass { reductions: 0 }
244    }
245    /// Check if a value is a power of two.
246    pub fn is_power_of_two(n: u64) -> bool {
247        n > 0 && (n & (n - 1)) == 0
248    }
249    /// Compute log2 for a power of two, returning None if not a power of two.
250    pub fn log2_exact(n: u64) -> Option<u32> {
251        if Self::is_power_of_two(n) {
252            Some(n.trailing_zeros())
253        } else {
254            None
255        }
256    }
257    /// Check if a value is a power of two minus one (e.g. 0x7F, 0xFF, 0xFFFF).
258    pub fn is_mask(n: u64) -> bool {
259        n > 0 && (n & (n + 1)) == 0
260    }
261    /// Count trailing zeros.
262    pub fn ctz(n: u64) -> u32 {
263        if n == 0 {
264            64
265        } else {
266            n.trailing_zeros()
267        }
268    }
269    /// Count leading zeros.
270    pub fn clz(n: u64) -> u32 {
271        n.leading_zeros()
272    }
273    /// Population count (number of set bits).
274    pub fn popcount(n: u64) -> u32 {
275        n.count_ones()
276    }
277}
278/// Statistics for a single pass execution.
279#[derive(Debug, Clone, Default)]
280pub struct PassStats {
281    /// Name of the pass.
282    pub name: String,
283    /// Number of times the pass has been run.
284    pub run_count: u32,
285    /// Total number of changes made across all runs.
286    pub total_changes: usize,
287    /// Duration of the last run in microseconds.
288    pub last_duration_us: u64,
289    /// Whether the last run made any changes.
290    pub last_changed: bool,
291}
292impl PassStats {
293    /// Create a new stats entry for the named pass.
294    pub fn new(name: impl Into<String>) -> Self {
295        PassStats {
296            name: name.into(),
297            ..Default::default()
298        }
299    }
300    /// Record a run of this pass.
301    pub fn record_run(&mut self, changes: usize, duration_us: u64) {
302        self.run_count += 1;
303        self.total_changes += changes;
304        self.last_duration_us = duration_us;
305        self.last_changed = changes > 0;
306    }
307    /// Average changes per run.
308    pub fn avg_changes(&self) -> f64 {
309        if self.run_count == 0 {
310            0.0
311        } else {
312            self.total_changes as f64 / self.run_count as f64
313        }
314    }
315}
316/// Estimates the size and complexity of LCNF expressions.
317///
318/// Used by inlining heuristics to decide whether a function body is small
319/// enough to inline.
320pub struct ExprSizeEstimator;
321impl ExprSizeEstimator {
322    /// Count the number of let-bindings in an expression.
323    pub fn count_lets(expr: &LcnfExpr) -> usize {
324        match expr {
325            LcnfExpr::Let { body, .. } => 1 + Self::count_lets(body),
326            LcnfExpr::Case { alts, default, .. } => {
327                let alt_sum: usize = alts.iter().map(|a| Self::count_lets(&a.body)).sum();
328                let def_sum = default.as_ref().map(|d| Self::count_lets(d)).unwrap_or(0);
329                alt_sum + def_sum
330            }
331            _ => 0,
332        }
333    }
334    /// Count the number of case expressions.
335    pub fn count_cases(expr: &LcnfExpr) -> usize {
336        match expr {
337            LcnfExpr::Let { body, .. } => Self::count_cases(body),
338            LcnfExpr::Case { alts, default, .. } => {
339                let alt_sum: usize = alts.iter().map(|a| Self::count_cases(&a.body)).sum();
340                let def_sum = default.as_ref().map(|d| Self::count_cases(d)).unwrap_or(0);
341                1 + alt_sum + def_sum
342            }
343            _ => 0,
344        }
345    }
346    /// Compute a complexity score (lets + 2*cases + tail_calls).
347    pub fn complexity(expr: &LcnfExpr) -> usize {
348        match expr {
349            LcnfExpr::Let { body, .. } => 1 + Self::complexity(body),
350            LcnfExpr::Case { alts, default, .. } => {
351                let alt_sum: usize = alts.iter().map(|a| Self::complexity(&a.body)).sum();
352                let def_sum = default.as_ref().map(|d| Self::complexity(d)).unwrap_or(0);
353                2 + alt_sum + def_sum
354            }
355            LcnfExpr::TailCall(_, _) => 1,
356            LcnfExpr::Return(_) => 0,
357            LcnfExpr::Unreachable => 0,
358        }
359    }
360    /// Maximum nesting depth of the expression.
361    pub fn max_depth(expr: &LcnfExpr) -> usize {
362        match expr {
363            LcnfExpr::Let { body, .. } => 1 + Self::max_depth(body),
364            LcnfExpr::Case { alts, default, .. } => {
365                let max_alt = alts
366                    .iter()
367                    .map(|a| Self::max_depth(&a.body))
368                    .max()
369                    .unwrap_or(0);
370                let max_def = default.as_ref().map(|d| Self::max_depth(d)).unwrap_or(0);
371                1 + max_alt.max(max_def)
372            }
373            _ => 0,
374        }
375    }
376    /// Count all variable references in the expression.
377    pub fn count_var_refs(expr: &LcnfExpr) -> usize {
378        match expr {
379            LcnfExpr::Let { value, body, .. } => {
380                Self::count_var_refs_in_value(value) + Self::count_var_refs(body)
381            }
382            LcnfExpr::Case { alts, default, .. } => {
383                let alt_sum: usize = alts.iter().map(|a| Self::count_var_refs(&a.body)).sum();
384                let def_sum = default
385                    .as_ref()
386                    .map(|d| Self::count_var_refs(d))
387                    .unwrap_or(0);
388                1 + alt_sum + def_sum
389            }
390            LcnfExpr::Return(LcnfArg::Var(_)) => 1,
391            LcnfExpr::TailCall(f, args) => {
392                let f_count = if matches!(f, LcnfArg::Var(_)) { 1 } else { 0 };
393                let a_count = args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count();
394                f_count + a_count
395            }
396            _ => 0,
397        }
398    }
399    pub(super) fn count_var_refs_in_value(value: &LcnfLetValue) -> usize {
400        match value {
401            LcnfLetValue::App(f, args) => {
402                let f_count = if matches!(f, LcnfArg::Var(_)) { 1 } else { 0 };
403                let a_count = args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count();
404                f_count + a_count
405            }
406            LcnfLetValue::FVar(_) => 1,
407            LcnfLetValue::Proj(_, _, _) => 1,
408            LcnfLetValue::Reset(_) => 1,
409            LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
410                args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count()
411            }
412            LcnfLetValue::Lit(_) | LcnfLetValue::Erased => 0,
413        }
414    }
415    /// Whether an expression is "trivial" (just a return or unreachable).
416    pub fn is_trivial(expr: &LcnfExpr) -> bool {
417        matches!(expr, LcnfExpr::Return(_) | LcnfExpr::Unreachable)
418    }
419    /// Whether an expression is suitable for inlining (complexity below threshold).
420    pub fn should_inline(expr: &LcnfExpr, threshold: usize) -> bool {
421        Self::complexity(expr) <= threshold
422    }
423}
424/// Manages a pipeline of optimization passes with dependency ordering.
425///
426/// Passes are executed in topological order based on their declared
427/// dependencies. Cycle detection uses Kahn's algorithm.
428#[derive(Debug, Default)]
429pub struct PassManager {
430    /// Registered pass names in insertion order.
431    pub(super) pass_names: Vec<String>,
432    /// Dependencies between passes.
433    pub(super) dependencies: Vec<PassDependency>,
434    /// Per-pass statistics.
435    pub(super) stats: HashMap<String, PassStats>,
436    /// Maximum number of fixed-point iterations.
437    pub max_iterations: u32,
438}
439impl PassManager {
440    /// Create a new pass manager.
441    pub fn new() -> Self {
442        PassManager {
443            pass_names: Vec::new(),
444            dependencies: Vec::new(),
445            stats: HashMap::new(),
446            max_iterations: 10,
447        }
448    }
449    /// Register a pass by name.
450    pub fn add_pass(&mut self, name: impl Into<String>) {
451        let n = name.into();
452        if !self.pass_names.contains(&n) {
453            self.stats.insert(n.clone(), PassStats::new(&n));
454            self.pass_names.push(n);
455        }
456    }
457    /// Add a dependency: `pass` depends on `depends_on`.
458    pub fn add_dependency(&mut self, pass: impl Into<String>, depends_on: impl Into<String>) {
459        let dep = PassDependency::new(pass, depends_on);
460        if !self.dependencies.contains(&dep) {
461            self.dependencies.push(dep);
462        }
463    }
464    /// Record a run of the named pass.
465    pub fn record_run(&mut self, name: &str, changes: usize, duration_us: u64) {
466        if let Some(stats) = self.stats.get_mut(name) {
467            stats.record_run(changes, duration_us);
468        }
469    }
470    /// Get statistics for a named pass.
471    pub fn get_stats(&self, name: &str) -> Option<&PassStats> {
472        self.stats.get(name)
473    }
474    /// Get all statistics.
475    pub fn all_stats(&self) -> &HashMap<String, PassStats> {
476        &self.stats
477    }
478    /// Number of registered passes.
479    pub fn num_passes(&self) -> usize {
480        self.pass_names.len()
481    }
482    /// Compute topological ordering of passes using Kahn's algorithm.
483    ///
484    /// Returns `None` if there is a cycle in the dependency graph.
485    pub fn topological_order(&self) -> Option<Vec<String>> {
486        let mut in_degree: HashMap<&str, usize> = HashMap::new();
487        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
488        for name in &self.pass_names {
489            in_degree.insert(name.as_str(), 0);
490            adj.entry(name.as_str()).or_default();
491        }
492        for dep in &self.dependencies {
493            if self.pass_names.contains(&dep.pass) && self.pass_names.contains(&dep.depends_on) {
494                adj.entry(dep.depends_on.as_str())
495                    .or_default()
496                    .push(dep.pass.as_str());
497                *in_degree.entry(dep.pass.as_str()).or_insert(0) += 1;
498            }
499        }
500        let mut queue: Vec<&str> = in_degree
501            .iter()
502            .filter(|(_, &deg)| deg == 0)
503            .map(|(&name, _)| name)
504            .collect();
505        queue.sort();
506        let mut result = Vec::new();
507        while let Some(node) = queue.pop() {
508            result.push(node.to_string());
509            if let Some(neighbors) = adj.get(node) {
510                for &neighbor in neighbors {
511                    let deg = in_degree
512                        .get_mut(neighbor)
513                        .expect(
514                            "neighbor must be in in_degree; all passes were inserted during initialization",
515                        );
516                    *deg -= 1;
517                    if *deg == 0 {
518                        queue.push(neighbor);
519                        queue.sort();
520                    }
521                }
522            }
523        }
524        if result.len() == self.pass_names.len() {
525            Some(result)
526        } else {
527            None
528        }
529    }
530    /// Check if the dependency graph has a cycle.
531    pub fn has_cycle(&self) -> bool {
532        self.topological_order().is_none()
533    }
534    /// Total changes across all passes.
535    pub fn total_changes(&self) -> usize {
536        self.stats.values().map(|s| s.total_changes).sum()
537    }
538    /// Total runs across all passes.
539    pub fn total_runs(&self) -> u32 {
540        self.stats.values().map(|s| s.run_count).sum()
541    }
542}
543/// Constant folding pass -- evaluate constant expressions at compile time.
544pub struct ConstantFoldingPass {
545    pub folds_performed: u32,
546}
547impl ConstantFoldingPass {
548    pub fn new() -> Self {
549        ConstantFoldingPass { folds_performed: 0 }
550    }
551    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
552        for decl in decls.iter_mut() {
553            self.fold_expr(&mut decl.body);
554        }
555    }
556    pub(super) fn fold_expr(&mut self, expr: &mut LcnfExpr) {
557        match expr {
558            LcnfExpr::Let { value, body, .. } => {
559                if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Nat(lhs)), args) = value {
560                    if args.len() == 2 {
561                        if let (LcnfArg::Lit(LcnfLit::Nat(rhs)), LcnfArg::Lit(LcnfLit::Nat(op_n))) =
562                            (&args[0], &args[1])
563                        {
564                            let op = match op_n {
565                                0 => "add",
566                                1 => "sub",
567                                2 => "mul",
568                                _ => "",
569                            };
570                            if let Some(result) = self.try_fold_nat_op(op, *lhs, *rhs) {
571                                *value = LcnfLetValue::Lit(LcnfLit::Nat(result));
572                                self.folds_performed += 1;
573                            }
574                        }
575                    }
576                }
577                self.fold_expr(body);
578            }
579            LcnfExpr::Case { alts, default, .. } => {
580                for alt in alts.iter_mut() {
581                    self.fold_expr(&mut alt.body);
582                }
583                if let Some(def) = default {
584                    self.fold_expr(def);
585                }
586            }
587            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
588        }
589    }
590    /// Try to fold a nat binary operation.
591    pub fn try_fold_nat_op(&self, op: &str, lhs: u64, rhs: u64) -> Option<u64> {
592        match op {
593            "add" => Some(lhs.wrapping_add(rhs)),
594            "sub" => Some(lhs.saturating_sub(rhs)),
595            "mul" => Some(lhs.wrapping_mul(rhs)),
596            "div" => lhs.checked_div(rhs),
597            "mod" => lhs.checked_rem(rhs),
598            "min" => Some(lhs.min(rhs)),
599            "max" => Some(lhs.max(rhs)),
600            "pow" => Some(lhs.wrapping_pow(rhs as u32)),
601            "and" => Some(lhs & rhs),
602            "or" => Some(lhs | rhs),
603            "xor" => Some(lhs ^ rhs),
604            "shl" => Some(lhs.wrapping_shl(rhs as u32)),
605            "shr" => Some(lhs.wrapping_shr(rhs as u32)),
606            _ => None,
607        }
608    }
609    /// Try to fold a boolean operation.
610    pub fn try_fold_bool_op(&self, op: &str, lhs: bool, rhs: bool) -> Option<bool> {
611        match op {
612            "and" => Some(lhs && rhs),
613            "or" => Some(lhs || rhs),
614            "xor" => Some(lhs ^ rhs),
615            "eq" => Some(lhs == rhs),
616            "ne" => Some(lhs != rhs),
617            _ => None,
618        }
619    }
620    /// Try to fold a comparison operation.
621    pub fn try_fold_cmp(&self, op: &str, lhs: u64, rhs: u64) -> Option<bool> {
622        match op {
623            "eq" => Some(lhs == rhs),
624            "ne" => Some(lhs != rhs),
625            "lt" => Some(lhs < rhs),
626            "le" => Some(lhs <= rhs),
627            "gt" => Some(lhs > rhs),
628            "ge" => Some(lhs >= rhs),
629            _ => None,
630        }
631    }
632}
633#[allow(dead_code)]
634#[derive(Debug, Clone)]
635pub struct OPCacheEntry {
636    pub key: String,
637    pub data: Vec<u8>,
638    pub timestamp: u64,
639    pub valid: bool,
640}
641#[allow(dead_code)]
642#[derive(Debug, Clone)]
643pub struct OPDominatorTree {
644    pub idom: Vec<Option<u32>>,
645    pub dom_children: Vec<Vec<u32>>,
646    pub dom_depth: Vec<u32>,
647}
648impl OPDominatorTree {
649    #[allow(dead_code)]
650    pub fn new(size: usize) -> Self {
651        OPDominatorTree {
652            idom: vec![None; size],
653            dom_children: vec![Vec::new(); size],
654            dom_depth: vec![0; size],
655        }
656    }
657    #[allow(dead_code)]
658    pub fn set_idom(&mut self, node: usize, idom: u32) {
659        self.idom[node] = Some(idom);
660    }
661    #[allow(dead_code)]
662    pub fn dominates(&self, a: usize, b: usize) -> bool {
663        if a == b {
664            return true;
665        }
666        let mut cur = b;
667        loop {
668            match self.idom[cur] {
669                Some(parent) if parent as usize == a => return true,
670                Some(parent) if parent as usize == cur => return false,
671                Some(parent) => cur = parent as usize,
672                None => return false,
673            }
674        }
675    }
676    #[allow(dead_code)]
677    pub fn depth(&self, node: usize) -> u32 {
678        self.dom_depth.get(node).copied().unwrap_or(0)
679    }
680}
681/// Profile-guided optimization hints
682#[derive(Debug, Clone)]
683pub struct PgoHints {
684    pub hot_functions: Vec<String>,
685    pub likely_branches: Vec<(String, u32, bool)>,
686    pub inline_candidates: Vec<String>,
687    pub cold_functions: Vec<String>,
688    pub call_counts: HashMap<String, u64>,
689}
690impl PgoHints {
691    pub fn new() -> Self {
692        PgoHints {
693            hot_functions: Vec::new(),
694            likely_branches: Vec::new(),
695            inline_candidates: Vec::new(),
696            cold_functions: Vec::new(),
697            call_counts: HashMap::new(),
698        }
699    }
700    pub fn mark_hot(&mut self, func_name: &str) {
701        if !self.hot_functions.iter().any(|f| f == func_name) {
702            self.hot_functions.push(func_name.to_string());
703        }
704    }
705    pub fn mark_cold(&mut self, func_name: &str) {
706        if !self.cold_functions.iter().any(|f| f == func_name) {
707            self.cold_functions.push(func_name.to_string());
708        }
709    }
710    pub fn mark_inline(&mut self, func_name: &str) {
711        if !self.inline_candidates.iter().any(|f| f == func_name) {
712            self.inline_candidates.push(func_name.to_string());
713        }
714    }
715    pub fn record_call(&mut self, func_name: &str, count: u64) {
716        *self.call_counts.entry(func_name.to_string()).or_insert(0) += count;
717    }
718    pub fn is_hot(&self, func_name: &str) -> bool {
719        self.hot_functions.iter().any(|f| f == func_name)
720    }
721    pub fn is_cold(&self, func_name: &str) -> bool {
722        self.cold_functions.iter().any(|f| f == func_name)
723    }
724    pub fn should_inline(&self, func_name: &str) -> bool {
725        self.inline_candidates.iter().any(|f| f == func_name)
726    }
727    pub fn call_count(&self, func_name: &str) -> u64 {
728        self.call_counts.get(func_name).copied().unwrap_or(0)
729    }
730    /// Total number of hints across all categories.
731    pub fn total_hints(&self) -> usize {
732        self.hot_functions.len()
733            + self.cold_functions.len()
734            + self.inline_candidates.len()
735            + self.likely_branches.len()
736            + self.call_counts.len()
737    }
738    /// Merge another set of hints into this one.
739    pub fn merge(&mut self, other: &PgoHints) {
740        for f in &other.hot_functions {
741            self.mark_hot(f);
742        }
743        for f in &other.cold_functions {
744            self.mark_cold(f);
745        }
746        for f in &other.inline_candidates {
747            self.mark_inline(f);
748        }
749        for (name, count) in &other.call_counts {
750            self.record_call(name, *count);
751        }
752    }
753    /// Classify a function by its hotness: Hot, Cold, or Normal.
754    pub fn classify(&self, func_name: &str) -> &'static str {
755        if self.is_hot(func_name) {
756            "hot"
757        } else if self.is_cold(func_name) {
758            "cold"
759        } else {
760            "normal"
761        }
762    }
763}
764/// Dead code elimination -- remove unreachable let expressions.
765pub struct DeadCodeEliminationPass {
766    pub removed: u32,
767}
768impl DeadCodeEliminationPass {
769    pub fn new() -> Self {
770        DeadCodeEliminationPass { removed: 0 }
771    }
772    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
773        for decl in decls.iter_mut() {
774            let mut used: HashSet<LcnfVarId> = HashSet::new();
775            Self::collect_used_vars(&decl.body, &mut used);
776            let mut body = decl.body.clone();
777            self.eliminate_dead_lets(&mut body, &used);
778            decl.body = body;
779        }
780    }
781    pub(super) fn eliminate_dead_lets(&mut self, expr: &mut LcnfExpr, used: &HashSet<LcnfVarId>) {
782        match expr {
783            LcnfExpr::Let {
784                id, value, body, ..
785            } => {
786                let is_pure = matches!(
787                    value,
788                    LcnfLetValue::Lit(_) | LcnfLetValue::FVar(_) | LcnfLetValue::Erased
789                );
790                if is_pure && !used.contains(id) {
791                    let new_body = *body.clone();
792                    *expr = new_body;
793                    self.removed += 1;
794                    self.eliminate_dead_lets(expr, used);
795                } else {
796                    self.eliminate_dead_lets(body, used);
797                }
798            }
799            LcnfExpr::Case { alts, default, .. } => {
800                for alt in alts.iter_mut() {
801                    self.eliminate_dead_lets(&mut alt.body, used);
802                }
803                if let Some(def) = default {
804                    self.eliminate_dead_lets(def, used);
805                }
806            }
807            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
808        }
809    }
810    pub(super) fn collect_used_vars(expr: &LcnfExpr, used: &mut HashSet<LcnfVarId>) {
811        match expr {
812            LcnfExpr::Let {
813                id: _, value, body, ..
814            } => {
815                match value {
816                    LcnfLetValue::App(func, args) => {
817                        if let LcnfArg::Var(v) = func {
818                            used.insert(*v);
819                        }
820                        for a in args {
821                            if let LcnfArg::Var(v) = a {
822                                used.insert(*v);
823                            }
824                        }
825                    }
826                    LcnfLetValue::FVar(v) => {
827                        used.insert(*v);
828                    }
829                    LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
830                        for a in args {
831                            if let LcnfArg::Var(v) = a {
832                                used.insert(*v);
833                            }
834                        }
835                    }
836                    LcnfLetValue::Proj(_, _, v) => {
837                        used.insert(*v);
838                    }
839                    LcnfLetValue::Reset(v) => {
840                        used.insert(*v);
841                    }
842                    LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
843                }
844                Self::collect_used_vars(body, used);
845            }
846            LcnfExpr::Case {
847                scrutinee,
848                alts,
849                default,
850                ..
851            } => {
852                used.insert(*scrutinee);
853                for alt in alts {
854                    Self::collect_used_vars(&alt.body, used);
855                }
856                if let Some(def) = default {
857                    Self::collect_used_vars(def, used);
858                }
859            }
860            LcnfExpr::Return(a) | LcnfExpr::TailCall(a, _) => {
861                if let LcnfArg::Var(v) = a {
862                    used.insert(*v);
863                }
864                if let LcnfExpr::TailCall(_, args) = expr {
865                    for a in args {
866                        if let LcnfArg::Var(v) = a {
867                            used.insert(*v);
868                        }
869                    }
870                }
871            }
872            LcnfExpr::Unreachable => {}
873        }
874    }
875}
876#[allow(dead_code)]
877#[derive(Debug, Clone)]
878pub struct OPPassConfig {
879    pub phase: OPPassPhase,
880    pub enabled: bool,
881    pub max_iterations: u32,
882    pub debug_output: bool,
883    pub pass_name: String,
884}
885impl OPPassConfig {
886    #[allow(dead_code)]
887    pub fn new(name: impl Into<String>, phase: OPPassPhase) -> Self {
888        OPPassConfig {
889            phase,
890            enabled: true,
891            max_iterations: 10,
892            debug_output: false,
893            pass_name: name.into(),
894        }
895    }
896    #[allow(dead_code)]
897    pub fn disabled(mut self) -> Self {
898        self.enabled = false;
899        self
900    }
901    #[allow(dead_code)]
902    pub fn with_debug(mut self) -> Self {
903        self.debug_output = true;
904        self
905    }
906    #[allow(dead_code)]
907    pub fn max_iter(mut self, n: u32) -> Self {
908        self.max_iterations = n;
909        self
910    }
911}
912/// Copy propagation -- replace uses of copied variables with originals.
913pub struct CopyPropagationPass {
914    pub substitutions: u32,
915}
916impl CopyPropagationPass {
917    pub fn new() -> Self {
918        CopyPropagationPass { substitutions: 0 }
919    }
920    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
921        for decl in decls.iter_mut() {
922            self.propagate_copies_in_expr(&mut decl.body);
923        }
924    }
925    pub(super) fn propagate_copies_in_expr(&mut self, expr: &mut LcnfExpr) {
926        if let LcnfExpr::Let {
927            id,
928            value: LcnfLetValue::FVar(src),
929            body,
930            ..
931        } = expr
932        {
933            let from = *id;
934            let to = *src;
935            substitute_var_in_expr(body, from, to);
936            self.substitutions += 1;
937            self.propagate_copies_in_expr(body);
938        } else {
939            match expr {
940                LcnfExpr::Let { body, .. } => self.propagate_copies_in_expr(body),
941                LcnfExpr::Case { alts, default, .. } => {
942                    for alt in alts.iter_mut() {
943                        self.propagate_copies_in_expr(&mut alt.body);
944                    }
945                    if let Some(def) = default {
946                        self.propagate_copies_in_expr(def);
947                    }
948                }
949                _ => {}
950            }
951        }
952    }
953}
954#[allow(dead_code)]
955#[derive(Debug, Clone, Default)]
956pub struct OPPassStats {
957    pub total_runs: u32,
958    pub successful_runs: u32,
959    pub total_changes: u64,
960    pub time_ms: u64,
961    pub iterations_used: u32,
962}
963impl OPPassStats {
964    #[allow(dead_code)]
965    pub fn new() -> Self {
966        Self::default()
967    }
968    #[allow(dead_code)]
969    pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
970        self.total_runs += 1;
971        self.successful_runs += 1;
972        self.total_changes += changes;
973        self.time_ms += time_ms;
974        self.iterations_used = iterations;
975    }
976    #[allow(dead_code)]
977    pub fn average_changes_per_run(&self) -> f64 {
978        if self.total_runs == 0 {
979            return 0.0;
980        }
981        self.total_changes as f64 / self.total_runs as f64
982    }
983    #[allow(dead_code)]
984    pub fn success_rate(&self) -> f64 {
985        if self.total_runs == 0 {
986            return 0.0;
987        }
988        self.successful_runs as f64 / self.total_runs as f64
989    }
990    #[allow(dead_code)]
991    pub fn format_summary(&self) -> String {
992        format!(
993            "Runs: {}/{}, Changes: {}, Time: {}ms",
994            self.successful_runs, self.total_runs, self.total_changes, self.time_ms
995        )
996    }
997}
998#[allow(dead_code)]
999#[derive(Debug, Clone)]
1000pub struct OPLivenessInfo {
1001    pub live_in: Vec<std::collections::HashSet<u32>>,
1002    pub live_out: Vec<std::collections::HashSet<u32>>,
1003    pub defs: Vec<std::collections::HashSet<u32>>,
1004    pub uses: Vec<std::collections::HashSet<u32>>,
1005}
1006impl OPLivenessInfo {
1007    #[allow(dead_code)]
1008    pub fn new(block_count: usize) -> Self {
1009        OPLivenessInfo {
1010            live_in: vec![std::collections::HashSet::new(); block_count],
1011            live_out: vec![std::collections::HashSet::new(); block_count],
1012            defs: vec![std::collections::HashSet::new(); block_count],
1013            uses: vec![std::collections::HashSet::new(); block_count],
1014        }
1015    }
1016    #[allow(dead_code)]
1017    pub fn add_def(&mut self, block: usize, var: u32) {
1018        if block < self.defs.len() {
1019            self.defs[block].insert(var);
1020        }
1021    }
1022    #[allow(dead_code)]
1023    pub fn add_use(&mut self, block: usize, var: u32) {
1024        if block < self.uses.len() {
1025            self.uses[block].insert(var);
1026        }
1027    }
1028    #[allow(dead_code)]
1029    pub fn is_live_in(&self, block: usize, var: u32) -> bool {
1030        self.live_in
1031            .get(block)
1032            .map(|s| s.contains(&var))
1033            .unwrap_or(false)
1034    }
1035    #[allow(dead_code)]
1036    pub fn is_live_out(&self, block: usize, var: u32) -> bool {
1037        self.live_out
1038            .get(block)
1039            .map(|s| s.contains(&var))
1040            .unwrap_or(false)
1041    }
1042}
1043#[allow(dead_code)]
1044#[derive(Debug, Clone)]
1045pub struct OPDepGraph {
1046    pub(super) nodes: Vec<u32>,
1047    pub(super) edges: Vec<(u32, u32)>,
1048}
1049impl OPDepGraph {
1050    #[allow(dead_code)]
1051    pub fn new() -> Self {
1052        OPDepGraph {
1053            nodes: Vec::new(),
1054            edges: Vec::new(),
1055        }
1056    }
1057    #[allow(dead_code)]
1058    pub fn add_node(&mut self, id: u32) {
1059        if !self.nodes.contains(&id) {
1060            self.nodes.push(id);
1061        }
1062    }
1063    #[allow(dead_code)]
1064    pub fn add_dep(&mut self, dep: u32, dependent: u32) {
1065        self.add_node(dep);
1066        self.add_node(dependent);
1067        self.edges.push((dep, dependent));
1068    }
1069    #[allow(dead_code)]
1070    pub fn dependents_of(&self, node: u32) -> Vec<u32> {
1071        self.edges
1072            .iter()
1073            .filter(|(d, _)| *d == node)
1074            .map(|(_, dep)| *dep)
1075            .collect()
1076    }
1077    #[allow(dead_code)]
1078    pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
1079        self.edges
1080            .iter()
1081            .filter(|(_, dep)| *dep == node)
1082            .map(|(d, _)| *d)
1083            .collect()
1084    }
1085    #[allow(dead_code)]
1086    pub fn topological_sort(&self) -> Vec<u32> {
1087        let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
1088        for &n in &self.nodes {
1089            in_degree.insert(n, 0);
1090        }
1091        for (_, dep) in &self.edges {
1092            *in_degree.entry(*dep).or_insert(0) += 1;
1093        }
1094        let mut queue: std::collections::VecDeque<u32> = self
1095            .nodes
1096            .iter()
1097            .filter(|&&n| in_degree[&n] == 0)
1098            .copied()
1099            .collect();
1100        let mut result = Vec::new();
1101        while let Some(node) = queue.pop_front() {
1102            result.push(node);
1103            for dep in self.dependents_of(node) {
1104                let cnt = in_degree.entry(dep).or_insert(0);
1105                *cnt = cnt.saturating_sub(1);
1106                if *cnt == 0 {
1107                    queue.push_back(dep);
1108                }
1109            }
1110        }
1111        result
1112    }
1113    #[allow(dead_code)]
1114    pub fn has_cycle(&self) -> bool {
1115        self.topological_sort().len() < self.nodes.len()
1116    }
1117}
1118#[allow(dead_code)]
1119pub struct OPConstantFoldingHelper;
1120impl OPConstantFoldingHelper {
1121    #[allow(dead_code)]
1122    pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
1123        a.checked_add(b)
1124    }
1125    #[allow(dead_code)]
1126    pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
1127        a.checked_sub(b)
1128    }
1129    #[allow(dead_code)]
1130    pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
1131        a.checked_mul(b)
1132    }
1133    #[allow(dead_code)]
1134    pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
1135        if b == 0 {
1136            None
1137        } else {
1138            a.checked_div(b)
1139        }
1140    }
1141    #[allow(dead_code)]
1142    pub fn fold_add_f64(a: f64, b: f64) -> f64 {
1143        a + b
1144    }
1145    #[allow(dead_code)]
1146    pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
1147        a * b
1148    }
1149    #[allow(dead_code)]
1150    pub fn fold_neg_i64(a: i64) -> Option<i64> {
1151        a.checked_neg()
1152    }
1153    #[allow(dead_code)]
1154    pub fn fold_not_bool(a: bool) -> bool {
1155        !a
1156    }
1157    #[allow(dead_code)]
1158    pub fn fold_and_bool(a: bool, b: bool) -> bool {
1159        a && b
1160    }
1161    #[allow(dead_code)]
1162    pub fn fold_or_bool(a: bool, b: bool) -> bool {
1163        a || b
1164    }
1165    #[allow(dead_code)]
1166    pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
1167        a.checked_shl(b)
1168    }
1169    #[allow(dead_code)]
1170    pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
1171        a.checked_shr(b)
1172    }
1173    #[allow(dead_code)]
1174    pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
1175        if b == 0 {
1176            None
1177        } else {
1178            Some(a % b)
1179        }
1180    }
1181    #[allow(dead_code)]
1182    pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
1183        a & b
1184    }
1185    #[allow(dead_code)]
1186    pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
1187        a | b
1188    }
1189    #[allow(dead_code)]
1190    pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
1191        a ^ b
1192    }
1193    #[allow(dead_code)]
1194    pub fn fold_bitnot_i64(a: i64) -> i64 {
1195        !a
1196    }
1197}
1198/// Estimates the cost of inlining a function.
1199#[derive(Debug, Clone)]
1200pub struct InlineCostEstimator {
1201    /// Base cost threshold below which functions are always inlined.
1202    pub always_inline_threshold: usize,
1203    /// Threshold for functions in hot call sites.
1204    pub hot_threshold: usize,
1205    /// Threshold for cold call sites.
1206    pub cold_threshold: usize,
1207    /// Bonus for tail-recursive functions (they benefit less from inlining).
1208    pub tail_recursive_penalty: usize,
1209}
1210impl InlineCostEstimator {
1211    /// Compute the inlining cost for a function body.
1212    pub fn cost(&self, decl: &LcnfFunDecl) -> usize {
1213        let base = ExprSizeEstimator::complexity(&decl.body);
1214        let penalty = if decl.is_recursive {
1215            self.tail_recursive_penalty
1216        } else {
1217            0
1218        };
1219        base + penalty
1220    }
1221    /// Decide whether to inline based on cost and PGO hints.
1222    pub fn should_inline(&self, decl: &LcnfFunDecl, pgo: Option<&PgoHints>) -> bool {
1223        let cost = self.cost(decl);
1224        if cost <= self.always_inline_threshold {
1225            return true;
1226        }
1227        if let Some(hints) = pgo {
1228            if hints.should_inline(&decl.name) {
1229                return true;
1230            }
1231            if hints.is_hot(&decl.name) {
1232                return cost <= self.hot_threshold;
1233            }
1234            if hints.is_cold(&decl.name) {
1235                return cost <= self.cold_threshold;
1236            }
1237        }
1238        cost <= self.cold_threshold
1239    }
1240}
1241/// Eliminates identity let-bindings of the form `let x = x`.
1242pub struct IdentityEliminationPass {
1243    pub eliminated: u32,
1244}
1245impl IdentityEliminationPass {
1246    pub fn new() -> Self {
1247        IdentityEliminationPass { eliminated: 0 }
1248    }
1249    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
1250        for decl in decls.iter_mut() {
1251            self.elim_expr(&mut decl.body);
1252        }
1253    }
1254    pub(super) fn elim_expr(&mut self, expr: &mut LcnfExpr) {
1255        loop {
1256            if let LcnfExpr::Let {
1257                id,
1258                value: LcnfLetValue::FVar(src),
1259                body,
1260                ..
1261            } = expr
1262            {
1263                if *id == *src {
1264                    let new_body = *body.clone();
1265                    *expr = new_body;
1266                    self.eliminated += 1;
1267                    continue;
1268                }
1269            }
1270            break;
1271        }
1272        match expr {
1273            LcnfExpr::Let { body, .. } => self.elim_expr(body),
1274            LcnfExpr::Case { alts, default, .. } => {
1275                for alt in alts.iter_mut() {
1276                    self.elim_expr(&mut alt.body);
1277                }
1278                if let Some(def) = default {
1279                    self.elim_expr(def);
1280                }
1281            }
1282            _ => {}
1283        }
1284    }
1285}
1286/// Eliminates code after `Unreachable` terminators.
1287pub struct UnreachableCodeEliminationPass {
1288    pub eliminated: u32,
1289}
1290impl UnreachableCodeEliminationPass {
1291    pub fn new() -> Self {
1292        UnreachableCodeEliminationPass { eliminated: 0 }
1293    }
1294    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
1295        for decl in decls.iter_mut() {
1296            self.elim_expr(&mut decl.body);
1297        }
1298    }
1299    pub(super) fn elim_expr(&mut self, expr: &mut LcnfExpr) {
1300        match expr {
1301            LcnfExpr::Let { body, .. } => {
1302                self.elim_expr(body);
1303                if matches!(**body, LcnfExpr::Unreachable) {
1304                    *expr = LcnfExpr::Unreachable;
1305                    self.eliminated += 1;
1306                }
1307            }
1308            LcnfExpr::Case { alts, default, .. } => {
1309                for alt in alts.iter_mut() {
1310                    self.elim_expr(&mut alt.body);
1311                }
1312                if let Some(def) = default {
1313                    self.elim_expr(def);
1314                }
1315            }
1316            _ => {}
1317        }
1318    }
1319}