Skip to main content

oxilean_codegen/opt_join/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::functions::*;
9use std::collections::VecDeque;
10
11/// Configuration for join point optimization
12#[derive(Debug, Clone)]
13pub struct JoinPointConfig {
14    /// Maximum size (in instructions) for a join point to be inlined
15    pub max_join_size: usize,
16    /// Whether to inline small join points
17    pub inline_small_joins: bool,
18    /// Whether to detect and mark tail calls
19    pub detect_tail_calls: bool,
20    /// Whether to perform contification
21    pub enable_contification: bool,
22    /// Whether to float join points closer to uses
23    pub float_join_points: bool,
24    /// Whether to eliminate dead join points
25    pub eliminate_dead_joins: bool,
26    /// Maximum number of optimization iterations
27    pub max_iterations: usize,
28}
29/// A generic key-value configuration store for OJoin.
30#[derive(Debug, Clone, Default)]
31pub struct OJoinConfig {
32    pub(super) entries: std::collections::HashMap<String, String>,
33}
34impl OJoinConfig {
35    pub fn new() -> Self {
36        OJoinConfig::default()
37    }
38    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
39        self.entries.insert(key.into(), value.into());
40    }
41    pub fn get(&self, key: &str) -> Option<&str> {
42        self.entries.get(key).map(|s| s.as_str())
43    }
44    pub fn get_bool(&self, key: &str) -> bool {
45        matches!(self.get(key), Some("true") | Some("1") | Some("yes"))
46    }
47    pub fn get_int(&self, key: &str) -> Option<i64> {
48        self.get(key)?.parse().ok()
49    }
50    pub fn len(&self) -> usize {
51        self.entries.len()
52    }
53    pub fn is_empty(&self) -> bool {
54        self.entries.is_empty()
55    }
56}
57/// A feature flag set for OJoin capabilities.
58#[derive(Debug, Clone, Default)]
59pub struct OJoinFeatures {
60    pub(super) flags: std::collections::HashSet<String>,
61}
62impl OJoinFeatures {
63    pub fn new() -> Self {
64        OJoinFeatures::default()
65    }
66    pub fn enable(&mut self, flag: impl Into<String>) {
67        self.flags.insert(flag.into());
68    }
69    pub fn disable(&mut self, flag: &str) {
70        self.flags.remove(flag);
71    }
72    pub fn is_enabled(&self, flag: &str) -> bool {
73        self.flags.contains(flag)
74    }
75    pub fn len(&self) -> usize {
76        self.flags.len()
77    }
78    pub fn is_empty(&self) -> bool {
79        self.flags.is_empty()
80    }
81    pub fn union(&self, other: &OJoinFeatures) -> OJoinFeatures {
82        OJoinFeatures {
83            flags: self.flags.union(&other.flags).cloned().collect(),
84        }
85    }
86    pub fn intersection(&self, other: &OJoinFeatures) -> OJoinFeatures {
87        OJoinFeatures {
88            flags: self.flags.intersection(&other.flags).cloned().collect(),
89        }
90    }
91}
92#[allow(dead_code)]
93pub struct OJConstantFoldingHelper;
94impl OJConstantFoldingHelper {
95    #[allow(dead_code)]
96    pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
97        a.checked_add(b)
98    }
99    #[allow(dead_code)]
100    pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
101        a.checked_sub(b)
102    }
103    #[allow(dead_code)]
104    pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
105        a.checked_mul(b)
106    }
107    #[allow(dead_code)]
108    pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
109        if b == 0 {
110            None
111        } else {
112            a.checked_div(b)
113        }
114    }
115    #[allow(dead_code)]
116    pub fn fold_add_f64(a: f64, b: f64) -> f64 {
117        a + b
118    }
119    #[allow(dead_code)]
120    pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
121        a * b
122    }
123    #[allow(dead_code)]
124    pub fn fold_neg_i64(a: i64) -> Option<i64> {
125        a.checked_neg()
126    }
127    #[allow(dead_code)]
128    pub fn fold_not_bool(a: bool) -> bool {
129        !a
130    }
131    #[allow(dead_code)]
132    pub fn fold_and_bool(a: bool, b: bool) -> bool {
133        a && b
134    }
135    #[allow(dead_code)]
136    pub fn fold_or_bool(a: bool, b: bool) -> bool {
137        a || b
138    }
139    #[allow(dead_code)]
140    pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
141        a.checked_shl(b)
142    }
143    #[allow(dead_code)]
144    pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
145        a.checked_shr(b)
146    }
147    #[allow(dead_code)]
148    pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
149        if b == 0 {
150            None
151        } else {
152            Some(a % b)
153        }
154    }
155    #[allow(dead_code)]
156    pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
157        a & b
158    }
159    #[allow(dead_code)]
160    pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
161        a | b
162    }
163    #[allow(dead_code)]
164    pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
165        a ^ b
166    }
167    #[allow(dead_code)]
168    pub fn fold_bitnot_i64(a: i64) -> i64 {
169        !a
170    }
171}
172/// Information about a function call site
173#[derive(Debug, Clone)]
174pub struct CallSiteInfo {
175    /// The calling function
176    pub(super) caller: String,
177    /// Whether this call is in tail position
178    pub(super) is_tail: bool,
179    /// Number of arguments passed
180    pub(super) arg_count: usize,
181    /// The variable ID of the callee if it's a local var
182    pub(super) callee_var: Option<LcnfVarId>,
183}
184#[allow(dead_code)]
185#[derive(Debug, Clone)]
186pub struct OJDepGraph {
187    pub(super) nodes: Vec<u32>,
188    pub(super) edges: Vec<(u32, u32)>,
189}
190impl OJDepGraph {
191    #[allow(dead_code)]
192    pub fn new() -> Self {
193        OJDepGraph {
194            nodes: Vec::new(),
195            edges: Vec::new(),
196        }
197    }
198    #[allow(dead_code)]
199    pub fn add_node(&mut self, id: u32) {
200        if !self.nodes.contains(&id) {
201            self.nodes.push(id);
202        }
203    }
204    #[allow(dead_code)]
205    pub fn add_dep(&mut self, dep: u32, dependent: u32) {
206        self.add_node(dep);
207        self.add_node(dependent);
208        self.edges.push((dep, dependent));
209    }
210    #[allow(dead_code)]
211    pub fn dependents_of(&self, node: u32) -> Vec<u32> {
212        self.edges
213            .iter()
214            .filter(|(d, _)| *d == node)
215            .map(|(_, dep)| *dep)
216            .collect()
217    }
218    #[allow(dead_code)]
219    pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
220        self.edges
221            .iter()
222            .filter(|(_, dep)| *dep == node)
223            .map(|(d, _)| *d)
224            .collect()
225    }
226    #[allow(dead_code)]
227    pub fn topological_sort(&self) -> Vec<u32> {
228        let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
229        for &n in &self.nodes {
230            in_degree.insert(n, 0);
231        }
232        for (_, dep) in &self.edges {
233            *in_degree.entry(*dep).or_insert(0) += 1;
234        }
235        let mut queue: std::collections::VecDeque<u32> = self
236            .nodes
237            .iter()
238            .filter(|&&n| in_degree[&n] == 0)
239            .copied()
240            .collect();
241        let mut result = Vec::new();
242        while let Some(node) = queue.pop_front() {
243            result.push(node);
244            for dep in self.dependents_of(node) {
245                let cnt = in_degree.entry(dep).or_insert(0);
246                *cnt = cnt.saturating_sub(1);
247                if *cnt == 0 {
248                    queue.push_back(dep);
249                }
250            }
251        }
252        result
253    }
254    #[allow(dead_code)]
255    pub fn has_cycle(&self) -> bool {
256        self.topological_sort().len() < self.nodes.len()
257    }
258}
259/// Statistics for join point optimization
260#[derive(Debug, Clone, Default)]
261pub struct JoinPointStats {
262    /// Number of join points created
263    pub joins_created: usize,
264    /// Number of join points inlined
265    pub joins_inlined: usize,
266    /// Number of dead join points eliminated
267    pub joins_eliminated: usize,
268    /// Number of tail calls detected
269    pub tail_calls_detected: usize,
270    /// Number of functions contified
271    pub functions_contified: usize,
272    /// Number of join points floated
273    pub joins_floated: usize,
274    /// Total optimization iterations run
275    pub iterations: usize,
276}
277impl JoinPointStats {
278    pub(super) fn total_changes(&self) -> usize {
279        self.joins_created
280            + self.joins_inlined
281            + self.joins_eliminated
282            + self.tail_calls_detected
283            + self.functions_contified
284            + self.joins_floated
285    }
286}
287/// A text buffer for building OJoin output source code.
288#[derive(Debug, Default)]
289pub struct OJoinSourceBuffer {
290    pub(super) buf: String,
291    pub(super) indent_level: usize,
292    pub(super) indent_str: String,
293}
294impl OJoinSourceBuffer {
295    pub fn new() -> Self {
296        OJoinSourceBuffer {
297            buf: String::new(),
298            indent_level: 0,
299            indent_str: "    ".to_string(),
300        }
301    }
302    pub fn with_indent(mut self, indent: impl Into<String>) -> Self {
303        self.indent_str = indent.into();
304        self
305    }
306    pub fn push_line(&mut self, line: &str) {
307        for _ in 0..self.indent_level {
308            self.buf.push_str(&self.indent_str);
309        }
310        self.buf.push_str(line);
311        self.buf.push('\n');
312    }
313    pub fn push_raw(&mut self, s: &str) {
314        self.buf.push_str(s);
315    }
316    pub fn indent(&mut self) {
317        self.indent_level += 1;
318    }
319    pub fn dedent(&mut self) {
320        self.indent_level = self.indent_level.saturating_sub(1);
321    }
322    pub fn as_str(&self) -> &str {
323        &self.buf
324    }
325    pub fn len(&self) -> usize {
326        self.buf.len()
327    }
328    pub fn is_empty(&self) -> bool {
329        self.buf.is_empty()
330    }
331    pub fn line_count(&self) -> usize {
332        self.buf.lines().count()
333    }
334    pub fn into_string(self) -> String {
335        self.buf
336    }
337    pub fn reset(&mut self) {
338        self.buf.clear();
339        self.indent_level = 0;
340    }
341}
342/// Tracks declared names for OJoin scope analysis.
343#[derive(Debug, Default)]
344pub struct OJoinNameScope {
345    pub(super) declared: std::collections::HashSet<String>,
346    pub(super) depth: usize,
347    pub(super) parent: Option<Box<OJoinNameScope>>,
348}
349impl OJoinNameScope {
350    pub fn new() -> Self {
351        OJoinNameScope::default()
352    }
353    pub fn declare(&mut self, name: impl Into<String>) -> bool {
354        self.declared.insert(name.into())
355    }
356    pub fn is_declared(&self, name: &str) -> bool {
357        self.declared.contains(name)
358    }
359    pub fn push_scope(self) -> Self {
360        OJoinNameScope {
361            declared: std::collections::HashSet::new(),
362            depth: self.depth + 1,
363            parent: Some(Box::new(self)),
364        }
365    }
366    pub fn pop_scope(self) -> Self {
367        *self.parent.unwrap_or_default()
368    }
369    pub fn depth(&self) -> usize {
370        self.depth
371    }
372    pub fn len(&self) -> usize {
373        self.declared.len()
374    }
375}
376#[allow(dead_code)]
377#[derive(Debug, Clone)]
378pub struct OJAnalysisCache {
379    pub(super) entries: std::collections::HashMap<String, OJCacheEntry>,
380    pub(super) max_size: usize,
381    pub(super) hits: u64,
382    pub(super) misses: u64,
383}
384impl OJAnalysisCache {
385    #[allow(dead_code)]
386    pub fn new(max_size: usize) -> Self {
387        OJAnalysisCache {
388            entries: std::collections::HashMap::new(),
389            max_size,
390            hits: 0,
391            misses: 0,
392        }
393    }
394    #[allow(dead_code)]
395    pub fn get(&mut self, key: &str) -> Option<&OJCacheEntry> {
396        if self.entries.contains_key(key) {
397            self.hits += 1;
398            self.entries.get(key)
399        } else {
400            self.misses += 1;
401            None
402        }
403    }
404    #[allow(dead_code)]
405    pub fn insert(&mut self, key: String, data: Vec<u8>) {
406        if self.entries.len() >= self.max_size {
407            if let Some(oldest) = self.entries.keys().next().cloned() {
408                self.entries.remove(&oldest);
409            }
410        }
411        self.entries.insert(
412            key.clone(),
413            OJCacheEntry {
414                key,
415                data,
416                timestamp: 0,
417                valid: true,
418            },
419        );
420    }
421    #[allow(dead_code)]
422    pub fn invalidate(&mut self, key: &str) {
423        if let Some(entry) = self.entries.get_mut(key) {
424            entry.valid = false;
425        }
426    }
427    #[allow(dead_code)]
428    pub fn clear(&mut self) {
429        self.entries.clear();
430    }
431    #[allow(dead_code)]
432    pub fn hit_rate(&self) -> f64 {
433        let total = self.hits + self.misses;
434        if total == 0 {
435            return 0.0;
436        }
437        self.hits as f64 / total as f64
438    }
439    #[allow(dead_code)]
440    pub fn size(&self) -> usize {
441        self.entries.len()
442    }
443}
444/// A monotonically increasing ID generator for OJoin.
445#[derive(Debug, Default)]
446pub struct OJoinIdGen {
447    pub(super) next: u32,
448}
449impl OJoinIdGen {
450    pub fn new() -> Self {
451        OJoinIdGen::default()
452    }
453    pub fn next_id(&mut self) -> u32 {
454        let id = self.next;
455        self.next += 1;
456        id
457    }
458    pub fn peek_next(&self) -> u32 {
459        self.next
460    }
461    pub fn reset(&mut self) {
462        self.next = 0;
463    }
464    pub fn skip(&mut self, n: u32) {
465        self.next += n;
466    }
467}
468/// Heuristic freshness key for OJoin incremental compilation.
469#[derive(Debug, Clone, PartialEq, Eq, Hash)]
470pub struct OJoinIncrKey {
471    pub content_hash: u64,
472    pub config_hash: u64,
473}
474impl OJoinIncrKey {
475    pub fn new(content: u64, config: u64) -> Self {
476        OJoinIncrKey {
477            content_hash: content,
478            config_hash: config,
479        }
480    }
481    pub fn combined_hash(&self) -> u64 {
482        self.content_hash.wrapping_mul(0x9e3779b97f4a7c15) ^ self.config_hash
483    }
484    pub fn matches(&self, other: &OJoinIncrKey) -> bool {
485        self.content_hash == other.content_hash && self.config_hash == other.config_hash
486    }
487}
488/// A diagnostic message from a OJoin pass.
489#[derive(Debug, Clone)]
490pub struct OJoinDiagMsg {
491    pub severity: OJoinDiagSeverity,
492    pub pass: String,
493    pub message: String,
494}
495impl OJoinDiagMsg {
496    pub fn error(pass: impl Into<String>, msg: impl Into<String>) -> Self {
497        OJoinDiagMsg {
498            severity: OJoinDiagSeverity::Error,
499            pass: pass.into(),
500            message: msg.into(),
501        }
502    }
503    pub fn warning(pass: impl Into<String>, msg: impl Into<String>) -> Self {
504        OJoinDiagMsg {
505            severity: OJoinDiagSeverity::Warning,
506            pass: pass.into(),
507            message: msg.into(),
508        }
509    }
510    pub fn note(pass: impl Into<String>, msg: impl Into<String>) -> Self {
511        OJoinDiagMsg {
512            severity: OJoinDiagSeverity::Note,
513            pass: pass.into(),
514            message: msg.into(),
515        }
516    }
517}
518/// A version tag for OJoin output artifacts.
519#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
520pub struct OJoinVersion {
521    pub major: u32,
522    pub minor: u32,
523    pub patch: u32,
524    pub pre: Option<String>,
525}
526impl OJoinVersion {
527    pub fn new(major: u32, minor: u32, patch: u32) -> Self {
528        OJoinVersion {
529            major,
530            minor,
531            patch,
532            pre: None,
533        }
534    }
535    pub fn with_pre(mut self, pre: impl Into<String>) -> Self {
536        self.pre = Some(pre.into());
537        self
538    }
539    pub fn is_stable(&self) -> bool {
540        self.pre.is_none()
541    }
542    pub fn is_compatible_with(&self, other: &OJoinVersion) -> bool {
543        self.major == other.major && self.minor >= other.minor
544    }
545}
546/// Collects OJoin diagnostics.
547#[derive(Debug, Default)]
548pub struct OJoinDiagCollector {
549    pub(super) msgs: Vec<OJoinDiagMsg>,
550}
551impl OJoinDiagCollector {
552    pub fn new() -> Self {
553        OJoinDiagCollector::default()
554    }
555    pub fn emit(&mut self, d: OJoinDiagMsg) {
556        self.msgs.push(d);
557    }
558    pub fn has_errors(&self) -> bool {
559        self.msgs
560            .iter()
561            .any(|d| d.severity == OJoinDiagSeverity::Error)
562    }
563    pub fn errors(&self) -> Vec<&OJoinDiagMsg> {
564        self.msgs
565            .iter()
566            .filter(|d| d.severity == OJoinDiagSeverity::Error)
567            .collect()
568    }
569    pub fn warnings(&self) -> Vec<&OJoinDiagMsg> {
570        self.msgs
571            .iter()
572            .filter(|d| d.severity == OJoinDiagSeverity::Warning)
573            .collect()
574    }
575    pub fn len(&self) -> usize {
576        self.msgs.len()
577    }
578    pub fn is_empty(&self) -> bool {
579        self.msgs.is_empty()
580    }
581    pub fn clear(&mut self) {
582        self.msgs.clear();
583    }
584}
585/// Pipeline profiler for OJoin.
586#[derive(Debug, Default)]
587pub struct OJoinProfiler {
588    pub(super) timings: Vec<OJoinPassTiming>,
589}
590impl OJoinProfiler {
591    pub fn new() -> Self {
592        OJoinProfiler::default()
593    }
594    pub fn record(&mut self, t: OJoinPassTiming) {
595        self.timings.push(t);
596    }
597    pub fn total_elapsed_us(&self) -> u64 {
598        self.timings.iter().map(|t| t.elapsed_us).sum()
599    }
600    pub fn slowest_pass(&self) -> Option<&OJoinPassTiming> {
601        self.timings.iter().max_by_key(|t| t.elapsed_us)
602    }
603    pub fn num_passes(&self) -> usize {
604        self.timings.len()
605    }
606    pub fn profitable_passes(&self) -> Vec<&OJoinPassTiming> {
607        self.timings.iter().filter(|t| t.is_profitable()).collect()
608    }
609}
610#[allow(dead_code)]
611pub struct OJPassRegistry {
612    pub(super) configs: Vec<OJPassConfig>,
613    pub(super) stats: std::collections::HashMap<String, OJPassStats>,
614}
615impl OJPassRegistry {
616    #[allow(dead_code)]
617    pub fn new() -> Self {
618        OJPassRegistry {
619            configs: Vec::new(),
620            stats: std::collections::HashMap::new(),
621        }
622    }
623    #[allow(dead_code)]
624    pub fn register(&mut self, config: OJPassConfig) {
625        self.stats
626            .insert(config.pass_name.clone(), OJPassStats::new());
627        self.configs.push(config);
628    }
629    #[allow(dead_code)]
630    pub fn enabled_passes(&self) -> Vec<&OJPassConfig> {
631        self.configs.iter().filter(|c| c.enabled).collect()
632    }
633    #[allow(dead_code)]
634    pub fn get_stats(&self, name: &str) -> Option<&OJPassStats> {
635        self.stats.get(name)
636    }
637    #[allow(dead_code)]
638    pub fn total_passes(&self) -> usize {
639        self.configs.len()
640    }
641    #[allow(dead_code)]
642    pub fn enabled_count(&self) -> usize {
643        self.enabled_passes().len()
644    }
645    #[allow(dead_code)]
646    pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
647        if let Some(stats) = self.stats.get_mut(name) {
648            stats.record_run(changes, time_ms, iter);
649        }
650    }
651}
652/// Main join point optimizer
653pub struct JoinPointOptimizer {
654    pub(super) config: JoinPointConfig,
655    pub(super) stats: JoinPointStats,
656    pub(super) next_id: u64,
657}
658impl JoinPointOptimizer {
659    /// Create a new join point optimizer with the given configuration
660    pub fn new(config: JoinPointConfig) -> Self {
661        JoinPointOptimizer {
662            config,
663            stats: JoinPointStats::default(),
664            next_id: 1000,
665        }
666    }
667    /// Get the optimization statistics
668    pub fn stats(&self) -> &JoinPointStats {
669        &self.stats
670    }
671    /// Generate a fresh variable ID
672    pub(super) fn fresh_id(&mut self) -> LcnfVarId {
673        let id = self.next_id;
674        self.next_id += 1;
675        LcnfVarId(id)
676    }
677    /// Optimize a single function declaration
678    pub(super) fn optimize_decl(&mut self, decl: &mut LcnfFunDecl) {
679        for _ in 0..self.config.max_iterations {
680            let changes_before = self.stats.total_changes();
681            if self.config.detect_tail_calls {
682                self.detect_tail_calls_in_expr(&mut decl.body, &decl.name);
683            }
684            if self.config.inline_small_joins {
685                self.inline_small_joins(&mut decl.body);
686            }
687            if self.config.eliminate_dead_joins {
688                self.eliminate_dead_joins(&mut decl.body);
689            }
690            if self.config.enable_contification {
691                self.contify_functions(&mut decl.body);
692            }
693            if self.config.float_join_points {
694                self.float_joins(&mut decl.body);
695            }
696            self.stats.iterations += 1;
697            if self.stats.total_changes() == changes_before {
698                break;
699            }
700        }
701    }
702    /// Detect tail calls in an expression and convert App to TailCall where appropriate
703    pub(super) fn detect_tail_calls_in_expr(&mut self, expr: &mut LcnfExpr, _current_fn: &str) {
704        let should_convert = if let LcnfExpr::Let {
705            id,
706            value: LcnfLetValue::App(func, args),
707            body,
708            ..
709        } = &*expr
710        {
711            if let LcnfExpr::Return(LcnfArg::Var(ret_var)) = body.as_ref() {
712                if *ret_var == *id {
713                    Some((func.clone(), args.clone()))
714                } else {
715                    None
716                }
717            } else {
718                None
719            }
720        } else {
721            None
722        };
723        if let Some((func, args)) = should_convert {
724            *expr = LcnfExpr::TailCall(func, args);
725            self.stats.tail_calls_detected += 1;
726            return;
727        }
728        match expr {
729            LcnfExpr::Let { body, .. } => {
730                self.detect_tail_calls_in_expr(body, _current_fn);
731            }
732            LcnfExpr::Case { alts, default, .. } => {
733                for alt in alts.iter_mut() {
734                    self.detect_tail_calls_in_expr(&mut alt.body, _current_fn);
735                }
736                if let Some(def) = default {
737                    self.detect_tail_calls_in_expr(def, _current_fn);
738                }
739            }
740            LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
741        }
742    }
743    /// Inline small join points (those with few instructions)
744    pub(super) fn inline_small_joins(&mut self, expr: &mut LcnfExpr) {
745        let small_joins = self.find_small_joins(expr);
746        if !small_joins.is_empty() {
747            self.apply_join_inlining(expr, &small_joins);
748        }
749    }
750    /// Find let-bound values that are small enough to inline
751    pub(super) fn find_small_joins(&self, expr: &LcnfExpr) -> HashMap<LcnfVarId, LcnfLetValue> {
752        let mut joins = HashMap::new();
753        match expr {
754            LcnfExpr::Let {
755                id, value, body, ..
756            } => {
757                let size = self.value_size(value);
758                if size <= self.config.max_join_size {
759                    joins.insert(*id, value.clone());
760                }
761                joins.extend(self.find_small_joins(body));
762            }
763            LcnfExpr::Case { alts, default, .. } => {
764                for alt in alts {
765                    joins.extend(self.find_small_joins(&alt.body));
766                }
767                if let Some(def) = default {
768                    joins.extend(self.find_small_joins(def));
769                }
770            }
771            _ => {}
772        }
773        joins
774    }
775    /// Compute the "size" of a let-value for inlining decisions
776    pub(super) fn value_size(&self, value: &LcnfLetValue) -> usize {
777        match value {
778            LcnfLetValue::Lit(_)
779            | LcnfLetValue::Erased
780            | LcnfLetValue::FVar(_)
781            | LcnfLetValue::Reset(_)
782            | LcnfLetValue::Reuse(_, _, _, _) => 1,
783            LcnfLetValue::Proj(_, _, _) => 1,
784            LcnfLetValue::App(_, args) => 1 + args.len(),
785            LcnfLetValue::Ctor(_, _, args) => 1 + args.len(),
786        }
787    }
788    /// Apply inlining of small join points: replace FVar references
789    pub(super) fn apply_join_inlining(
790        &mut self,
791        expr: &mut LcnfExpr,
792        joins: &HashMap<LcnfVarId, LcnfLetValue>,
793    ) {
794        match expr {
795            LcnfExpr::Let {
796                id, value, body, ..
797            } => {
798                if let LcnfLetValue::FVar(ref fvar) = value {
799                    if let Some(replacement) = joins.get(fvar) {
800                        if *id != *fvar {
801                            *value = replacement.clone();
802                            self.stats.joins_inlined += 1;
803                        }
804                    }
805                }
806                self.apply_join_inlining(body, joins);
807            }
808            LcnfExpr::Case { alts, default, .. } => {
809                for alt in alts.iter_mut() {
810                    self.apply_join_inlining(&mut alt.body, joins);
811                }
812                if let Some(def) = default {
813                    self.apply_join_inlining(def, joins);
814                }
815            }
816            _ => {}
817        }
818    }
819    /// Eliminate dead join points (unreferenced let-bindings that are pure)
820    pub(super) fn eliminate_dead_joins(&mut self, expr: &mut LcnfExpr) {
821        let used = collect_used_vars(expr);
822        self.remove_dead_lets(expr, &used);
823    }
824    /// Remove let-bindings for variables that are never used
825    pub(super) fn remove_dead_lets(&mut self, expr: &mut LcnfExpr, used: &HashSet<LcnfVarId>) {
826        loop {
827            let mut changed = false;
828            if let LcnfExpr::Let {
829                id, value, body, ..
830            } = expr
831            {
832                if !used.contains(id) && is_pure_value(value) {
833                    *expr = *body.clone();
834                    self.stats.joins_eliminated += 1;
835                    changed = true;
836                }
837            }
838            if !changed {
839                break;
840            }
841        }
842        match expr {
843            LcnfExpr::Let { body, .. } => {
844                self.remove_dead_lets(body, used);
845            }
846            LcnfExpr::Case { alts, default, .. } => {
847                for alt in alts.iter_mut() {
848                    self.remove_dead_lets(&mut alt.body, used);
849                }
850                if let Some(def) = default {
851                    self.remove_dead_lets(def, used);
852                }
853            }
854            _ => {}
855        }
856    }
857    /// Convert functions that are always called in tail position to join points
858    pub(super) fn contify_functions(&mut self, expr: &mut LcnfExpr) {
859        let tail_uses = analyze_tail_uses(expr, true);
860        let candidates: Vec<LcnfVarId> = tail_uses
861            .iter()
862            .filter(|(_, use_kind)| **use_kind == TailUse::TailOnly)
863            .map(|(var, _)| *var)
864            .collect();
865        if !candidates.is_empty() {
866            self.mark_contified(expr, &candidates);
867        }
868    }
869    /// Mark let-bound functions as contified (join points)
870    pub(super) fn mark_contified(&mut self, expr: &mut LcnfExpr, candidates: &[LcnfVarId]) {
871        match expr {
872            LcnfExpr::Let { id, body, .. } => {
873                if candidates.contains(id) {
874                    self.stats.functions_contified += 1;
875                }
876                self.mark_contified(body, candidates);
877            }
878            LcnfExpr::Case { alts, default, .. } => {
879                for alt in alts.iter_mut() {
880                    self.mark_contified(&mut alt.body, candidates);
881                }
882                if let Some(def) = default {
883                    self.mark_contified(def, candidates);
884                }
885            }
886            _ => {}
887        }
888    }
889    /// Float join points closer to their uses
890    pub(super) fn float_joins(&mut self, expr: &mut LcnfExpr) {
891        let moved = self.try_float_into_case(expr);
892        if moved {
893            self.stats.joins_floated += 1;
894        }
895        match expr {
896            LcnfExpr::Let { body, .. } => {
897                self.float_joins(body);
898            }
899            LcnfExpr::Case { alts, default, .. } => {
900                for alt in alts.iter_mut() {
901                    self.float_joins(&mut alt.body);
902                }
903                if let Some(def) = default {
904                    self.float_joins(def);
905                }
906            }
907            _ => {}
908        }
909    }
910    /// Try to float a let-binding into the single case branch that uses it.
911    ///
912    /// Transforms:
913    ///   `let x = v; case n { A -> ..x.. | B -> (no x) }`
914    /// into:
915    ///   `case n { A -> let x = v; ..x.. | B -> (no x) }`
916    ///
917    /// Only floats when exactly one branch references `x`.
918    pub(super) fn try_float_into_case(&mut self, expr: &mut LcnfExpr) -> bool {
919        let can_float = if let LcnfExpr::Let { id, body, .. } = &*expr {
920            if let LcnfExpr::Case { alts, default, .. } = body.as_ref() {
921                let use_count = alts.iter().filter(|a| expr_uses_var(&a.body, *id)).count()
922                    + default
923                        .as_ref()
924                        .map(|d| usize::from(expr_uses_var(d, *id)))
925                        .unwrap_or(0);
926                use_count == 1
927            } else {
928                false
929            }
930        } else {
931            false
932        };
933        if !can_float {
934            return false;
935        }
936        let old = std::mem::replace(expr, LcnfExpr::Unreachable);
937        if let LcnfExpr::Let {
938            id,
939            name,
940            ty,
941            value,
942            body,
943        } = old
944        {
945            if let LcnfExpr::Case {
946                scrutinee,
947                scrutinee_ty,
948                mut alts,
949                mut default,
950            } = *body
951            {
952                if let Some(idx) = alts.iter().position(|a| expr_uses_var(&a.body, id)) {
953                    let old_body = std::mem::replace(&mut alts[idx].body, LcnfExpr::Unreachable);
954                    alts[idx].body = LcnfExpr::Let {
955                        id,
956                        name,
957                        ty,
958                        value,
959                        body: Box::new(old_body),
960                    };
961                } else if let Some(def) = default.take() {
962                    default = Some(Box::new(LcnfExpr::Let {
963                        id,
964                        name,
965                        ty,
966                        value,
967                        body: def,
968                    }));
969                }
970                *expr = LcnfExpr::Case {
971                    scrutinee,
972                    scrutinee_ty,
973                    alts,
974                    default,
975                };
976                return true;
977            }
978        }
979        false
980    }
981}
982#[allow(dead_code)]
983#[derive(Debug, Clone)]
984pub struct OJDominatorTree {
985    pub idom: Vec<Option<u32>>,
986    pub dom_children: Vec<Vec<u32>>,
987    pub dom_depth: Vec<u32>,
988}
989impl OJDominatorTree {
990    #[allow(dead_code)]
991    pub fn new(size: usize) -> Self {
992        OJDominatorTree {
993            idom: vec![None; size],
994            dom_children: vec![Vec::new(); size],
995            dom_depth: vec![0; size],
996        }
997    }
998    #[allow(dead_code)]
999    pub fn set_idom(&mut self, node: usize, idom: u32) {
1000        self.idom[node] = Some(idom);
1001    }
1002    #[allow(dead_code)]
1003    pub fn dominates(&self, a: usize, b: usize) -> bool {
1004        if a == b {
1005            return true;
1006        }
1007        let mut cur = b;
1008        loop {
1009            match self.idom[cur] {
1010                Some(parent) if parent as usize == a => return true,
1011                Some(parent) if parent as usize == cur => return false,
1012                Some(parent) => cur = parent as usize,
1013                None => return false,
1014            }
1015        }
1016    }
1017    #[allow(dead_code)]
1018    pub fn depth(&self, node: usize) -> u32 {
1019        self.dom_depth.get(node).copied().unwrap_or(0)
1020    }
1021}
1022#[allow(dead_code)]
1023#[derive(Debug, Clone)]
1024pub struct OJWorklist {
1025    pub(super) items: std::collections::VecDeque<u32>,
1026    pub(super) in_worklist: std::collections::HashSet<u32>,
1027}
1028impl OJWorklist {
1029    #[allow(dead_code)]
1030    pub fn new() -> Self {
1031        OJWorklist {
1032            items: std::collections::VecDeque::new(),
1033            in_worklist: std::collections::HashSet::new(),
1034        }
1035    }
1036    #[allow(dead_code)]
1037    pub fn push(&mut self, item: u32) -> bool {
1038        if self.in_worklist.insert(item) {
1039            self.items.push_back(item);
1040            true
1041        } else {
1042            false
1043        }
1044    }
1045    #[allow(dead_code)]
1046    pub fn pop(&mut self) -> Option<u32> {
1047        let item = self.items.pop_front()?;
1048        self.in_worklist.remove(&item);
1049        Some(item)
1050    }
1051    #[allow(dead_code)]
1052    pub fn is_empty(&self) -> bool {
1053        self.items.is_empty()
1054    }
1055    #[allow(dead_code)]
1056    pub fn len(&self) -> usize {
1057        self.items.len()
1058    }
1059    #[allow(dead_code)]
1060    pub fn contains(&self, item: u32) -> bool {
1061        self.in_worklist.contains(&item)
1062    }
1063}
1064/// Pass-timing record for OJoin profiler.
1065#[derive(Debug, Clone)]
1066pub struct OJoinPassTiming {
1067    pub pass_name: String,
1068    pub elapsed_us: u64,
1069    pub items_processed: usize,
1070    pub bytes_before: usize,
1071    pub bytes_after: usize,
1072}
1073impl OJoinPassTiming {
1074    pub fn new(
1075        pass_name: impl Into<String>,
1076        elapsed_us: u64,
1077        items: usize,
1078        before: usize,
1079        after: usize,
1080    ) -> Self {
1081        OJoinPassTiming {
1082            pass_name: pass_name.into(),
1083            elapsed_us,
1084            items_processed: items,
1085            bytes_before: before,
1086            bytes_after: after,
1087        }
1088    }
1089    pub fn throughput_mps(&self) -> f64 {
1090        if self.elapsed_us == 0 {
1091            0.0
1092        } else {
1093            self.items_processed as f64 / (self.elapsed_us as f64 / 1_000_000.0)
1094        }
1095    }
1096    pub fn size_ratio(&self) -> f64 {
1097        if self.bytes_before == 0 {
1098            1.0
1099        } else {
1100            self.bytes_after as f64 / self.bytes_before as f64
1101        }
1102    }
1103    pub fn is_profitable(&self) -> bool {
1104        self.size_ratio() <= 1.05
1105    }
1106}
1107/// Emission statistics for OJoin.
1108#[derive(Debug, Clone, Default)]
1109pub struct OJoinEmitStats {
1110    pub bytes_emitted: usize,
1111    pub items_emitted: usize,
1112    pub errors: usize,
1113    pub warnings: usize,
1114    pub elapsed_ms: u64,
1115}
1116impl OJoinEmitStats {
1117    pub fn new() -> Self {
1118        OJoinEmitStats::default()
1119    }
1120    pub fn throughput_bps(&self) -> f64 {
1121        if self.elapsed_ms == 0 {
1122            0.0
1123        } else {
1124            self.bytes_emitted as f64 / (self.elapsed_ms as f64 / 1000.0)
1125        }
1126    }
1127    pub fn is_clean(&self) -> bool {
1128        self.errors == 0
1129    }
1130}
1131/// Information about whether a variable is used only in tail position
1132#[derive(Debug, Clone, PartialEq, Eq)]
1133pub enum TailUse {
1134    /// Never used
1135    Unused,
1136    /// Used only in tail position
1137    TailOnly,
1138    /// Used in non-tail position
1139    NonTail,
1140    /// Used in both tail and non-tail positions
1141    Mixed,
1142}
1143impl TailUse {
1144    pub(super) fn merge(&self, other: &TailUse) -> TailUse {
1145        match (self, other) {
1146            (TailUse::Unused, x) | (x, TailUse::Unused) => x.clone(),
1147            (TailUse::TailOnly, TailUse::TailOnly) => TailUse::TailOnly,
1148            (TailUse::NonTail, TailUse::NonTail) => TailUse::NonTail,
1149            _ => TailUse::Mixed,
1150        }
1151    }
1152}
1153/// Severity of a OJoin diagnostic.
1154#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
1155pub enum OJoinDiagSeverity {
1156    Note,
1157    Warning,
1158    Error,
1159}
1160#[allow(dead_code)]
1161#[derive(Debug, Clone)]
1162pub struct OJLivenessInfo {
1163    pub live_in: Vec<std::collections::HashSet<u32>>,
1164    pub live_out: Vec<std::collections::HashSet<u32>>,
1165    pub defs: Vec<std::collections::HashSet<u32>>,
1166    pub uses: Vec<std::collections::HashSet<u32>>,
1167}
1168impl OJLivenessInfo {
1169    #[allow(dead_code)]
1170    pub fn new(block_count: usize) -> Self {
1171        OJLivenessInfo {
1172            live_in: vec![std::collections::HashSet::new(); block_count],
1173            live_out: vec![std::collections::HashSet::new(); block_count],
1174            defs: vec![std::collections::HashSet::new(); block_count],
1175            uses: vec![std::collections::HashSet::new(); block_count],
1176        }
1177    }
1178    #[allow(dead_code)]
1179    pub fn add_def(&mut self, block: usize, var: u32) {
1180        if block < self.defs.len() {
1181            self.defs[block].insert(var);
1182        }
1183    }
1184    #[allow(dead_code)]
1185    pub fn add_use(&mut self, block: usize, var: u32) {
1186        if block < self.uses.len() {
1187            self.uses[block].insert(var);
1188        }
1189    }
1190    #[allow(dead_code)]
1191    pub fn is_live_in(&self, block: usize, var: u32) -> bool {
1192        self.live_in
1193            .get(block)
1194            .map(|s| s.contains(&var))
1195            .unwrap_or(false)
1196    }
1197    #[allow(dead_code)]
1198    pub fn is_live_out(&self, block: usize, var: u32) -> bool {
1199        self.live_out
1200            .get(block)
1201            .map(|s| s.contains(&var))
1202            .unwrap_or(false)
1203    }
1204}
1205#[allow(dead_code)]
1206#[derive(Debug, Clone)]
1207pub struct OJCacheEntry {
1208    pub key: String,
1209    pub data: Vec<u8>,
1210    pub timestamp: u64,
1211    pub valid: bool,
1212}
1213/// A fixed-capacity ring buffer of strings (for recent-event logging in OJoin).
1214#[derive(Debug)]
1215pub struct OJoinEventLog {
1216    pub(super) entries: std::collections::VecDeque<String>,
1217    pub(super) capacity: usize,
1218}
1219impl OJoinEventLog {
1220    pub fn new(capacity: usize) -> Self {
1221        OJoinEventLog {
1222            entries: std::collections::VecDeque::with_capacity(capacity),
1223            capacity,
1224        }
1225    }
1226    pub fn push(&mut self, event: impl Into<String>) {
1227        if self.entries.len() >= self.capacity {
1228            self.entries.pop_front();
1229        }
1230        self.entries.push_back(event.into());
1231    }
1232    pub fn iter(&self) -> impl Iterator<Item = &String> {
1233        self.entries.iter()
1234    }
1235    pub fn len(&self) -> usize {
1236        self.entries.len()
1237    }
1238    pub fn is_empty(&self) -> bool {
1239        self.entries.is_empty()
1240    }
1241    pub fn capacity(&self) -> usize {
1242        self.capacity
1243    }
1244    pub fn clear(&mut self) {
1245        self.entries.clear();
1246    }
1247}
1248#[allow(dead_code)]
1249#[derive(Debug, Clone, Default)]
1250pub struct OJPassStats {
1251    pub total_runs: u32,
1252    pub successful_runs: u32,
1253    pub total_changes: u64,
1254    pub time_ms: u64,
1255    pub iterations_used: u32,
1256}
1257impl OJPassStats {
1258    #[allow(dead_code)]
1259    pub fn new() -> Self {
1260        Self::default()
1261    }
1262    #[allow(dead_code)]
1263    pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
1264        self.total_runs += 1;
1265        self.successful_runs += 1;
1266        self.total_changes += changes;
1267        self.time_ms += time_ms;
1268        self.iterations_used = iterations;
1269    }
1270    #[allow(dead_code)]
1271    pub fn average_changes_per_run(&self) -> f64 {
1272        if self.total_runs == 0 {
1273            return 0.0;
1274        }
1275        self.total_changes as f64 / self.total_runs as f64
1276    }
1277    #[allow(dead_code)]
1278    pub fn success_rate(&self) -> f64 {
1279        if self.total_runs == 0 {
1280            return 0.0;
1281        }
1282        self.successful_runs as f64 / self.total_runs as f64
1283    }
1284    #[allow(dead_code)]
1285    pub fn format_summary(&self) -> String {
1286        format!(
1287            "Runs: {}/{}, Changes: {}, Time: {}ms",
1288            self.successful_runs, self.total_runs, self.total_changes, self.time_ms
1289        )
1290    }
1291}
1292#[allow(dead_code)]
1293#[derive(Debug, Clone)]
1294pub struct OJPassConfig {
1295    pub phase: OJPassPhase,
1296    pub enabled: bool,
1297    pub max_iterations: u32,
1298    pub debug_output: bool,
1299    pub pass_name: String,
1300}
1301impl OJPassConfig {
1302    #[allow(dead_code)]
1303    pub fn new(name: impl Into<String>, phase: OJPassPhase) -> Self {
1304        OJPassConfig {
1305            phase,
1306            enabled: true,
1307            max_iterations: 10,
1308            debug_output: false,
1309            pass_name: name.into(),
1310        }
1311    }
1312    #[allow(dead_code)]
1313    pub fn disabled(mut self) -> Self {
1314        self.enabled = false;
1315        self
1316    }
1317    #[allow(dead_code)]
1318    pub fn with_debug(mut self) -> Self {
1319        self.debug_output = true;
1320        self
1321    }
1322    #[allow(dead_code)]
1323    pub fn max_iter(mut self, n: u32) -> Self {
1324        self.max_iterations = n;
1325        self
1326    }
1327}
1328#[allow(dead_code)]
1329#[derive(Debug, Clone, PartialEq)]
1330pub enum OJPassPhase {
1331    Analysis,
1332    Transformation,
1333    Verification,
1334    Cleanup,
1335}
1336impl OJPassPhase {
1337    #[allow(dead_code)]
1338    pub fn name(&self) -> &str {
1339        match self {
1340            OJPassPhase::Analysis => "analysis",
1341            OJPassPhase::Transformation => "transformation",
1342            OJPassPhase::Verification => "verification",
1343            OJPassPhase::Cleanup => "cleanup",
1344        }
1345    }
1346    #[allow(dead_code)]
1347    pub fn is_modifying(&self) -> bool {
1348        matches!(self, OJPassPhase::Transformation | OJPassPhase::Cleanup)
1349    }
1350}