Skip to main content

oxilean_codegen/opt_passes/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::functions::*;
9use std::collections::VecDeque;
10
11/// Beta reduction pass -- reduce lambda applications.
12pub struct BetaReductionPass {
13    pub reductions: u32,
14}
15impl BetaReductionPass {
16    pub fn new() -> Self {
17        BetaReductionPass { reductions: 0 }
18    }
19    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
20        for decl in decls.iter_mut() {
21            self.reduce_expr(&mut decl.body);
22        }
23    }
24    pub(super) fn reduce_expr(&mut self, expr: &mut LcnfExpr) {
25        match expr {
26            LcnfExpr::Let { body, .. } => {
27                self.reduce_expr(body);
28            }
29            LcnfExpr::Case { alts, default, .. } => {
30                for alt in alts.iter_mut() {
31                    self.reduce_expr(&mut alt.body);
32                }
33                if let Some(def) = default {
34                    self.reduce_expr(def);
35                }
36            }
37            LcnfExpr::TailCall(LcnfArg::Lit(_), _) => {
38                self.reductions += 1;
39            }
40            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
41        }
42    }
43}
44/// Describes a dependency between two passes.
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct PassDependency {
47    /// Name of the pass that depends on another.
48    pub pass: String,
49    /// Name of the pass that must run first.
50    pub depends_on: String,
51}
52impl PassDependency {
53    /// Create a new dependency.
54    pub fn new(pass: impl Into<String>, depends_on: impl Into<String>) -> Self {
55        PassDependency {
56            pass: pass.into(),
57            depends_on: depends_on.into(),
58        }
59    }
60}
61#[allow(dead_code)]
62pub struct OPPassRegistry {
63    pub(super) configs: Vec<OPPassConfig>,
64    pub(super) stats: std::collections::HashMap<String, OPPassStats>,
65}
66impl OPPassRegistry {
67    #[allow(dead_code)]
68    pub fn new() -> Self {
69        OPPassRegistry {
70            configs: Vec::new(),
71            stats: std::collections::HashMap::new(),
72        }
73    }
74    #[allow(dead_code)]
75    pub fn register(&mut self, config: OPPassConfig) {
76        self.stats
77            .insert(config.pass_name.clone(), OPPassStats::new());
78        self.configs.push(config);
79    }
80    #[allow(dead_code)]
81    pub fn enabled_passes(&self) -> Vec<&OPPassConfig> {
82        self.configs.iter().filter(|c| c.enabled).collect()
83    }
84    #[allow(dead_code)]
85    pub fn get_stats(&self, name: &str) -> Option<&OPPassStats> {
86        self.stats.get(name)
87    }
88    #[allow(dead_code)]
89    pub fn total_passes(&self) -> usize {
90        self.configs.len()
91    }
92    #[allow(dead_code)]
93    pub fn enabled_count(&self) -> usize {
94        self.enabled_passes().len()
95    }
96    #[allow(dead_code)]
97    pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
98        if let Some(stats) = self.stats.get_mut(name) {
99            stats.record_run(changes, time_ms, iter);
100        }
101    }
102}
103#[allow(dead_code)]
104#[derive(Debug, Clone, PartialEq)]
105pub enum OPPassPhase {
106    Analysis,
107    Transformation,
108    Verification,
109    Cleanup,
110}
111impl OPPassPhase {
112    #[allow(dead_code)]
113    pub fn name(&self) -> &str {
114        match self {
115            OPPassPhase::Analysis => "analysis",
116            OPPassPhase::Transformation => "transformation",
117            OPPassPhase::Verification => "verification",
118            OPPassPhase::Cleanup => "cleanup",
119        }
120    }
121    #[allow(dead_code)]
122    pub fn is_modifying(&self) -> bool {
123        matches!(self, OPPassPhase::Transformation | OPPassPhase::Cleanup)
124    }
125}
126#[allow(dead_code)]
127#[derive(Debug, Clone)]
128pub struct OPWorklist {
129    pub(super) items: std::collections::VecDeque<u32>,
130    pub(super) in_worklist: std::collections::HashSet<u32>,
131}
132impl OPWorklist {
133    #[allow(dead_code)]
134    pub fn new() -> Self {
135        OPWorklist {
136            items: std::collections::VecDeque::new(),
137            in_worklist: std::collections::HashSet::new(),
138        }
139    }
140    #[allow(dead_code)]
141    pub fn push(&mut self, item: u32) -> bool {
142        if self.in_worklist.insert(item) {
143            self.items.push_back(item);
144            true
145        } else {
146            false
147        }
148    }
149    #[allow(dead_code)]
150    pub fn pop(&mut self) -> Option<u32> {
151        let item = self.items.pop_front()?;
152        self.in_worklist.remove(&item);
153        Some(item)
154    }
155    #[allow(dead_code)]
156    pub fn is_empty(&self) -> bool {
157        self.items.is_empty()
158    }
159    #[allow(dead_code)]
160    pub fn len(&self) -> usize {
161        self.items.len()
162    }
163    #[allow(dead_code)]
164    pub fn contains(&self, item: u32) -> bool {
165        self.in_worklist.contains(&item)
166    }
167}
168#[allow(dead_code)]
169#[derive(Debug, Clone)]
170pub struct OPAnalysisCache {
171    pub(super) entries: std::collections::HashMap<String, OPCacheEntry>,
172    pub(super) max_size: usize,
173    pub(super) hits: u64,
174    pub(super) misses: u64,
175}
176impl OPAnalysisCache {
177    #[allow(dead_code)]
178    pub fn new(max_size: usize) -> Self {
179        OPAnalysisCache {
180            entries: std::collections::HashMap::new(),
181            max_size,
182            hits: 0,
183            misses: 0,
184        }
185    }
186    #[allow(dead_code)]
187    pub fn get(&mut self, key: &str) -> Option<&OPCacheEntry> {
188        if self.entries.contains_key(key) {
189            self.hits += 1;
190            self.entries.get(key)
191        } else {
192            self.misses += 1;
193            None
194        }
195    }
196    #[allow(dead_code)]
197    pub fn insert(&mut self, key: String, data: Vec<u8>) {
198        if self.entries.len() >= self.max_size {
199            if let Some(oldest) = self.entries.keys().next().cloned() {
200                self.entries.remove(&oldest);
201            }
202        }
203        self.entries.insert(
204            key.clone(),
205            OPCacheEntry {
206                key,
207                data,
208                timestamp: 0,
209                valid: true,
210            },
211        );
212    }
213    #[allow(dead_code)]
214    pub fn invalidate(&mut self, key: &str) {
215        if let Some(entry) = self.entries.get_mut(key) {
216            entry.valid = false;
217        }
218    }
219    #[allow(dead_code)]
220    pub fn clear(&mut self) {
221        self.entries.clear();
222    }
223    #[allow(dead_code)]
224    pub fn hit_rate(&self) -> f64 {
225        let total = self.hits + self.misses;
226        if total == 0 {
227            return 0.0;
228        }
229        self.hits as f64 / total as f64
230    }
231    #[allow(dead_code)]
232    pub fn size(&self) -> usize {
233        self.entries.len()
234    }
235}
236/// Strength reduction: replaces expensive operations with cheaper equivalents.
237/// For example, multiplication by a power of 2 becomes a left shift.
238pub struct StrengthReductionPass {
239    pub reductions: u32,
240}
241impl StrengthReductionPass {
242    pub fn new() -> Self {
243        StrengthReductionPass { reductions: 0 }
244    }
245    /// Check if a value is a power of two.
246    pub fn is_power_of_two(n: u64) -> bool {
247        n > 0 && (n & (n - 1)) == 0
248    }
249    /// Compute log2 for a power of two, returning None if not a power of two.
250    pub fn log2_exact(n: u64) -> Option<u32> {
251        if Self::is_power_of_two(n) {
252            Some(n.trailing_zeros())
253        } else {
254            None
255        }
256    }
257    /// Check if a value is a power of two minus one (e.g. 0x7F, 0xFF, 0xFFFF).
258    pub fn is_mask(n: u64) -> bool {
259        n > 0 && (n & (n + 1)) == 0
260    }
261    /// Count trailing zeros.
262    pub fn ctz(n: u64) -> u32 {
263        if n == 0 {
264            64
265        } else {
266            n.trailing_zeros()
267        }
268    }
269    /// Count leading zeros.
270    pub fn clz(n: u64) -> u32 {
271        n.leading_zeros()
272    }
273    /// Population count (number of set bits).
274    pub fn popcount(n: u64) -> u32 {
275        n.count_ones()
276    }
277}
278/// Statistics for a single pass execution.
279#[derive(Debug, Clone, Default)]
280pub struct PassStats {
281    /// Name of the pass.
282    pub name: String,
283    /// Number of times the pass has been run.
284    pub run_count: u32,
285    /// Total number of changes made across all runs.
286    pub total_changes: usize,
287    /// Duration of the last run in microseconds.
288    pub last_duration_us: u64,
289    /// Whether the last run made any changes.
290    pub last_changed: bool,
291}
292impl PassStats {
293    /// Create a new stats entry for the named pass.
294    pub fn new(name: impl Into<String>) -> Self {
295        PassStats {
296            name: name.into(),
297            ..Default::default()
298        }
299    }
300    /// Record a run of this pass.
301    pub fn record_run(&mut self, changes: usize, duration_us: u64) {
302        self.run_count += 1;
303        self.total_changes += changes;
304        self.last_duration_us = duration_us;
305        self.last_changed = changes > 0;
306    }
307    /// Average changes per run.
308    pub fn avg_changes(&self) -> f64 {
309        if self.run_count == 0 {
310            0.0
311        } else {
312            self.total_changes as f64 / self.run_count as f64
313        }
314    }
315}
316/// Estimates the size and complexity of LCNF expressions.
317///
318/// Used by inlining heuristics to decide whether a function body is small
319/// enough to inline.
320pub struct ExprSizeEstimator;
321impl ExprSizeEstimator {
322    /// Count the number of let-bindings in an expression.
323    pub fn count_lets(expr: &LcnfExpr) -> usize {
324        match expr {
325            LcnfExpr::Let { body, .. } => 1 + Self::count_lets(body),
326            LcnfExpr::Case { alts, default, .. } => {
327                let alt_sum: usize = alts.iter().map(|a| Self::count_lets(&a.body)).sum();
328                let def_sum = default.as_ref().map(|d| Self::count_lets(d)).unwrap_or(0);
329                alt_sum + def_sum
330            }
331            _ => 0,
332        }
333    }
334    /// Count the number of case expressions.
335    pub fn count_cases(expr: &LcnfExpr) -> usize {
336        match expr {
337            LcnfExpr::Let { body, .. } => Self::count_cases(body),
338            LcnfExpr::Case { alts, default, .. } => {
339                let alt_sum: usize = alts.iter().map(|a| Self::count_cases(&a.body)).sum();
340                let def_sum = default.as_ref().map(|d| Self::count_cases(d)).unwrap_or(0);
341                1 + alt_sum + def_sum
342            }
343            _ => 0,
344        }
345    }
346    /// Compute a complexity score (lets + 2*cases + tail_calls).
347    pub fn complexity(expr: &LcnfExpr) -> usize {
348        match expr {
349            LcnfExpr::Let { body, .. } => 1 + Self::complexity(body),
350            LcnfExpr::Case { alts, default, .. } => {
351                let alt_sum: usize = alts.iter().map(|a| Self::complexity(&a.body)).sum();
352                let def_sum = default.as_ref().map(|d| Self::complexity(d)).unwrap_or(0);
353                2 + alt_sum + def_sum
354            }
355            LcnfExpr::TailCall(_, _) => 1,
356            LcnfExpr::Return(_) => 0,
357            LcnfExpr::Unreachable => 0,
358        }
359    }
360    /// Maximum nesting depth of the expression.
361    pub fn max_depth(expr: &LcnfExpr) -> usize {
362        match expr {
363            LcnfExpr::Let { body, .. } => 1 + Self::max_depth(body),
364            LcnfExpr::Case { alts, default, .. } => {
365                let max_alt = alts
366                    .iter()
367                    .map(|a| Self::max_depth(&a.body))
368                    .max()
369                    .unwrap_or(0);
370                let max_def = default.as_ref().map(|d| Self::max_depth(d)).unwrap_or(0);
371                1 + max_alt.max(max_def)
372            }
373            _ => 0,
374        }
375    }
376    /// Count all variable references in the expression.
377    pub fn count_var_refs(expr: &LcnfExpr) -> usize {
378        match expr {
379            LcnfExpr::Let { value, body, .. } => {
380                Self::count_var_refs_in_value(value) + Self::count_var_refs(body)
381            }
382            LcnfExpr::Case { alts, default, .. } => {
383                let alt_sum: usize = alts.iter().map(|a| Self::count_var_refs(&a.body)).sum();
384                let def_sum = default
385                    .as_ref()
386                    .map(|d| Self::count_var_refs(d))
387                    .unwrap_or(0);
388                1 + alt_sum + def_sum
389            }
390            LcnfExpr::Return(LcnfArg::Var(_)) => 1,
391            LcnfExpr::TailCall(f, args) => {
392                let f_count = if matches!(f, LcnfArg::Var(_)) { 1 } else { 0 };
393                let a_count = args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count();
394                f_count + a_count
395            }
396            _ => 0,
397        }
398    }
399    pub(super) fn count_var_refs_in_value(value: &LcnfLetValue) -> usize {
400        match value {
401            LcnfLetValue::App(f, args) => {
402                let f_count = if matches!(f, LcnfArg::Var(_)) { 1 } else { 0 };
403                let a_count = args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count();
404                f_count + a_count
405            }
406            LcnfLetValue::FVar(_) => 1,
407            LcnfLetValue::Proj(_, _, _) => 1,
408            LcnfLetValue::Reset(_) => 1,
409            LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
410                args.iter().filter(|a| matches!(a, LcnfArg::Var(_))).count()
411            }
412            LcnfLetValue::Lit(_) | LcnfLetValue::Erased => 0,
413        }
414    }
415    /// Whether an expression is "trivial" (just a return or unreachable).
416    pub fn is_trivial(expr: &LcnfExpr) -> bool {
417        matches!(expr, LcnfExpr::Return(_) | LcnfExpr::Unreachable)
418    }
419    /// Whether an expression is suitable for inlining (complexity below threshold).
420    pub fn should_inline(expr: &LcnfExpr, threshold: usize) -> bool {
421        Self::complexity(expr) <= threshold
422    }
423}
424/// Manages a pipeline of optimization passes with dependency ordering.
425///
426/// Passes are executed in topological order based on their declared
427/// dependencies. Cycle detection uses Kahn's algorithm.
428#[derive(Debug, Default)]
429pub struct PassManager {
430    /// Registered pass names in insertion order.
431    pub(super) pass_names: Vec<String>,
432    /// Dependencies between passes.
433    pub(super) dependencies: Vec<PassDependency>,
434    /// Per-pass statistics.
435    pub(super) stats: HashMap<String, PassStats>,
436    /// Maximum number of fixed-point iterations.
437    pub max_iterations: u32,
438}
439impl PassManager {
440    /// Create a new pass manager.
441    pub fn new() -> Self {
442        PassManager {
443            pass_names: Vec::new(),
444            dependencies: Vec::new(),
445            stats: HashMap::new(),
446            max_iterations: 10,
447        }
448    }
449    /// Register a pass by name.
450    pub fn add_pass(&mut self, name: impl Into<String>) {
451        let n = name.into();
452        if !self.pass_names.contains(&n) {
453            self.stats.insert(n.clone(), PassStats::new(&n));
454            self.pass_names.push(n);
455        }
456    }
457    /// Add a dependency: `pass` depends on `depends_on`.
458    pub fn add_dependency(&mut self, pass: impl Into<String>, depends_on: impl Into<String>) {
459        let dep = PassDependency::new(pass, depends_on);
460        if !self.dependencies.contains(&dep) {
461            self.dependencies.push(dep);
462        }
463    }
464    /// Record a run of the named pass.
465    pub fn record_run(&mut self, name: &str, changes: usize, duration_us: u64) {
466        if let Some(stats) = self.stats.get_mut(name) {
467            stats.record_run(changes, duration_us);
468        }
469    }
470    /// Get statistics for a named pass.
471    pub fn get_stats(&self, name: &str) -> Option<&PassStats> {
472        self.stats.get(name)
473    }
474    /// Get all statistics.
475    pub fn all_stats(&self) -> &HashMap<String, PassStats> {
476        &self.stats
477    }
478    /// Number of registered passes.
479    pub fn num_passes(&self) -> usize {
480        self.pass_names.len()
481    }
482    /// Compute topological ordering of passes using Kahn's algorithm.
483    ///
484    /// Returns `None` if there is a cycle in the dependency graph.
485    pub fn topological_order(&self) -> Option<Vec<String>> {
486        let mut in_degree: HashMap<&str, usize> = HashMap::new();
487        let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
488        for name in &self.pass_names {
489            in_degree.insert(name.as_str(), 0);
490            adj.entry(name.as_str()).or_default();
491        }
492        for dep in &self.dependencies {
493            if self.pass_names.contains(&dep.pass) && self.pass_names.contains(&dep.depends_on) {
494                adj.entry(dep.depends_on.as_str())
495                    .or_default()
496                    .push(dep.pass.as_str());
497                *in_degree.entry(dep.pass.as_str()).or_insert(0) += 1;
498            }
499        }
500        let mut queue: Vec<&str> = in_degree
501            .iter()
502            .filter(|(_, &deg)| deg == 0)
503            .map(|(&name, _)| name)
504            .collect();
505        queue.sort();
506        let mut result = Vec::new();
507        while let Some(node) = queue.pop() {
508            result.push(node.to_string());
509            if let Some(neighbors) = adj.get(node) {
510                for &neighbor in neighbors {
511                    let deg = in_degree
512                        .get_mut(neighbor)
513                        .expect(
514                            "neighbor must be in in_degree; all passes were inserted during initialization",
515                        );
516                    *deg -= 1;
517                    if *deg == 0 {
518                        queue.push(neighbor);
519                        queue.sort();
520                    }
521                }
522            }
523        }
524        if result.len() == self.pass_names.len() {
525            Some(result)
526        } else {
527            None
528        }
529    }
530    /// Check if the dependency graph has a cycle.
531    pub fn has_cycle(&self) -> bool {
532        self.topological_order().is_none()
533    }
534    /// Total changes across all passes.
535    pub fn total_changes(&self) -> usize {
536        self.stats.values().map(|s| s.total_changes).sum()
537    }
538    /// Total runs across all passes.
539    pub fn total_runs(&self) -> u32 {
540        self.stats.values().map(|s| s.run_count).sum()
541    }
542}
543/// Constant folding pass -- evaluate constant expressions at compile time.
544pub struct ConstantFoldingPass {
545    pub folds_performed: u32,
546}
547impl ConstantFoldingPass {
548    pub fn new() -> Self {
549        ConstantFoldingPass { folds_performed: 0 }
550    }
551    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
552        for decl in decls.iter_mut() {
553            self.fold_expr(&mut decl.body);
554        }
555    }
556    pub(super) fn fold_expr(&mut self, expr: &mut LcnfExpr) {
557        match expr {
558            LcnfExpr::Let { value, body, .. } => {
559                if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Nat(lhs)), args) = value {
560                    if args.len() == 2 {
561                        if let (LcnfArg::Lit(LcnfLit::Nat(rhs)), LcnfArg::Lit(LcnfLit::Nat(op_n))) =
562                            (&args[0], &args[1])
563                        {
564                            let op = match op_n {
565                                0 => "add",
566                                1 => "sub",
567                                2 => "mul",
568                                _ => "",
569                            };
570                            if let Some(result) = self.try_fold_nat_op(op, *lhs, *rhs) {
571                                *value = LcnfLetValue::Lit(LcnfLit::Nat(result));
572                                self.folds_performed += 1;
573                            }
574                        }
575                    }
576                }
577                self.fold_expr(body);
578            }
579            LcnfExpr::Case { alts, default, .. } => {
580                for alt in alts.iter_mut() {
581                    self.fold_expr(&mut alt.body);
582                }
583                if let Some(def) = default {
584                    self.fold_expr(def);
585                }
586            }
587            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
588        }
589    }
590    /// Try to fold a nat binary operation.
591    pub fn try_fold_nat_op(&self, op: &str, lhs: u64, rhs: u64) -> Option<u64> {
592        match op {
593            "add" => Some(lhs.wrapping_add(rhs)),
594            "sub" => Some(lhs.saturating_sub(rhs)),
595            "mul" => Some(lhs.wrapping_mul(rhs)),
596            "div" => {
597                if rhs == 0 {
598                    None
599                } else {
600                    Some(lhs / rhs)
601                }
602            }
603            "mod" => {
604                if rhs == 0 {
605                    None
606                } else {
607                    Some(lhs % rhs)
608                }
609            }
610            "min" => Some(lhs.min(rhs)),
611            "max" => Some(lhs.max(rhs)),
612            "pow" => Some(lhs.wrapping_pow(rhs as u32)),
613            "and" => Some(lhs & rhs),
614            "or" => Some(lhs | rhs),
615            "xor" => Some(lhs ^ rhs),
616            "shl" => Some(lhs.wrapping_shl(rhs as u32)),
617            "shr" => Some(lhs.wrapping_shr(rhs as u32)),
618            _ => None,
619        }
620    }
621    /// Try to fold a boolean operation.
622    pub fn try_fold_bool_op(&self, op: &str, lhs: bool, rhs: bool) -> Option<bool> {
623        match op {
624            "and" => Some(lhs && rhs),
625            "or" => Some(lhs || rhs),
626            "xor" => Some(lhs ^ rhs),
627            "eq" => Some(lhs == rhs),
628            "ne" => Some(lhs != rhs),
629            _ => None,
630        }
631    }
632    /// Try to fold a comparison operation.
633    pub fn try_fold_cmp(&self, op: &str, lhs: u64, rhs: u64) -> Option<bool> {
634        match op {
635            "eq" => Some(lhs == rhs),
636            "ne" => Some(lhs != rhs),
637            "lt" => Some(lhs < rhs),
638            "le" => Some(lhs <= rhs),
639            "gt" => Some(lhs > rhs),
640            "ge" => Some(lhs >= rhs),
641            _ => None,
642        }
643    }
644}
645#[allow(dead_code)]
646#[derive(Debug, Clone)]
647pub struct OPCacheEntry {
648    pub key: String,
649    pub data: Vec<u8>,
650    pub timestamp: u64,
651    pub valid: bool,
652}
653#[allow(dead_code)]
654#[derive(Debug, Clone)]
655pub struct OPDominatorTree {
656    pub idom: Vec<Option<u32>>,
657    pub dom_children: Vec<Vec<u32>>,
658    pub dom_depth: Vec<u32>,
659}
660impl OPDominatorTree {
661    #[allow(dead_code)]
662    pub fn new(size: usize) -> Self {
663        OPDominatorTree {
664            idom: vec![None; size],
665            dom_children: vec![Vec::new(); size],
666            dom_depth: vec![0; size],
667        }
668    }
669    #[allow(dead_code)]
670    pub fn set_idom(&mut self, node: usize, idom: u32) {
671        self.idom[node] = Some(idom);
672    }
673    #[allow(dead_code)]
674    pub fn dominates(&self, a: usize, b: usize) -> bool {
675        if a == b {
676            return true;
677        }
678        let mut cur = b;
679        loop {
680            match self.idom[cur] {
681                Some(parent) if parent as usize == a => return true,
682                Some(parent) if parent as usize == cur => return false,
683                Some(parent) => cur = parent as usize,
684                None => return false,
685            }
686        }
687    }
688    #[allow(dead_code)]
689    pub fn depth(&self, node: usize) -> u32 {
690        self.dom_depth.get(node).copied().unwrap_or(0)
691    }
692}
693/// Profile-guided optimization hints
694#[derive(Debug, Clone)]
695pub struct PgoHints {
696    pub hot_functions: Vec<String>,
697    pub likely_branches: Vec<(String, u32, bool)>,
698    pub inline_candidates: Vec<String>,
699    pub cold_functions: Vec<String>,
700    pub call_counts: HashMap<String, u64>,
701}
702impl PgoHints {
703    pub fn new() -> Self {
704        PgoHints {
705            hot_functions: Vec::new(),
706            likely_branches: Vec::new(),
707            inline_candidates: Vec::new(),
708            cold_functions: Vec::new(),
709            call_counts: HashMap::new(),
710        }
711    }
712    pub fn mark_hot(&mut self, func_name: &str) {
713        if !self.hot_functions.iter().any(|f| f == func_name) {
714            self.hot_functions.push(func_name.to_string());
715        }
716    }
717    pub fn mark_cold(&mut self, func_name: &str) {
718        if !self.cold_functions.iter().any(|f| f == func_name) {
719            self.cold_functions.push(func_name.to_string());
720        }
721    }
722    pub fn mark_inline(&mut self, func_name: &str) {
723        if !self.inline_candidates.iter().any(|f| f == func_name) {
724            self.inline_candidates.push(func_name.to_string());
725        }
726    }
727    pub fn record_call(&mut self, func_name: &str, count: u64) {
728        *self.call_counts.entry(func_name.to_string()).or_insert(0) += count;
729    }
730    pub fn is_hot(&self, func_name: &str) -> bool {
731        self.hot_functions.iter().any(|f| f == func_name)
732    }
733    pub fn is_cold(&self, func_name: &str) -> bool {
734        self.cold_functions.iter().any(|f| f == func_name)
735    }
736    pub fn should_inline(&self, func_name: &str) -> bool {
737        self.inline_candidates.iter().any(|f| f == func_name)
738    }
739    pub fn call_count(&self, func_name: &str) -> u64 {
740        self.call_counts.get(func_name).copied().unwrap_or(0)
741    }
742    /// Total number of hints across all categories.
743    pub fn total_hints(&self) -> usize {
744        self.hot_functions.len()
745            + self.cold_functions.len()
746            + self.inline_candidates.len()
747            + self.likely_branches.len()
748            + self.call_counts.len()
749    }
750    /// Merge another set of hints into this one.
751    pub fn merge(&mut self, other: &PgoHints) {
752        for f in &other.hot_functions {
753            self.mark_hot(f);
754        }
755        for f in &other.cold_functions {
756            self.mark_cold(f);
757        }
758        for f in &other.inline_candidates {
759            self.mark_inline(f);
760        }
761        for (name, count) in &other.call_counts {
762            self.record_call(name, *count);
763        }
764    }
765    /// Classify a function by its hotness: Hot, Cold, or Normal.
766    pub fn classify(&self, func_name: &str) -> &'static str {
767        if self.is_hot(func_name) {
768            "hot"
769        } else if self.is_cold(func_name) {
770            "cold"
771        } else {
772            "normal"
773        }
774    }
775}
776/// Dead code elimination -- remove unreachable let expressions.
777pub struct DeadCodeEliminationPass {
778    pub removed: u32,
779}
780impl DeadCodeEliminationPass {
781    pub fn new() -> Self {
782        DeadCodeEliminationPass { removed: 0 }
783    }
784    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
785        for decl in decls.iter_mut() {
786            let mut used: HashSet<LcnfVarId> = HashSet::new();
787            Self::collect_used_vars(&decl.body, &mut used);
788            let mut body = decl.body.clone();
789            self.eliminate_dead_lets(&mut body, &used);
790            decl.body = body;
791        }
792    }
793    pub(super) fn eliminate_dead_lets(&mut self, expr: &mut LcnfExpr, used: &HashSet<LcnfVarId>) {
794        match expr {
795            LcnfExpr::Let {
796                id, value, body, ..
797            } => {
798                let is_pure = matches!(
799                    value,
800                    LcnfLetValue::Lit(_) | LcnfLetValue::FVar(_) | LcnfLetValue::Erased
801                );
802                if is_pure && !used.contains(id) {
803                    let new_body = *body.clone();
804                    *expr = new_body;
805                    self.removed += 1;
806                    self.eliminate_dead_lets(expr, used);
807                } else {
808                    self.eliminate_dead_lets(body, used);
809                }
810            }
811            LcnfExpr::Case { alts, default, .. } => {
812                for alt in alts.iter_mut() {
813                    self.eliminate_dead_lets(&mut alt.body, used);
814                }
815                if let Some(def) = default {
816                    self.eliminate_dead_lets(def, used);
817                }
818            }
819            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
820        }
821    }
822    pub(super) fn collect_used_vars(expr: &LcnfExpr, used: &mut HashSet<LcnfVarId>) {
823        match expr {
824            LcnfExpr::Let {
825                id: _, value, body, ..
826            } => {
827                match value {
828                    LcnfLetValue::App(func, args) => {
829                        if let LcnfArg::Var(v) = func {
830                            used.insert(*v);
831                        }
832                        for a in args {
833                            if let LcnfArg::Var(v) = a {
834                                used.insert(*v);
835                            }
836                        }
837                    }
838                    LcnfLetValue::FVar(v) => {
839                        used.insert(*v);
840                    }
841                    LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
842                        for a in args {
843                            if let LcnfArg::Var(v) = a {
844                                used.insert(*v);
845                            }
846                        }
847                    }
848                    LcnfLetValue::Proj(_, _, v) => {
849                        used.insert(*v);
850                    }
851                    LcnfLetValue::Reset(v) => {
852                        used.insert(*v);
853                    }
854                    LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
855                }
856                Self::collect_used_vars(body, used);
857            }
858            LcnfExpr::Case {
859                scrutinee,
860                alts,
861                default,
862                ..
863            } => {
864                used.insert(*scrutinee);
865                for alt in alts {
866                    Self::collect_used_vars(&alt.body, used);
867                }
868                if let Some(def) = default {
869                    Self::collect_used_vars(def, used);
870                }
871            }
872            LcnfExpr::Return(a) | LcnfExpr::TailCall(a, _) => {
873                if let LcnfArg::Var(v) = a {
874                    used.insert(*v);
875                }
876                if let LcnfExpr::TailCall(_, args) = expr {
877                    for a in args {
878                        if let LcnfArg::Var(v) = a {
879                            used.insert(*v);
880                        }
881                    }
882                }
883            }
884            LcnfExpr::Unreachable => {}
885        }
886    }
887}
888#[allow(dead_code)]
889#[derive(Debug, Clone)]
890pub struct OPPassConfig {
891    pub phase: OPPassPhase,
892    pub enabled: bool,
893    pub max_iterations: u32,
894    pub debug_output: bool,
895    pub pass_name: String,
896}
897impl OPPassConfig {
898    #[allow(dead_code)]
899    pub fn new(name: impl Into<String>, phase: OPPassPhase) -> Self {
900        OPPassConfig {
901            phase,
902            enabled: true,
903            max_iterations: 10,
904            debug_output: false,
905            pass_name: name.into(),
906        }
907    }
908    #[allow(dead_code)]
909    pub fn disabled(mut self) -> Self {
910        self.enabled = false;
911        self
912    }
913    #[allow(dead_code)]
914    pub fn with_debug(mut self) -> Self {
915        self.debug_output = true;
916        self
917    }
918    #[allow(dead_code)]
919    pub fn max_iter(mut self, n: u32) -> Self {
920        self.max_iterations = n;
921        self
922    }
923}
924/// Copy propagation -- replace uses of copied variables with originals.
925pub struct CopyPropagationPass {
926    pub substitutions: u32,
927}
928impl CopyPropagationPass {
929    pub fn new() -> Self {
930        CopyPropagationPass { substitutions: 0 }
931    }
932    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
933        for decl in decls.iter_mut() {
934            self.propagate_copies_in_expr(&mut decl.body);
935        }
936    }
937    pub(super) fn propagate_copies_in_expr(&mut self, expr: &mut LcnfExpr) {
938        if let LcnfExpr::Let {
939            id,
940            value: LcnfLetValue::FVar(src),
941            body,
942            ..
943        } = expr
944        {
945            let from = *id;
946            let to = *src;
947            substitute_var_in_expr(body, from, to);
948            self.substitutions += 1;
949            self.propagate_copies_in_expr(body);
950        } else {
951            match expr {
952                LcnfExpr::Let { body, .. } => self.propagate_copies_in_expr(body),
953                LcnfExpr::Case { alts, default, .. } => {
954                    for alt in alts.iter_mut() {
955                        self.propagate_copies_in_expr(&mut alt.body);
956                    }
957                    if let Some(def) = default {
958                        self.propagate_copies_in_expr(def);
959                    }
960                }
961                _ => {}
962            }
963        }
964    }
965}
966#[allow(dead_code)]
967#[derive(Debug, Clone, Default)]
968pub struct OPPassStats {
969    pub total_runs: u32,
970    pub successful_runs: u32,
971    pub total_changes: u64,
972    pub time_ms: u64,
973    pub iterations_used: u32,
974}
975impl OPPassStats {
976    #[allow(dead_code)]
977    pub fn new() -> Self {
978        Self::default()
979    }
980    #[allow(dead_code)]
981    pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
982        self.total_runs += 1;
983        self.successful_runs += 1;
984        self.total_changes += changes;
985        self.time_ms += time_ms;
986        self.iterations_used = iterations;
987    }
988    #[allow(dead_code)]
989    pub fn average_changes_per_run(&self) -> f64 {
990        if self.total_runs == 0 {
991            return 0.0;
992        }
993        self.total_changes as f64 / self.total_runs as f64
994    }
995    #[allow(dead_code)]
996    pub fn success_rate(&self) -> f64 {
997        if self.total_runs == 0 {
998            return 0.0;
999        }
1000        self.successful_runs as f64 / self.total_runs as f64
1001    }
1002    #[allow(dead_code)]
1003    pub fn format_summary(&self) -> String {
1004        format!(
1005            "Runs: {}/{}, Changes: {}, Time: {}ms",
1006            self.successful_runs, self.total_runs, self.total_changes, self.time_ms
1007        )
1008    }
1009}
1010#[allow(dead_code)]
1011#[derive(Debug, Clone)]
1012pub struct OPLivenessInfo {
1013    pub live_in: Vec<std::collections::HashSet<u32>>,
1014    pub live_out: Vec<std::collections::HashSet<u32>>,
1015    pub defs: Vec<std::collections::HashSet<u32>>,
1016    pub uses: Vec<std::collections::HashSet<u32>>,
1017}
1018impl OPLivenessInfo {
1019    #[allow(dead_code)]
1020    pub fn new(block_count: usize) -> Self {
1021        OPLivenessInfo {
1022            live_in: vec![std::collections::HashSet::new(); block_count],
1023            live_out: vec![std::collections::HashSet::new(); block_count],
1024            defs: vec![std::collections::HashSet::new(); block_count],
1025            uses: vec![std::collections::HashSet::new(); block_count],
1026        }
1027    }
1028    #[allow(dead_code)]
1029    pub fn add_def(&mut self, block: usize, var: u32) {
1030        if block < self.defs.len() {
1031            self.defs[block].insert(var);
1032        }
1033    }
1034    #[allow(dead_code)]
1035    pub fn add_use(&mut self, block: usize, var: u32) {
1036        if block < self.uses.len() {
1037            self.uses[block].insert(var);
1038        }
1039    }
1040    #[allow(dead_code)]
1041    pub fn is_live_in(&self, block: usize, var: u32) -> bool {
1042        self.live_in
1043            .get(block)
1044            .map(|s| s.contains(&var))
1045            .unwrap_or(false)
1046    }
1047    #[allow(dead_code)]
1048    pub fn is_live_out(&self, block: usize, var: u32) -> bool {
1049        self.live_out
1050            .get(block)
1051            .map(|s| s.contains(&var))
1052            .unwrap_or(false)
1053    }
1054}
1055#[allow(dead_code)]
1056#[derive(Debug, Clone)]
1057pub struct OPDepGraph {
1058    pub(super) nodes: Vec<u32>,
1059    pub(super) edges: Vec<(u32, u32)>,
1060}
1061impl OPDepGraph {
1062    #[allow(dead_code)]
1063    pub fn new() -> Self {
1064        OPDepGraph {
1065            nodes: Vec::new(),
1066            edges: Vec::new(),
1067        }
1068    }
1069    #[allow(dead_code)]
1070    pub fn add_node(&mut self, id: u32) {
1071        if !self.nodes.contains(&id) {
1072            self.nodes.push(id);
1073        }
1074    }
1075    #[allow(dead_code)]
1076    pub fn add_dep(&mut self, dep: u32, dependent: u32) {
1077        self.add_node(dep);
1078        self.add_node(dependent);
1079        self.edges.push((dep, dependent));
1080    }
1081    #[allow(dead_code)]
1082    pub fn dependents_of(&self, node: u32) -> Vec<u32> {
1083        self.edges
1084            .iter()
1085            .filter(|(d, _)| *d == node)
1086            .map(|(_, dep)| *dep)
1087            .collect()
1088    }
1089    #[allow(dead_code)]
1090    pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
1091        self.edges
1092            .iter()
1093            .filter(|(_, dep)| *dep == node)
1094            .map(|(d, _)| *d)
1095            .collect()
1096    }
1097    #[allow(dead_code)]
1098    pub fn topological_sort(&self) -> Vec<u32> {
1099        let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
1100        for &n in &self.nodes {
1101            in_degree.insert(n, 0);
1102        }
1103        for (_, dep) in &self.edges {
1104            *in_degree.entry(*dep).or_insert(0) += 1;
1105        }
1106        let mut queue: std::collections::VecDeque<u32> = self
1107            .nodes
1108            .iter()
1109            .filter(|&&n| in_degree[&n] == 0)
1110            .copied()
1111            .collect();
1112        let mut result = Vec::new();
1113        while let Some(node) = queue.pop_front() {
1114            result.push(node);
1115            for dep in self.dependents_of(node) {
1116                let cnt = in_degree.entry(dep).or_insert(0);
1117                *cnt = cnt.saturating_sub(1);
1118                if *cnt == 0 {
1119                    queue.push_back(dep);
1120                }
1121            }
1122        }
1123        result
1124    }
1125    #[allow(dead_code)]
1126    pub fn has_cycle(&self) -> bool {
1127        self.topological_sort().len() < self.nodes.len()
1128    }
1129}
1130#[allow(dead_code)]
1131pub struct OPConstantFoldingHelper;
1132impl OPConstantFoldingHelper {
1133    #[allow(dead_code)]
1134    pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
1135        a.checked_add(b)
1136    }
1137    #[allow(dead_code)]
1138    pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
1139        a.checked_sub(b)
1140    }
1141    #[allow(dead_code)]
1142    pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
1143        a.checked_mul(b)
1144    }
1145    #[allow(dead_code)]
1146    pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
1147        if b == 0 {
1148            None
1149        } else {
1150            a.checked_div(b)
1151        }
1152    }
1153    #[allow(dead_code)]
1154    pub fn fold_add_f64(a: f64, b: f64) -> f64 {
1155        a + b
1156    }
1157    #[allow(dead_code)]
1158    pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
1159        a * b
1160    }
1161    #[allow(dead_code)]
1162    pub fn fold_neg_i64(a: i64) -> Option<i64> {
1163        a.checked_neg()
1164    }
1165    #[allow(dead_code)]
1166    pub fn fold_not_bool(a: bool) -> bool {
1167        !a
1168    }
1169    #[allow(dead_code)]
1170    pub fn fold_and_bool(a: bool, b: bool) -> bool {
1171        a && b
1172    }
1173    #[allow(dead_code)]
1174    pub fn fold_or_bool(a: bool, b: bool) -> bool {
1175        a || b
1176    }
1177    #[allow(dead_code)]
1178    pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
1179        a.checked_shl(b)
1180    }
1181    #[allow(dead_code)]
1182    pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
1183        a.checked_shr(b)
1184    }
1185    #[allow(dead_code)]
1186    pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
1187        if b == 0 {
1188            None
1189        } else {
1190            Some(a % b)
1191        }
1192    }
1193    #[allow(dead_code)]
1194    pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
1195        a & b
1196    }
1197    #[allow(dead_code)]
1198    pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
1199        a | b
1200    }
1201    #[allow(dead_code)]
1202    pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
1203        a ^ b
1204    }
1205    #[allow(dead_code)]
1206    pub fn fold_bitnot_i64(a: i64) -> i64 {
1207        !a
1208    }
1209}
1210/// Estimates the cost of inlining a function.
1211#[derive(Debug, Clone)]
1212pub struct InlineCostEstimator {
1213    /// Base cost threshold below which functions are always inlined.
1214    pub always_inline_threshold: usize,
1215    /// Threshold for functions in hot call sites.
1216    pub hot_threshold: usize,
1217    /// Threshold for cold call sites.
1218    pub cold_threshold: usize,
1219    /// Bonus for tail-recursive functions (they benefit less from inlining).
1220    pub tail_recursive_penalty: usize,
1221}
1222impl InlineCostEstimator {
1223    /// Compute the inlining cost for a function body.
1224    pub fn cost(&self, decl: &LcnfFunDecl) -> usize {
1225        let base = ExprSizeEstimator::complexity(&decl.body);
1226        let penalty = if decl.is_recursive {
1227            self.tail_recursive_penalty
1228        } else {
1229            0
1230        };
1231        base + penalty
1232    }
1233    /// Decide whether to inline based on cost and PGO hints.
1234    pub fn should_inline(&self, decl: &LcnfFunDecl, pgo: Option<&PgoHints>) -> bool {
1235        let cost = self.cost(decl);
1236        if cost <= self.always_inline_threshold {
1237            return true;
1238        }
1239        if let Some(hints) = pgo {
1240            if hints.should_inline(&decl.name) {
1241                return true;
1242            }
1243            if hints.is_hot(&decl.name) {
1244                return cost <= self.hot_threshold;
1245            }
1246            if hints.is_cold(&decl.name) {
1247                return cost <= self.cold_threshold;
1248            }
1249        }
1250        cost <= self.cold_threshold
1251    }
1252}
1253/// Eliminates identity let-bindings of the form `let x = x`.
1254pub struct IdentityEliminationPass {
1255    pub eliminated: u32,
1256}
1257impl IdentityEliminationPass {
1258    pub fn new() -> Self {
1259        IdentityEliminationPass { eliminated: 0 }
1260    }
1261    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
1262        for decl in decls.iter_mut() {
1263            self.elim_expr(&mut decl.body);
1264        }
1265    }
1266    pub(super) fn elim_expr(&mut self, expr: &mut LcnfExpr) {
1267        loop {
1268            if let LcnfExpr::Let {
1269                id,
1270                value: LcnfLetValue::FVar(src),
1271                body,
1272                ..
1273            } = expr
1274            {
1275                if *id == *src {
1276                    let new_body = *body.clone();
1277                    *expr = new_body;
1278                    self.eliminated += 1;
1279                    continue;
1280                }
1281            }
1282            break;
1283        }
1284        match expr {
1285            LcnfExpr::Let { body, .. } => self.elim_expr(body),
1286            LcnfExpr::Case { alts, default, .. } => {
1287                for alt in alts.iter_mut() {
1288                    self.elim_expr(&mut alt.body);
1289                }
1290                if let Some(def) = default {
1291                    self.elim_expr(def);
1292                }
1293            }
1294            _ => {}
1295        }
1296    }
1297}
1298/// Eliminates code after `Unreachable` terminators.
1299pub struct UnreachableCodeEliminationPass {
1300    pub eliminated: u32,
1301}
1302impl UnreachableCodeEliminationPass {
1303    pub fn new() -> Self {
1304        UnreachableCodeEliminationPass { eliminated: 0 }
1305    }
1306    pub fn run(&mut self, decls: &mut [LcnfFunDecl]) {
1307        for decl in decls.iter_mut() {
1308            self.elim_expr(&mut decl.body);
1309        }
1310    }
1311    pub(super) fn elim_expr(&mut self, expr: &mut LcnfExpr) {
1312        match expr {
1313            LcnfExpr::Let { body, .. } => {
1314                self.elim_expr(body);
1315                if matches!(**body, LcnfExpr::Unreachable) {
1316                    *expr = LcnfExpr::Unreachable;
1317                    self.eliminated += 1;
1318                }
1319            }
1320            LcnfExpr::Case { alts, default, .. } => {
1321                for alt in alts.iter_mut() {
1322                    self.elim_expr(&mut alt.body);
1323                }
1324                if let Some(def) = default {
1325                    self.elim_expr(def);
1326                }
1327            }
1328            _ => {}
1329        }
1330    }
1331}