Skip to main content

oxilean_codegen/mlir_backend/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::HashMap;
6
7use std::collections::{HashSet, VecDeque};
8
9/// An MLIR function definition.
10#[derive(Debug, Clone)]
11pub struct MlirFunc {
12    /// Function name (without `@`)
13    pub name: String,
14    /// Function arguments: list of (name, type)
15    pub args: Vec<(String, MlirType)>,
16    /// Return types
17    pub results: Vec<MlirType>,
18    /// Function body (a region)
19    pub body: MlirRegion,
20    /// Extra attributes (e.g., sym_visibility = "private")
21    pub attributes: Vec<(String, MlirAttr)>,
22    /// Whether this is a declaration only (no body)
23    pub is_declaration: bool,
24}
25impl MlirFunc {
26    /// Create a simple function with a body.
27    pub fn new(
28        name: impl Into<String>,
29        args: Vec<(String, MlirType)>,
30        results: Vec<MlirType>,
31        body: MlirRegion,
32    ) -> Self {
33        MlirFunc {
34            name: name.into(),
35            args,
36            results,
37            body,
38            attributes: vec![],
39            is_declaration: false,
40        }
41    }
42    /// Create a function declaration (no body, for extern functions).
43    pub fn declaration(
44        name: impl Into<String>,
45        args: Vec<MlirType>,
46        results: Vec<MlirType>,
47    ) -> Self {
48        let arg_vals = args
49            .into_iter()
50            .enumerate()
51            .map(|(i, t)| (format!("arg{}", i), t))
52            .collect();
53        MlirFunc {
54            name: name.into(),
55            args: arg_vals,
56            results,
57            body: MlirRegion::empty(),
58            attributes: vec![],
59            is_declaration: true,
60        }
61    }
62    /// Emit the function as MLIR text.
63    pub fn emit(&self) -> String {
64        let mut out = String::new();
65        if self.is_declaration {
66            out.push_str("  func.func private @");
67        } else {
68            out.push_str("  func.func @");
69        }
70        out.push_str(&self.name);
71        out.push('(');
72        for (i, (name, ty)) in self.args.iter().enumerate() {
73            if i > 0 {
74                out.push_str(", ");
75            }
76            out.push_str(&format!("%{}: {}", name, ty));
77        }
78        out.push(')');
79        if !self.results.is_empty() {
80            out.push_str(" -> ");
81            if self.results.len() == 1 {
82                out.push_str(&self.results[0].to_string());
83            } else {
84                out.push('(');
85                for (i, r) in self.results.iter().enumerate() {
86                    if i > 0 {
87                        out.push_str(", ");
88                    }
89                    out.push_str(&r.to_string());
90                }
91                out.push(')');
92            }
93        }
94        if !self.attributes.is_empty() {
95            out.push_str(" attributes {");
96            for (i, (k, v)) in self.attributes.iter().enumerate() {
97                if i > 0 {
98                    out.push_str(", ");
99                }
100                out.push_str(&format!("{} = {}", k, v));
101            }
102            out.push('}');
103        }
104        if self.is_declaration {
105            out.push('\n');
106        } else {
107            out.push_str(" {\n");
108            for block in &self.body.blocks {
109                out.push_str(&format!("{}", block));
110            }
111            out.push_str("  }\n");
112        }
113        out
114    }
115}
116#[allow(dead_code)]
117#[derive(Debug, Clone, PartialEq)]
118pub enum MLIRPassPhase {
119    Analysis,
120    Transformation,
121    Verification,
122    Cleanup,
123}
124impl MLIRPassPhase {
125    #[allow(dead_code)]
126    pub fn name(&self) -> &str {
127        match self {
128            MLIRPassPhase::Analysis => "analysis",
129            MLIRPassPhase::Transformation => "transformation",
130            MLIRPassPhase::Verification => "verification",
131            MLIRPassPhase::Cleanup => "cleanup",
132        }
133    }
134    #[allow(dead_code)]
135    pub fn is_modifying(&self) -> bool {
136        matches!(self, MLIRPassPhase::Transformation | MLIRPassPhase::Cleanup)
137    }
138}
139/// MLIR code generation backend.
140pub struct MlirBackend {
141    pub(super) module: MlirModule,
142    pub(super) ssa: SsaCounter,
143    pub(super) pass_pipeline: Vec<String>,
144}
145impl MlirBackend {
146    /// Create a new MLIR backend.
147    pub fn new() -> Self {
148        MlirBackend {
149            module: MlirModule::new(),
150            ssa: SsaCounter::new(),
151            pass_pipeline: vec![],
152        }
153    }
154    /// Create a backend with a module name.
155    pub fn with_name(name: impl Into<String>) -> Self {
156        MlirBackend {
157            module: MlirModule::named(name),
158            ssa: SsaCounter::new(),
159            pass_pipeline: vec![],
160        }
161    }
162    /// Add a pass to the pipeline.
163    pub fn add_pass(&mut self, pass: impl Into<String>) {
164        self.pass_pipeline.push(pass.into());
165    }
166    /// Add a simple integer add function.
167    pub fn compile_add_func(&mut self, name: &str, bits: u32) {
168        let int_ty = MlirType::Integer(bits, false);
169        let mut builder = MlirBuilder::new();
170        let arg0 = MlirValue::named("arg0", int_ty.clone());
171        let arg1 = MlirValue::named("arg1", int_ty.clone());
172        let sum = builder.addi(arg0.clone(), arg1.clone());
173        builder.return_op(vec![sum]);
174        let block = MlirBlock::entry(vec![arg0, arg1], builder.take_ops());
175        let region = MlirRegion::single_block(block);
176        let func = MlirFunc::new(name, vec![], vec![int_ty.clone()], region);
177        self.module.add_function(func);
178    }
179    /// Compile a declaration to a simple wrapper function.
180    pub fn compile_decl(&mut self, name: &str, arg_types: Vec<MlirType>, ret_type: MlirType) {
181        let args: Vec<(String, MlirType)> = arg_types
182            .into_iter()
183            .enumerate()
184            .map(|(i, t)| (format!("arg{}", i), t))
185            .collect();
186        let mut builder = MlirBuilder::new();
187        let zero = builder.const_int(0, 64);
188        builder.return_op(vec![zero]);
189        let block = MlirBlock::entry(vec![], builder.take_ops());
190        let func = MlirFunc::new(name, args, vec![ret_type], MlirRegion::single_block(block));
191        self.module.add_function(func);
192    }
193    /// Emit the full MLIR module as text.
194    pub fn emit_module(&self) -> String {
195        self.module.emit()
196    }
197    /// Generate the `mlir-opt` pass pipeline string.
198    pub fn run_passes(&self) -> String {
199        if self.pass_pipeline.is_empty() {
200            String::new()
201        } else {
202            format!("mlir-opt --{}", self.pass_pipeline.join(" --"))
203        }
204    }
205    /// Get the module (for further manipulation).
206    pub fn module_mut(&mut self) -> &mut MlirModule {
207        &mut self.module
208    }
209}
210/// Constant folding helper for MLIRExt.
211#[allow(dead_code)]
212#[derive(Debug, Clone, Default)]
213pub struct MLIRExtConstFolder {
214    pub(super) folds: usize,
215    pub(super) failures: usize,
216    pub(super) enabled: bool,
217}
218impl MLIRExtConstFolder {
219    #[allow(dead_code)]
220    pub fn new() -> Self {
221        Self {
222            folds: 0,
223            failures: 0,
224            enabled: true,
225        }
226    }
227    #[allow(dead_code)]
228    pub fn add_i64(&mut self, a: i64, b: i64) -> Option<i64> {
229        self.folds += 1;
230        a.checked_add(b)
231    }
232    #[allow(dead_code)]
233    pub fn sub_i64(&mut self, a: i64, b: i64) -> Option<i64> {
234        self.folds += 1;
235        a.checked_sub(b)
236    }
237    #[allow(dead_code)]
238    pub fn mul_i64(&mut self, a: i64, b: i64) -> Option<i64> {
239        self.folds += 1;
240        a.checked_mul(b)
241    }
242    #[allow(dead_code)]
243    pub fn div_i64(&mut self, a: i64, b: i64) -> Option<i64> {
244        if b == 0 {
245            self.failures += 1;
246            None
247        } else {
248            self.folds += 1;
249            a.checked_div(b)
250        }
251    }
252    #[allow(dead_code)]
253    pub fn rem_i64(&mut self, a: i64, b: i64) -> Option<i64> {
254        if b == 0 {
255            self.failures += 1;
256            None
257        } else {
258            self.folds += 1;
259            a.checked_rem(b)
260        }
261    }
262    #[allow(dead_code)]
263    pub fn neg_i64(&mut self, a: i64) -> Option<i64> {
264        self.folds += 1;
265        a.checked_neg()
266    }
267    #[allow(dead_code)]
268    pub fn shl_i64(&mut self, a: i64, s: u32) -> Option<i64> {
269        if s >= 64 {
270            self.failures += 1;
271            None
272        } else {
273            self.folds += 1;
274            a.checked_shl(s)
275        }
276    }
277    #[allow(dead_code)]
278    pub fn shr_i64(&mut self, a: i64, s: u32) -> Option<i64> {
279        if s >= 64 {
280            self.failures += 1;
281            None
282        } else {
283            self.folds += 1;
284            a.checked_shr(s)
285        }
286    }
287    #[allow(dead_code)]
288    pub fn and_i64(&mut self, a: i64, b: i64) -> i64 {
289        self.folds += 1;
290        a & b
291    }
292    #[allow(dead_code)]
293    pub fn or_i64(&mut self, a: i64, b: i64) -> i64 {
294        self.folds += 1;
295        a | b
296    }
297    #[allow(dead_code)]
298    pub fn xor_i64(&mut self, a: i64, b: i64) -> i64 {
299        self.folds += 1;
300        a ^ b
301    }
302    #[allow(dead_code)]
303    pub fn not_i64(&mut self, a: i64) -> i64 {
304        self.folds += 1;
305        !a
306    }
307    #[allow(dead_code)]
308    pub fn fold_count(&self) -> usize {
309        self.folds
310    }
311    #[allow(dead_code)]
312    pub fn failure_count(&self) -> usize {
313        self.failures
314    }
315    #[allow(dead_code)]
316    pub fn enable(&mut self) {
317        self.enabled = true;
318    }
319    #[allow(dead_code)]
320    pub fn disable(&mut self) {
321        self.enabled = false;
322    }
323    #[allow(dead_code)]
324    pub fn is_enabled(&self) -> bool {
325        self.enabled
326    }
327}
328#[allow(dead_code)]
329#[derive(Debug, Clone)]
330pub struct MLIRCacheEntry {
331    pub key: String,
332    pub data: Vec<u8>,
333    pub timestamp: u64,
334    pub valid: bool,
335}
336#[allow(dead_code)]
337#[derive(Debug, Clone)]
338pub struct MLIRLivenessInfo {
339    pub live_in: Vec<std::collections::HashSet<u32>>,
340    pub live_out: Vec<std::collections::HashSet<u32>>,
341    pub defs: Vec<std::collections::HashSet<u32>>,
342    pub uses: Vec<std::collections::HashSet<u32>>,
343}
344impl MLIRLivenessInfo {
345    #[allow(dead_code)]
346    pub fn new(block_count: usize) -> Self {
347        MLIRLivenessInfo {
348            live_in: vec![std::collections::HashSet::new(); block_count],
349            live_out: vec![std::collections::HashSet::new(); block_count],
350            defs: vec![std::collections::HashSet::new(); block_count],
351            uses: vec![std::collections::HashSet::new(); block_count],
352        }
353    }
354    #[allow(dead_code)]
355    pub fn add_def(&mut self, block: usize, var: u32) {
356        if block < self.defs.len() {
357            self.defs[block].insert(var);
358        }
359    }
360    #[allow(dead_code)]
361    pub fn add_use(&mut self, block: usize, var: u32) {
362        if block < self.uses.len() {
363            self.uses[block].insert(var);
364        }
365    }
366    #[allow(dead_code)]
367    pub fn is_live_in(&self, block: usize, var: u32) -> bool {
368        self.live_in
369            .get(block)
370            .map(|s| s.contains(&var))
371            .unwrap_or(false)
372    }
373    #[allow(dead_code)]
374    pub fn is_live_out(&self, block: usize, var: u32) -> bool {
375        self.live_out
376            .get(block)
377            .map(|s| s.contains(&var))
378            .unwrap_or(false)
379    }
380}
381#[allow(dead_code)]
382#[derive(Debug, Clone)]
383pub struct MLIRWorklist {
384    pub(super) items: std::collections::VecDeque<u32>,
385    pub(super) in_worklist: std::collections::HashSet<u32>,
386}
387impl MLIRWorklist {
388    #[allow(dead_code)]
389    pub fn new() -> Self {
390        MLIRWorklist {
391            items: std::collections::VecDeque::new(),
392            in_worklist: std::collections::HashSet::new(),
393        }
394    }
395    #[allow(dead_code)]
396    pub fn push(&mut self, item: u32) -> bool {
397        if self.in_worklist.insert(item) {
398            self.items.push_back(item);
399            true
400        } else {
401            false
402        }
403    }
404    #[allow(dead_code)]
405    pub fn pop(&mut self) -> Option<u32> {
406        let item = self.items.pop_front()?;
407        self.in_worklist.remove(&item);
408        Some(item)
409    }
410    #[allow(dead_code)]
411    pub fn is_empty(&self) -> bool {
412        self.items.is_empty()
413    }
414    #[allow(dead_code)]
415    pub fn len(&self) -> usize {
416        self.items.len()
417    }
418    #[allow(dead_code)]
419    pub fn contains(&self, item: u32) -> bool {
420        self.in_worklist.contains(&item)
421    }
422}
423/// Pass execution phase for MLIRExt.
424#[allow(dead_code)]
425#[derive(Debug, Clone, PartialEq, Eq, Hash)]
426pub enum MLIRExtPassPhase {
427    Early,
428    Middle,
429    Late,
430    Finalize,
431}
432impl MLIRExtPassPhase {
433    #[allow(dead_code)]
434    pub fn is_early(&self) -> bool {
435        matches!(self, Self::Early)
436    }
437    #[allow(dead_code)]
438    pub fn is_middle(&self) -> bool {
439        matches!(self, Self::Middle)
440    }
441    #[allow(dead_code)]
442    pub fn is_late(&self) -> bool {
443        matches!(self, Self::Late)
444    }
445    #[allow(dead_code)]
446    pub fn is_finalize(&self) -> bool {
447        matches!(self, Self::Finalize)
448    }
449    #[allow(dead_code)]
450    pub fn order(&self) -> u32 {
451        match self {
452            Self::Early => 0,
453            Self::Middle => 1,
454            Self::Late => 2,
455            Self::Finalize => 3,
456        }
457    }
458    #[allow(dead_code)]
459    pub fn from_order(n: u32) -> Option<Self> {
460        match n {
461            0 => Some(Self::Early),
462            1 => Some(Self::Middle),
463            2 => Some(Self::Late),
464            3 => Some(Self::Finalize),
465            _ => None,
466        }
467    }
468}
469/// MLIR dialect classification.
470#[derive(Debug, Clone, PartialEq, Eq, Hash)]
471pub enum MlirDialect {
472    /// Built-in dialect (module, func types)
473    Builtin,
474    /// Arithmetic operations: addi, addf, muli, divsi, cmpi, extsi, trunci
475    Arith,
476    /// Function definitions and calls: func.func, func.call, func.return
477    Func,
478    /// Control flow: cf.br, cf.cond_br, cf.switch
479    CF,
480    /// Memory references: memref.alloc, memref.load, memref.store, memref.dealloc
481    MemRef,
482    /// Structured control flow: scf.if, scf.for, scf.while
483    SCF,
484    /// Affine transformations: affine.for, affine.load, affine.store
485    Affine,
486    /// Tensor operations: tensor.extract, tensor.insert, tensor.reshape
487    Tensor,
488    /// Vector operations: vector.load, vector.store, vector.broadcast
489    Vector,
490    /// Linear algebra operations for ML: linalg.matmul, linalg.generic
491    Linalg,
492    /// GPU dialect: gpu.launch, gpu.thread_id, gpu.block_dim
493    GPU,
494    /// LLVM IR dialect: llvm.add, llvm.mlir.constant, llvm.call
495    LLVM,
496    /// Math functions: math.sin, math.cos, math.exp, math.log, math.sqrt
497    Math,
498    /// Index type operations
499    Index,
500}
501#[allow(dead_code)]
502#[derive(Debug, Clone, Default)]
503pub struct MLIRPassStats {
504    pub total_runs: u32,
505    pub successful_runs: u32,
506    pub total_changes: u64,
507    pub time_ms: u64,
508    pub iterations_used: u32,
509}
510impl MLIRPassStats {
511    #[allow(dead_code)]
512    pub fn new() -> Self {
513        Self::default()
514    }
515    #[allow(dead_code)]
516    pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
517        self.total_runs += 1;
518        self.successful_runs += 1;
519        self.total_changes += changes;
520        self.time_ms += time_ms;
521        self.iterations_used = iterations;
522    }
523    #[allow(dead_code)]
524    pub fn average_changes_per_run(&self) -> f64 {
525        if self.total_runs == 0 {
526            return 0.0;
527        }
528        self.total_changes as f64 / self.total_runs as f64
529    }
530    #[allow(dead_code)]
531    pub fn success_rate(&self) -> f64 {
532        if self.total_runs == 0 {
533            return 0.0;
534        }
535        self.successful_runs as f64 / self.total_runs as f64
536    }
537    #[allow(dead_code)]
538    pub fn format_summary(&self) -> String {
539        format!(
540            "Runs: {}/{}, Changes: {}, Time: {}ms",
541            self.successful_runs, self.total_runs, self.total_changes, self.time_ms
542        )
543    }
544}
545/// Analysis cache for MLIRExt.
546#[allow(dead_code)]
547#[derive(Debug)]
548pub struct MLIRExtCache {
549    pub(super) entries: Vec<(u64, Vec<u8>, bool, u32)>,
550    pub(super) cap: usize,
551    pub(super) total_hits: u64,
552    pub(super) total_misses: u64,
553}
554impl MLIRExtCache {
555    #[allow(dead_code)]
556    pub fn new(cap: usize) -> Self {
557        Self {
558            entries: Vec::new(),
559            cap,
560            total_hits: 0,
561            total_misses: 0,
562        }
563    }
564    #[allow(dead_code)]
565    pub fn get(&mut self, key: u64) -> Option<&[u8]> {
566        for e in self.entries.iter_mut() {
567            if e.0 == key && e.2 {
568                e.3 += 1;
569                self.total_hits += 1;
570                return Some(&e.1);
571            }
572        }
573        self.total_misses += 1;
574        None
575    }
576    #[allow(dead_code)]
577    pub fn put(&mut self, key: u64, data: Vec<u8>) {
578        if self.entries.len() >= self.cap {
579            self.entries.retain(|e| e.2);
580            if self.entries.len() >= self.cap {
581                self.entries.remove(0);
582            }
583        }
584        self.entries.push((key, data, true, 0));
585    }
586    #[allow(dead_code)]
587    pub fn invalidate(&mut self) {
588        for e in self.entries.iter_mut() {
589            e.2 = false;
590        }
591    }
592    #[allow(dead_code)]
593    pub fn hit_rate(&self) -> f64 {
594        let t = self.total_hits + self.total_misses;
595        if t == 0 {
596            0.0
597        } else {
598            self.total_hits as f64 / t as f64
599        }
600    }
601    #[allow(dead_code)]
602    pub fn live_count(&self) -> usize {
603        self.entries.iter().filter(|e| e.2).count()
604    }
605}
606/// Configuration for MLIRExt passes.
607#[allow(dead_code)]
608#[derive(Debug, Clone)]
609pub struct MLIRExtPassConfig {
610    pub name: String,
611    pub phase: MLIRExtPassPhase,
612    pub enabled: bool,
613    pub max_iterations: usize,
614    pub debug: u32,
615    pub timeout_ms: Option<u64>,
616}
617impl MLIRExtPassConfig {
618    #[allow(dead_code)]
619    pub fn new(name: impl Into<String>) -> Self {
620        Self {
621            name: name.into(),
622            phase: MLIRExtPassPhase::Middle,
623            enabled: true,
624            max_iterations: 100,
625            debug: 0,
626            timeout_ms: None,
627        }
628    }
629    #[allow(dead_code)]
630    pub fn with_phase(mut self, phase: MLIRExtPassPhase) -> Self {
631        self.phase = phase;
632        self
633    }
634    #[allow(dead_code)]
635    pub fn with_max_iter(mut self, n: usize) -> Self {
636        self.max_iterations = n;
637        self
638    }
639    #[allow(dead_code)]
640    pub fn with_debug(mut self, d: u32) -> Self {
641        self.debug = d;
642        self
643    }
644    #[allow(dead_code)]
645    pub fn disabled(mut self) -> Self {
646        self.enabled = false;
647        self
648    }
649    #[allow(dead_code)]
650    pub fn with_timeout(mut self, ms: u64) -> Self {
651        self.timeout_ms = Some(ms);
652        self
653    }
654    #[allow(dead_code)]
655    pub fn is_debug_enabled(&self) -> bool {
656        self.debug > 0
657    }
658}
659/// Counter for generating fresh SSA value names.
660#[derive(Debug, Default)]
661pub struct SsaCounter {
662    pub(super) counter: u32,
663    pub(super) named: HashMap<String, u32>,
664}
665impl SsaCounter {
666    /// Create a new counter.
667    pub fn new() -> Self {
668        SsaCounter::default()
669    }
670    /// Allocate the next numbered SSA value.
671    pub fn next(&mut self, ty: MlirType) -> MlirValue {
672        let id = self.counter;
673        self.counter += 1;
674        MlirValue::numbered(id, ty)
675    }
676    /// Allocate a named SSA value (deduplicated).
677    pub fn named(&mut self, base: &str, ty: MlirType) -> MlirValue {
678        let count = self.named.entry(base.to_string()).or_insert(0);
679        let name = if *count == 0 {
680            base.to_string()
681        } else {
682            format!("{}_{}", base, count)
683        };
684        *count += 1;
685        MlirValue::named(name, ty)
686    }
687    /// Reset the counter.
688    pub fn reset(&mut self) {
689        self.counter = 0;
690        self.named.clear();
691    }
692}
693#[allow(dead_code)]
694#[derive(Debug, Clone)]
695pub struct MLIRDominatorTree {
696    pub idom: Vec<Option<u32>>,
697    pub dom_children: Vec<Vec<u32>>,
698    pub dom_depth: Vec<u32>,
699}
700impl MLIRDominatorTree {
701    #[allow(dead_code)]
702    pub fn new(size: usize) -> Self {
703        MLIRDominatorTree {
704            idom: vec![None; size],
705            dom_children: vec![Vec::new(); size],
706            dom_depth: vec![0; size],
707        }
708    }
709    #[allow(dead_code)]
710    pub fn set_idom(&mut self, node: usize, idom: u32) {
711        self.idom[node] = Some(idom);
712    }
713    #[allow(dead_code)]
714    pub fn dominates(&self, a: usize, b: usize) -> bool {
715        if a == b {
716            return true;
717        }
718        let mut cur = b;
719        loop {
720            match self.idom[cur] {
721                Some(parent) if parent as usize == a => return true,
722                Some(parent) if parent as usize == cur => return false,
723                Some(parent) => cur = parent as usize,
724                None => return false,
725            }
726        }
727    }
728    #[allow(dead_code)]
729    pub fn depth(&self, node: usize) -> u32 {
730        self.dom_depth.get(node).copied().unwrap_or(0)
731    }
732}
733/// Builder for constructing MLIR operations conveniently.
734pub struct MlirBuilder {
735    pub(super) ssa: SsaCounter,
736    pub(super) ops: Vec<MlirOp>,
737}
738impl MlirBuilder {
739    /// Create a new builder.
740    pub fn new() -> Self {
741        MlirBuilder {
742            ssa: SsaCounter::new(),
743            ops: vec![],
744        }
745    }
746    /// Emit `arith.constant` for integer.
747    pub fn const_int(&mut self, value: i64, bits: u32) -> MlirValue {
748        let ty = MlirType::Integer(bits, false);
749        let result = self.ssa.next(ty.clone());
750        let mut op = MlirOp::unary_result(
751            result.clone(),
752            "arith.constant",
753            vec![],
754            vec![("value".to_string(), MlirAttr::Integer(value, ty))],
755        );
756        op.type_annotations = vec![result.ty.clone()];
757        self.ops.push(op);
758        result
759    }
760    /// Emit `arith.constant` for float.
761    pub fn const_float(&mut self, value: f64, bits: u32) -> MlirValue {
762        let ty = MlirType::Float(bits);
763        let result = self.ssa.next(ty.clone());
764        let op = MlirOp::unary_result(
765            result.clone(),
766            "arith.constant",
767            vec![],
768            vec![("value".to_string(), MlirAttr::Float(value))],
769        );
770        self.ops.push(op);
771        result
772    }
773    /// Emit `arith.addi`.
774    pub fn addi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
775        let ty = lhs.ty.clone();
776        let result = self.ssa.next(ty.clone());
777        let mut op = MlirOp::unary_result(result.clone(), "arith.addi", vec![lhs, rhs], vec![]);
778        op.type_annotations = vec![ty];
779        self.ops.push(op);
780        result
781    }
782    /// Emit `arith.subi`.
783    pub fn subi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
784        let ty = lhs.ty.clone();
785        let result = self.ssa.next(ty.clone());
786        let mut op = MlirOp::unary_result(result.clone(), "arith.subi", vec![lhs, rhs], vec![]);
787        op.type_annotations = vec![ty];
788        self.ops.push(op);
789        result
790    }
791    /// Emit `arith.muli`.
792    pub fn muli(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
793        let ty = lhs.ty.clone();
794        let result = self.ssa.next(ty.clone());
795        let mut op = MlirOp::unary_result(result.clone(), "arith.muli", vec![lhs, rhs], vec![]);
796        op.type_annotations = vec![ty];
797        self.ops.push(op);
798        result
799    }
800    /// Emit `arith.divsi` (signed integer division).
801    pub fn divsi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
802        let ty = lhs.ty.clone();
803        let result = self.ssa.next(ty.clone());
804        let mut op = MlirOp::unary_result(result.clone(), "arith.divsi", vec![lhs, rhs], vec![]);
805        op.type_annotations = vec![ty];
806        self.ops.push(op);
807        result
808    }
809    /// Emit `arith.addf`.
810    pub fn addf(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
811        let ty = lhs.ty.clone();
812        let result = self.ssa.next(ty.clone());
813        let mut op = MlirOp::unary_result(result.clone(), "arith.addf", vec![lhs, rhs], vec![]);
814        op.type_annotations = vec![ty];
815        self.ops.push(op);
816        result
817    }
818    /// Emit `arith.mulf`.
819    pub fn mulf(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
820        let ty = lhs.ty.clone();
821        let result = self.ssa.next(ty.clone());
822        let mut op = MlirOp::unary_result(result.clone(), "arith.mulf", vec![lhs, rhs], vec![]);
823        op.type_annotations = vec![ty];
824        self.ops.push(op);
825        result
826    }
827    /// Emit `arith.cmpi`.
828    pub fn cmpi(&mut self, pred: CmpiPred, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
829        let result = self.ssa.next(MlirType::Integer(1, false));
830        let mut op = MlirOp::unary_result(
831            result.clone(),
832            "arith.cmpi",
833            vec![lhs.clone(), rhs],
834            vec![("predicate".to_string(), MlirAttr::Str(pred.to_string()))],
835        );
836        op.type_annotations = vec![lhs.ty];
837        self.ops.push(op);
838        result
839    }
840    /// Emit `arith.extsi` (sign-extend integer).
841    pub fn extsi(&mut self, val: MlirValue, target_bits: u32) -> MlirValue {
842        let result = self.ssa.next(MlirType::Integer(target_bits, false));
843        let src_ty = val.ty.clone();
844        let dst_ty = result.ty.clone();
845        let mut op = MlirOp::unary_result(result.clone(), "arith.extsi", vec![val], vec![]);
846        op.type_annotations = vec![src_ty, dst_ty];
847        self.ops.push(op);
848        result
849    }
850    /// Emit `arith.trunci` (truncate integer).
851    pub fn trunci(&mut self, val: MlirValue, target_bits: u32) -> MlirValue {
852        let result = self.ssa.next(MlirType::Integer(target_bits, false));
853        let src_ty = val.ty.clone();
854        let dst_ty = result.ty.clone();
855        let mut op = MlirOp::unary_result(result.clone(), "arith.trunci", vec![val], vec![]);
856        op.type_annotations = vec![src_ty, dst_ty];
857        self.ops.push(op);
858        result
859    }
860    /// Emit `math.sin`.
861    pub fn sin(&mut self, val: MlirValue) -> MlirValue {
862        let ty = val.ty.clone();
863        let result = self.ssa.next(ty.clone());
864        let mut op = MlirOp::unary_result(result.clone(), "math.sin", vec![val], vec![]);
865        op.type_annotations = vec![ty];
866        self.ops.push(op);
867        result
868    }
869    /// Emit `math.cos`.
870    pub fn cos(&mut self, val: MlirValue) -> MlirValue {
871        let ty = val.ty.clone();
872        let result = self.ssa.next(ty.clone());
873        let mut op = MlirOp::unary_result(result.clone(), "math.cos", vec![val], vec![]);
874        op.type_annotations = vec![ty];
875        self.ops.push(op);
876        result
877    }
878    /// Emit `math.exp`.
879    pub fn exp(&mut self, val: MlirValue) -> MlirValue {
880        let ty = val.ty.clone();
881        let result = self.ssa.next(ty.clone());
882        let mut op = MlirOp::unary_result(result.clone(), "math.exp", vec![val], vec![]);
883        op.type_annotations = vec![ty];
884        self.ops.push(op);
885        result
886    }
887    /// Emit `math.log`.
888    pub fn log(&mut self, val: MlirValue) -> MlirValue {
889        let ty = val.ty.clone();
890        let result = self.ssa.next(ty.clone());
891        let mut op = MlirOp::unary_result(result.clone(), "math.log", vec![val], vec![]);
892        op.type_annotations = vec![ty];
893        self.ops.push(op);
894        result
895    }
896    /// Emit `math.sqrt`.
897    pub fn sqrt(&mut self, val: MlirValue) -> MlirValue {
898        let ty = val.ty.clone();
899        let result = self.ssa.next(ty.clone());
900        let mut op = MlirOp::unary_result(result.clone(), "math.sqrt", vec![val], vec![]);
901        op.type_annotations = vec![ty];
902        self.ops.push(op);
903        result
904    }
905    /// Emit `memref.alloc`.
906    pub fn alloc(&mut self, elem_ty: MlirType, dims: Vec<i64>) -> MlirValue {
907        let memref_ty = MlirType::MemRef(Box::new(elem_ty), dims, AffineMap::Constant);
908        let result = self.ssa.next(memref_ty.clone());
909        let op = MlirOp::unary_result(result.clone(), "memref.alloc", vec![], vec![]);
910        self.ops.push(op);
911        result
912    }
913    /// Emit `memref.dealloc`.
914    pub fn dealloc(&mut self, memref: MlirValue) {
915        let op = MlirOp::void_op("memref.dealloc", vec![memref], vec![]);
916        self.ops.push(op);
917    }
918    /// Emit `func.return`.
919    pub fn return_op(&mut self, values: Vec<MlirValue>) {
920        let op = MlirOp::void_op("func.return", values, vec![]);
921        self.ops.push(op);
922    }
923    /// Emit `func.call`.
924    pub fn call(
925        &mut self,
926        callee: &str,
927        args: Vec<MlirValue>,
928        result_types: Vec<MlirType>,
929    ) -> Vec<MlirValue> {
930        let results: Vec<MlirValue> = result_types.into_iter().map(|t| self.ssa.next(t)).collect();
931        let mut op = MlirOp {
932            results: results.clone(),
933            op_name: "func.call".to_string(),
934            operands: args,
935            regions: vec![],
936            successors: vec![],
937            attributes: vec![("callee".to_string(), MlirAttr::Symbol(callee.to_string()))],
938            type_annotations: vec![],
939        };
940        op.type_annotations = results.iter().map(|r| r.ty.clone()).collect();
941        self.ops.push(op);
942        results
943    }
944    /// Take the accumulated ops.
945    pub fn take_ops(&mut self) -> Vec<MlirOp> {
946        std::mem::take(&mut self.ops)
947    }
948    /// Build a basic block from accumulated ops.
949    pub fn finish_block(&mut self, args: Vec<MlirValue>) -> MlirBlock {
950        let ops = self.take_ops();
951        MlirBlock::entry(args, ops)
952    }
953}
954/// MLIR attribute representation.
955#[derive(Debug, Clone, PartialEq)]
956pub enum MlirAttr {
957    /// Integer attribute: `42 : i64`
958    Integer(i64, MlirType),
959    /// Float attribute: `3.14 : f64`
960    Float(f64),
961    /// String attribute: `"hello"`
962    Str(String),
963    /// Type attribute: `i32`
964    Type(MlirType),
965    /// Array attribute: `[1, 2, 3]`
966    Array(Vec<MlirAttr>),
967    /// Dictionary attribute: `{key = val, ...}`
968    Dict(Vec<(String, MlirAttr)>),
969    /// Affine map: `affine_map<(d0) -> (d0)>`
970    AffineMap(String),
971    /// Unit attribute (presence marker)
972    Unit,
973    /// Boolean attribute
974    Bool(bool),
975    /// Symbol reference: `@name`
976    Symbol(String),
977    /// Dense elements: `dense<[1.0, 2.0]> : tensor<2xf32>`
978    Dense(Vec<MlirAttr>, MlirType),
979}
980/// MLIR type system.
981#[derive(Debug, Clone, PartialEq)]
982pub enum MlirType {
983    /// Signless integer: `i1`, `i8`, `i16`, `i32`, `i64`
984    /// bool: signed = false (signless), i.e. `iN`
985    /// With signed = true, displayed as `si<N>` (for annotation only)
986    Integer(u32, bool),
987    /// Float types: `f16`, `f32`, `f64`, `f80`, `f128`, `bf16`
988    Float(u32),
989    /// Index type (platform-dependent integer, pointer-sized)
990    Index,
991    /// MemRef type: `memref<NxMxT, affine_map>` or `memref<?xT>`
992    MemRef(Box<MlirType>, Vec<i64>, AffineMap),
993    /// Ranked tensor: `tensor<2x3xf32>` or `tensor<?x4xi64>`
994    Tensor(Vec<i64>, Box<MlirType>),
995    /// Vector type (always statically shaped): `vector<4xf32>`
996    Vector(Vec<u64>, Box<MlirType>),
997    /// Tuple type: `tuple<i32, f64>`
998    Tuple(Vec<MlirType>),
999    /// None type
1000    NoneType,
1001    /// Custom / opaque type (e.g., from external dialect)
1002    Custom(String),
1003    /// Function type: `(i32, i64) -> f32`
1004    FuncType(Vec<MlirType>, Vec<MlirType>),
1005    /// Complex type: `complex<f32>`
1006    Complex(Box<MlirType>),
1007    /// Unranked memref: `memref<*xT>`
1008    UnrankedMemRef(Box<MlirType>),
1009}
1010/// Liveness analysis for MLIRExt.
1011#[allow(dead_code)]
1012#[derive(Debug, Clone, Default)]
1013pub struct MLIRExtLiveness {
1014    pub live_in: Vec<Vec<usize>>,
1015    pub live_out: Vec<Vec<usize>>,
1016    pub defs: Vec<Vec<usize>>,
1017    pub uses: Vec<Vec<usize>>,
1018}
1019impl MLIRExtLiveness {
1020    #[allow(dead_code)]
1021    pub fn new(n: usize) -> Self {
1022        Self {
1023            live_in: vec![Vec::new(); n],
1024            live_out: vec![Vec::new(); n],
1025            defs: vec![Vec::new(); n],
1026            uses: vec![Vec::new(); n],
1027        }
1028    }
1029    #[allow(dead_code)]
1030    pub fn live_in(&self, b: usize, v: usize) -> bool {
1031        self.live_in.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1032    }
1033    #[allow(dead_code)]
1034    pub fn live_out(&self, b: usize, v: usize) -> bool {
1035        self.live_out
1036            .get(b)
1037            .map(|s| s.contains(&v))
1038            .unwrap_or(false)
1039    }
1040    #[allow(dead_code)]
1041    pub fn add_def(&mut self, b: usize, v: usize) {
1042        if let Some(s) = self.defs.get_mut(b) {
1043            if !s.contains(&v) {
1044                s.push(v);
1045            }
1046        }
1047    }
1048    #[allow(dead_code)]
1049    pub fn add_use(&mut self, b: usize, v: usize) {
1050        if let Some(s) = self.uses.get_mut(b) {
1051            if !s.contains(&v) {
1052                s.push(v);
1053            }
1054        }
1055    }
1056    #[allow(dead_code)]
1057    pub fn var_is_used_in_block(&self, b: usize, v: usize) -> bool {
1058        self.uses.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1059    }
1060    #[allow(dead_code)]
1061    pub fn var_is_def_in_block(&self, b: usize, v: usize) -> bool {
1062        self.defs.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1063    }
1064}
1065/// An MLIR region (contains a list of basic blocks).
1066#[derive(Debug, Clone)]
1067pub struct MlirRegion {
1068    /// Blocks in this region
1069    pub blocks: Vec<MlirBlock>,
1070}
1071impl MlirRegion {
1072    /// Create a region with a single entry block.
1073    pub fn single_block(block: MlirBlock) -> Self {
1074        MlirRegion {
1075            blocks: vec![block],
1076        }
1077    }
1078    /// Create an empty region.
1079    pub fn empty() -> Self {
1080        MlirRegion { blocks: vec![] }
1081    }
1082}
1083/// Statistics for MLIRExt passes.
1084#[allow(dead_code)]
1085#[derive(Debug, Clone, Default)]
1086pub struct MLIRExtPassStats {
1087    pub iterations: usize,
1088    pub changed: bool,
1089    pub nodes_visited: usize,
1090    pub nodes_modified: usize,
1091    pub time_ms: u64,
1092    pub memory_bytes: usize,
1093    pub errors: usize,
1094}
1095impl MLIRExtPassStats {
1096    #[allow(dead_code)]
1097    pub fn new() -> Self {
1098        Self::default()
1099    }
1100    #[allow(dead_code)]
1101    pub fn visit(&mut self) {
1102        self.nodes_visited += 1;
1103    }
1104    #[allow(dead_code)]
1105    pub fn modify(&mut self) {
1106        self.nodes_modified += 1;
1107        self.changed = true;
1108    }
1109    #[allow(dead_code)]
1110    pub fn iterate(&mut self) {
1111        self.iterations += 1;
1112    }
1113    #[allow(dead_code)]
1114    pub fn error(&mut self) {
1115        self.errors += 1;
1116    }
1117    #[allow(dead_code)]
1118    pub fn efficiency(&self) -> f64 {
1119        if self.nodes_visited == 0 {
1120            0.0
1121        } else {
1122            self.nodes_modified as f64 / self.nodes_visited as f64
1123        }
1124    }
1125    #[allow(dead_code)]
1126    pub fn merge(&mut self, o: &MLIRExtPassStats) {
1127        self.iterations += o.iterations;
1128        self.changed |= o.changed;
1129        self.nodes_visited += o.nodes_visited;
1130        self.nodes_modified += o.nodes_modified;
1131        self.time_ms += o.time_ms;
1132        self.memory_bytes = self.memory_bytes.max(o.memory_bytes);
1133        self.errors += o.errors;
1134    }
1135}
1136/// An MLIR basic block.
1137#[derive(Debug, Clone)]
1138pub struct MlirBlock {
1139    /// Block label (None for entry block)
1140    pub label: Option<String>,
1141    /// Block arguments: (value, type) pairs
1142    pub arguments: Vec<MlirValue>,
1143    /// Operations in this block
1144    pub body: Vec<MlirOp>,
1145    /// Terminator operation (explicit for clarity, also included in body)
1146    pub terminator: Option<MlirOp>,
1147}
1148impl MlirBlock {
1149    /// Create an entry block (no label).
1150    pub fn entry(arguments: Vec<MlirValue>, body: Vec<MlirOp>) -> Self {
1151        MlirBlock {
1152            label: None,
1153            arguments,
1154            body,
1155            terminator: None,
1156        }
1157    }
1158    /// Create a labeled block.
1159    pub fn labeled(label: impl Into<String>, arguments: Vec<MlirValue>, body: Vec<MlirOp>) -> Self {
1160        MlirBlock {
1161            label: Some(label.into()),
1162            arguments,
1163            body,
1164            terminator: None,
1165        }
1166    }
1167}
1168/// Worklist for MLIRExt.
1169#[allow(dead_code)]
1170#[derive(Debug, Clone)]
1171pub struct MLIRExtWorklist {
1172    pub(super) items: std::collections::VecDeque<usize>,
1173    pub(super) present: Vec<bool>,
1174}
1175impl MLIRExtWorklist {
1176    #[allow(dead_code)]
1177    pub fn new(capacity: usize) -> Self {
1178        Self {
1179            items: std::collections::VecDeque::new(),
1180            present: vec![false; capacity],
1181        }
1182    }
1183    #[allow(dead_code)]
1184    pub fn push(&mut self, id: usize) {
1185        if id < self.present.len() && !self.present[id] {
1186            self.present[id] = true;
1187            self.items.push_back(id);
1188        }
1189    }
1190    #[allow(dead_code)]
1191    pub fn push_front(&mut self, id: usize) {
1192        if id < self.present.len() && !self.present[id] {
1193            self.present[id] = true;
1194            self.items.push_front(id);
1195        }
1196    }
1197    #[allow(dead_code)]
1198    pub fn pop(&mut self) -> Option<usize> {
1199        let id = self.items.pop_front()?;
1200        if id < self.present.len() {
1201            self.present[id] = false;
1202        }
1203        Some(id)
1204    }
1205    #[allow(dead_code)]
1206    pub fn is_empty(&self) -> bool {
1207        self.items.is_empty()
1208    }
1209    #[allow(dead_code)]
1210    pub fn len(&self) -> usize {
1211        self.items.len()
1212    }
1213    #[allow(dead_code)]
1214    pub fn contains(&self, id: usize) -> bool {
1215        id < self.present.len() && self.present[id]
1216    }
1217    #[allow(dead_code)]
1218    pub fn drain_all(&mut self) -> Vec<usize> {
1219        let v: Vec<usize> = self.items.drain(..).collect();
1220        for &id in &v {
1221            if id < self.present.len() {
1222                self.present[id] = false;
1223            }
1224        }
1225        v
1226    }
1227}
1228/// A single MLIR operation.
1229#[derive(Debug, Clone)]
1230pub struct MlirOp {
1231    /// SSA result values (may be empty for side-effecting ops)
1232    pub results: Vec<MlirValue>,
1233    /// Operation name: `arith.addi`, `func.return`, etc.
1234    pub op_name: String,
1235    /// Operands (SSA values used by this op)
1236    pub operands: Vec<MlirValue>,
1237    /// Nested regions (for scf.if, scf.for, func.func, etc.)
1238    pub regions: Vec<MlirRegion>,
1239    /// Successor block labels (for cf.br, cf.cond_br)
1240    pub successors: Vec<String>,
1241    /// Named attributes: `{value = 42 : i64}`
1242    pub attributes: Vec<(String, MlirAttr)>,
1243    /// Type annotations if needed (e.g., the result types for arith ops)
1244    pub type_annotations: Vec<MlirType>,
1245}
1246impl MlirOp {
1247    /// Create a simple op with one result.
1248    pub fn unary_result(
1249        result: MlirValue,
1250        op_name: impl Into<String>,
1251        operands: Vec<MlirValue>,
1252        attrs: Vec<(String, MlirAttr)>,
1253    ) -> Self {
1254        MlirOp {
1255            results: vec![result],
1256            op_name: op_name.into(),
1257            operands,
1258            regions: vec![],
1259            successors: vec![],
1260            attributes: attrs,
1261            type_annotations: vec![],
1262        }
1263    }
1264    /// Create a void op (no results).
1265    pub fn void_op(
1266        op_name: impl Into<String>,
1267        operands: Vec<MlirValue>,
1268        attrs: Vec<(String, MlirAttr)>,
1269    ) -> Self {
1270        MlirOp {
1271            results: vec![],
1272            op_name: op_name.into(),
1273            operands,
1274            regions: vec![],
1275            successors: vec![],
1276            attributes: attrs,
1277            type_annotations: vec![],
1278        }
1279    }
1280}
1281#[allow(dead_code)]
1282#[derive(Debug, Clone)]
1283pub struct MLIRAnalysisCache {
1284    pub(super) entries: std::collections::HashMap<String, MLIRCacheEntry>,
1285    pub(super) max_size: usize,
1286    pub(super) hits: u64,
1287    pub(super) misses: u64,
1288}
1289impl MLIRAnalysisCache {
1290    #[allow(dead_code)]
1291    pub fn new(max_size: usize) -> Self {
1292        MLIRAnalysisCache {
1293            entries: std::collections::HashMap::new(),
1294            max_size,
1295            hits: 0,
1296            misses: 0,
1297        }
1298    }
1299    #[allow(dead_code)]
1300    pub fn get(&mut self, key: &str) -> Option<&MLIRCacheEntry> {
1301        if self.entries.contains_key(key) {
1302            self.hits += 1;
1303            self.entries.get(key)
1304        } else {
1305            self.misses += 1;
1306            None
1307        }
1308    }
1309    #[allow(dead_code)]
1310    pub fn insert(&mut self, key: String, data: Vec<u8>) {
1311        if self.entries.len() >= self.max_size {
1312            if let Some(oldest) = self.entries.keys().next().cloned() {
1313                self.entries.remove(&oldest);
1314            }
1315        }
1316        self.entries.insert(
1317            key.clone(),
1318            MLIRCacheEntry {
1319                key,
1320                data,
1321                timestamp: 0,
1322                valid: true,
1323            },
1324        );
1325    }
1326    #[allow(dead_code)]
1327    pub fn invalidate(&mut self, key: &str) {
1328        if let Some(entry) = self.entries.get_mut(key) {
1329            entry.valid = false;
1330        }
1331    }
1332    #[allow(dead_code)]
1333    pub fn clear(&mut self) {
1334        self.entries.clear();
1335    }
1336    #[allow(dead_code)]
1337    pub fn hit_rate(&self) -> f64 {
1338        let total = self.hits + self.misses;
1339        if total == 0 {
1340            return 0.0;
1341        }
1342        self.hits as f64 / total as f64
1343    }
1344    #[allow(dead_code)]
1345    pub fn size(&self) -> usize {
1346        self.entries.len()
1347    }
1348}
1349#[allow(dead_code)]
1350#[derive(Debug, Clone)]
1351pub struct MLIRPassConfig {
1352    pub phase: MLIRPassPhase,
1353    pub enabled: bool,
1354    pub max_iterations: u32,
1355    pub debug_output: bool,
1356    pub pass_name: String,
1357}
1358impl MLIRPassConfig {
1359    #[allow(dead_code)]
1360    pub fn new(name: impl Into<String>, phase: MLIRPassPhase) -> Self {
1361        MLIRPassConfig {
1362            phase,
1363            enabled: true,
1364            max_iterations: 10,
1365            debug_output: false,
1366            pass_name: name.into(),
1367        }
1368    }
1369    #[allow(dead_code)]
1370    pub fn disabled(mut self) -> Self {
1371        self.enabled = false;
1372        self
1373    }
1374    #[allow(dead_code)]
1375    pub fn with_debug(mut self) -> Self {
1376        self.debug_output = true;
1377        self
1378    }
1379    #[allow(dead_code)]
1380    pub fn max_iter(mut self, n: u32) -> Self {
1381        self.max_iterations = n;
1382        self
1383    }
1384}
1385/// Float comparison predicates for `arith.cmpf`.
1386#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1387pub enum CmpfPred {
1388    /// Ordered equal
1389    Oeq,
1390    /// Ordered not equal
1391    One,
1392    /// Ordered less than
1393    Olt,
1394    /// Ordered less than or equal
1395    Ole,
1396    /// Ordered greater than
1397    Ogt,
1398    /// Ordered greater than or equal
1399    Oge,
1400    /// Unordered equal
1401    Ueq,
1402    /// Unordered not equal
1403    Une,
1404}
1405/// Affine map representation for MemRef and Affine dialect operations.
1406#[derive(Debug, Clone, PartialEq)]
1407pub enum AffineMap {
1408    /// Identity map: `(d0, d1) -> (d0, d1)`
1409    Identity(usize),
1410    /// Constant map: `() -> ()`
1411    Constant,
1412    /// Custom affine map expression string
1413    Custom(String),
1414}
1415/// MLIR SSA value (operand).
1416#[derive(Debug, Clone, PartialEq)]
1417pub struct MlirValue {
1418    /// The name/id of the SSA value (without the `%` prefix)
1419    pub name: String,
1420    /// Type of the value
1421    pub ty: MlirType,
1422}
1423impl MlirValue {
1424    /// Create a numbered SSA value.
1425    pub fn numbered(id: u32, ty: MlirType) -> Self {
1426        MlirValue {
1427            name: id.to_string(),
1428            ty,
1429        }
1430    }
1431    /// Create a named SSA value.
1432    pub fn named(name: impl Into<String>, ty: MlirType) -> Self {
1433        MlirValue {
1434            name: name.into(),
1435            ty,
1436        }
1437    }
1438}
1439/// Dependency graph for MLIRExt.
1440#[allow(dead_code)]
1441#[derive(Debug, Clone)]
1442pub struct MLIRExtDepGraph {
1443    pub(super) n: usize,
1444    pub(super) adj: Vec<Vec<usize>>,
1445    pub(super) rev: Vec<Vec<usize>>,
1446    pub(super) edge_count: usize,
1447}
1448impl MLIRExtDepGraph {
1449    #[allow(dead_code)]
1450    pub fn new(n: usize) -> Self {
1451        Self {
1452            n,
1453            adj: vec![Vec::new(); n],
1454            rev: vec![Vec::new(); n],
1455            edge_count: 0,
1456        }
1457    }
1458    #[allow(dead_code)]
1459    pub fn add_edge(&mut self, from: usize, to: usize) {
1460        if from < self.n && to < self.n {
1461            if !self.adj[from].contains(&to) {
1462                self.adj[from].push(to);
1463                self.rev[to].push(from);
1464                self.edge_count += 1;
1465            }
1466        }
1467    }
1468    #[allow(dead_code)]
1469    pub fn succs(&self, n: usize) -> &[usize] {
1470        self.adj.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1471    }
1472    #[allow(dead_code)]
1473    pub fn preds(&self, n: usize) -> &[usize] {
1474        self.rev.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1475    }
1476    #[allow(dead_code)]
1477    pub fn topo_sort(&self) -> Option<Vec<usize>> {
1478        let mut deg: Vec<usize> = (0..self.n).map(|i| self.rev[i].len()).collect();
1479        let mut q: std::collections::VecDeque<usize> =
1480            (0..self.n).filter(|&i| deg[i] == 0).collect();
1481        let mut out = Vec::with_capacity(self.n);
1482        while let Some(u) = q.pop_front() {
1483            out.push(u);
1484            for &v in &self.adj[u] {
1485                deg[v] -= 1;
1486                if deg[v] == 0 {
1487                    q.push_back(v);
1488                }
1489            }
1490        }
1491        if out.len() == self.n {
1492            Some(out)
1493        } else {
1494            None
1495        }
1496    }
1497    #[allow(dead_code)]
1498    pub fn has_cycle(&self) -> bool {
1499        self.topo_sort().is_none()
1500    }
1501    #[allow(dead_code)]
1502    pub fn reachable(&self, start: usize) -> Vec<usize> {
1503        let mut vis = vec![false; self.n];
1504        let mut stk = vec![start];
1505        let mut out = Vec::new();
1506        while let Some(u) = stk.pop() {
1507            if u < self.n && !vis[u] {
1508                vis[u] = true;
1509                out.push(u);
1510                for &v in &self.adj[u] {
1511                    if !vis[v] {
1512                        stk.push(v);
1513                    }
1514                }
1515            }
1516        }
1517        out
1518    }
1519    #[allow(dead_code)]
1520    pub fn scc(&self) -> Vec<Vec<usize>> {
1521        let mut visited = vec![false; self.n];
1522        let mut order = Vec::new();
1523        for i in 0..self.n {
1524            if !visited[i] {
1525                let mut stk = vec![(i, 0usize)];
1526                while let Some((u, idx)) = stk.last_mut() {
1527                    if !visited[*u] {
1528                        visited[*u] = true;
1529                    }
1530                    if *idx < self.adj[*u].len() {
1531                        let v = self.adj[*u][*idx];
1532                        *idx += 1;
1533                        if !visited[v] {
1534                            stk.push((v, 0));
1535                        }
1536                    } else {
1537                        order.push(*u);
1538                        stk.pop();
1539                    }
1540                }
1541            }
1542        }
1543        let mut comp = vec![usize::MAX; self.n];
1544        let mut components: Vec<Vec<usize>> = Vec::new();
1545        for &start in order.iter().rev() {
1546            if comp[start] == usize::MAX {
1547                let cid = components.len();
1548                let mut component = Vec::new();
1549                let mut stk = vec![start];
1550                while let Some(u) = stk.pop() {
1551                    if comp[u] == usize::MAX {
1552                        comp[u] = cid;
1553                        component.push(u);
1554                        for &v in &self.rev[u] {
1555                            if comp[v] == usize::MAX {
1556                                stk.push(v);
1557                            }
1558                        }
1559                    }
1560                }
1561                components.push(component);
1562            }
1563        }
1564        components
1565    }
1566    #[allow(dead_code)]
1567    pub fn node_count(&self) -> usize {
1568        self.n
1569    }
1570    #[allow(dead_code)]
1571    pub fn edge_count(&self) -> usize {
1572        self.edge_count
1573    }
1574}
1575/// Top-level MLIR module.
1576#[derive(Debug, Clone)]
1577pub struct MlirModule {
1578    /// Optional module name/attribute
1579    pub name: Option<String>,
1580    /// Function definitions
1581    pub functions: Vec<MlirFunc>,
1582    /// Global variables
1583    pub globals: Vec<MlirGlobal>,
1584    /// Dialect requirements (for `mlir-opt` pass specification)
1585    pub required_dialects: Vec<MlirDialect>,
1586}
1587impl MlirModule {
1588    /// Create a new empty MLIR module.
1589    pub fn new() -> Self {
1590        MlirModule {
1591            name: None,
1592            functions: vec![],
1593            globals: vec![],
1594            required_dialects: vec![],
1595        }
1596    }
1597    /// Create a module with a name.
1598    pub fn named(name: impl Into<String>) -> Self {
1599        MlirModule {
1600            name: Some(name.into()),
1601            functions: vec![],
1602            globals: vec![],
1603            required_dialects: vec![],
1604        }
1605    }
1606    /// Add a function to the module.
1607    pub fn add_function(&mut self, func: MlirFunc) {
1608        self.functions.push(func);
1609    }
1610    /// Add a global to the module.
1611    pub fn add_global(&mut self, global: MlirGlobal) {
1612        self.globals.push(global);
1613    }
1614    /// Generate textual MLIR format.
1615    pub fn emit(&self) -> String {
1616        let mut out = String::new();
1617        if let Some(name) = &self.name {
1618            out.push_str(&format!("module @{} {{\n", name));
1619        } else {
1620            out.push_str("module {\n");
1621        }
1622        for global in &self.globals {
1623            out.push_str(&global.emit());
1624        }
1625        for func in &self.functions {
1626            out.push_str(&func.emit());
1627        }
1628        out.push_str("}\n");
1629        out
1630    }
1631}
1632/// Arithmetic comparison predicates for `arith.cmpi`.
1633#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1634pub enum CmpiPred {
1635    /// Equal
1636    Eq,
1637    /// Not equal
1638    Ne,
1639    /// Signed less than
1640    Slt,
1641    /// Signed less than or equal
1642    Sle,
1643    /// Signed greater than
1644    Sgt,
1645    /// Signed greater than or equal
1646    Sge,
1647    /// Unsigned less than
1648    Ult,
1649    /// Unsigned less than or equal
1650    Ule,
1651    /// Unsigned greater than
1652    Ugt,
1653    /// Unsigned greater than or equal
1654    Uge,
1655}
1656/// Dominator tree for MLIRExt.
1657#[allow(dead_code)]
1658#[derive(Debug, Clone)]
1659pub struct MLIRExtDomTree {
1660    pub(super) idom: Vec<Option<usize>>,
1661    pub(super) children: Vec<Vec<usize>>,
1662    pub(super) depth: Vec<usize>,
1663}
1664impl MLIRExtDomTree {
1665    #[allow(dead_code)]
1666    pub fn new(n: usize) -> Self {
1667        Self {
1668            idom: vec![None; n],
1669            children: vec![Vec::new(); n],
1670            depth: vec![0; n],
1671        }
1672    }
1673    #[allow(dead_code)]
1674    pub fn set_idom(&mut self, node: usize, dom: usize) {
1675        if node < self.idom.len() {
1676            self.idom[node] = Some(dom);
1677            if dom < self.children.len() {
1678                self.children[dom].push(node);
1679            }
1680            self.depth[node] = if dom < self.depth.len() {
1681                self.depth[dom] + 1
1682            } else {
1683                1
1684            };
1685        }
1686    }
1687    #[allow(dead_code)]
1688    pub fn dominates(&self, a: usize, mut b: usize) -> bool {
1689        if a == b {
1690            return true;
1691        }
1692        let n = self.idom.len();
1693        for _ in 0..n {
1694            match self.idom.get(b).copied().flatten() {
1695                None => return false,
1696                Some(p) if p == a => return true,
1697                Some(p) if p == b => return false,
1698                Some(p) => b = p,
1699            }
1700        }
1701        false
1702    }
1703    #[allow(dead_code)]
1704    pub fn children_of(&self, n: usize) -> &[usize] {
1705        self.children.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1706    }
1707    #[allow(dead_code)]
1708    pub fn depth_of(&self, n: usize) -> usize {
1709        self.depth.get(n).copied().unwrap_or(0)
1710    }
1711    #[allow(dead_code)]
1712    pub fn lca(&self, mut a: usize, mut b: usize) -> usize {
1713        let n = self.idom.len();
1714        for _ in 0..(2 * n) {
1715            if a == b {
1716                return a;
1717            }
1718            if self.depth_of(a) > self.depth_of(b) {
1719                a = self.idom.get(a).and_then(|x| *x).unwrap_or(a);
1720            } else {
1721                b = self.idom.get(b).and_then(|x| *x).unwrap_or(b);
1722            }
1723        }
1724        0
1725    }
1726}
1727/// MLIR global variable.
1728#[derive(Debug, Clone)]
1729pub struct MlirGlobal {
1730    /// Global name (without `@`)
1731    pub name: String,
1732    /// Type of the global
1733    pub ty: MlirType,
1734    /// Initial value (attribute)
1735    pub initial_value: Option<MlirAttr>,
1736    /// Whether this is a constant
1737    pub is_constant: bool,
1738    /// Linkage: public, private, etc.
1739    pub linkage: String,
1740}
1741impl MlirGlobal {
1742    /// Create a simple global constant.
1743    pub fn constant(name: impl Into<String>, ty: MlirType, value: MlirAttr) -> Self {
1744        MlirGlobal {
1745            name: name.into(),
1746            ty,
1747            initial_value: Some(value),
1748            is_constant: true,
1749            linkage: "public".to_string(),
1750        }
1751    }
1752    /// Emit the global as MLIR text.
1753    pub fn emit(&self) -> String {
1754        let mut out = String::new();
1755        out.push_str("  memref.global ");
1756        if self.is_constant {
1757            out.push_str("constant ");
1758        }
1759        out.push_str(&format!("@{} : {}", self.name, self.ty));
1760        if let Some(v) = &self.initial_value {
1761            out.push_str(&format!(" = {}", v));
1762        }
1763        out.push('\n');
1764        out
1765    }
1766}
1767/// Pass registry for MLIRExt.
1768#[allow(dead_code)]
1769#[derive(Debug, Default)]
1770pub struct MLIRExtPassRegistry {
1771    pub(super) configs: Vec<MLIRExtPassConfig>,
1772    pub(super) stats: Vec<MLIRExtPassStats>,
1773}
1774impl MLIRExtPassRegistry {
1775    #[allow(dead_code)]
1776    pub fn new() -> Self {
1777        Self::default()
1778    }
1779    #[allow(dead_code)]
1780    pub fn register(&mut self, c: MLIRExtPassConfig) {
1781        self.stats.push(MLIRExtPassStats::new());
1782        self.configs.push(c);
1783    }
1784    #[allow(dead_code)]
1785    pub fn len(&self) -> usize {
1786        self.configs.len()
1787    }
1788    #[allow(dead_code)]
1789    pub fn is_empty(&self) -> bool {
1790        self.configs.is_empty()
1791    }
1792    #[allow(dead_code)]
1793    pub fn get(&self, i: usize) -> Option<&MLIRExtPassConfig> {
1794        self.configs.get(i)
1795    }
1796    #[allow(dead_code)]
1797    pub fn get_stats(&self, i: usize) -> Option<&MLIRExtPassStats> {
1798        self.stats.get(i)
1799    }
1800    #[allow(dead_code)]
1801    pub fn enabled_passes(&self) -> Vec<&MLIRExtPassConfig> {
1802        self.configs.iter().filter(|c| c.enabled).collect()
1803    }
1804    #[allow(dead_code)]
1805    pub fn passes_in_phase(&self, ph: &MLIRExtPassPhase) -> Vec<&MLIRExtPassConfig> {
1806        self.configs
1807            .iter()
1808            .filter(|c| c.enabled && &c.phase == ph)
1809            .collect()
1810    }
1811    #[allow(dead_code)]
1812    pub fn total_nodes_visited(&self) -> usize {
1813        self.stats.iter().map(|s| s.nodes_visited).sum()
1814    }
1815    #[allow(dead_code)]
1816    pub fn any_changed(&self) -> bool {
1817        self.stats.iter().any(|s| s.changed)
1818    }
1819}
1820#[allow(dead_code)]
1821#[derive(Debug, Clone)]
1822pub struct MLIRDepGraph {
1823    pub(super) nodes: Vec<u32>,
1824    pub(super) edges: Vec<(u32, u32)>,
1825}
1826impl MLIRDepGraph {
1827    #[allow(dead_code)]
1828    pub fn new() -> Self {
1829        MLIRDepGraph {
1830            nodes: Vec::new(),
1831            edges: Vec::new(),
1832        }
1833    }
1834    #[allow(dead_code)]
1835    pub fn add_node(&mut self, id: u32) {
1836        if !self.nodes.contains(&id) {
1837            self.nodes.push(id);
1838        }
1839    }
1840    #[allow(dead_code)]
1841    pub fn add_dep(&mut self, dep: u32, dependent: u32) {
1842        self.add_node(dep);
1843        self.add_node(dependent);
1844        self.edges.push((dep, dependent));
1845    }
1846    #[allow(dead_code)]
1847    pub fn dependents_of(&self, node: u32) -> Vec<u32> {
1848        self.edges
1849            .iter()
1850            .filter(|(d, _)| *d == node)
1851            .map(|(_, dep)| *dep)
1852            .collect()
1853    }
1854    #[allow(dead_code)]
1855    pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
1856        self.edges
1857            .iter()
1858            .filter(|(_, dep)| *dep == node)
1859            .map(|(d, _)| *d)
1860            .collect()
1861    }
1862    #[allow(dead_code)]
1863    pub fn topological_sort(&self) -> Vec<u32> {
1864        let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
1865        for &n in &self.nodes {
1866            in_degree.insert(n, 0);
1867        }
1868        for (_, dep) in &self.edges {
1869            *in_degree.entry(*dep).or_insert(0) += 1;
1870        }
1871        let mut queue: std::collections::VecDeque<u32> = self
1872            .nodes
1873            .iter()
1874            .filter(|&&n| in_degree[&n] == 0)
1875            .copied()
1876            .collect();
1877        let mut result = Vec::new();
1878        while let Some(node) = queue.pop_front() {
1879            result.push(node);
1880            for dep in self.dependents_of(node) {
1881                let cnt = in_degree.entry(dep).or_insert(0);
1882                *cnt = cnt.saturating_sub(1);
1883                if *cnt == 0 {
1884                    queue.push_back(dep);
1885                }
1886            }
1887        }
1888        result
1889    }
1890    #[allow(dead_code)]
1891    pub fn has_cycle(&self) -> bool {
1892        self.topological_sort().len() < self.nodes.len()
1893    }
1894}
1895#[allow(dead_code)]
1896pub struct MLIRPassRegistry {
1897    pub(super) configs: Vec<MLIRPassConfig>,
1898    pub(super) stats: std::collections::HashMap<String, MLIRPassStats>,
1899}
1900impl MLIRPassRegistry {
1901    #[allow(dead_code)]
1902    pub fn new() -> Self {
1903        MLIRPassRegistry {
1904            configs: Vec::new(),
1905            stats: std::collections::HashMap::new(),
1906        }
1907    }
1908    #[allow(dead_code)]
1909    pub fn register(&mut self, config: MLIRPassConfig) {
1910        self.stats
1911            .insert(config.pass_name.clone(), MLIRPassStats::new());
1912        self.configs.push(config);
1913    }
1914    #[allow(dead_code)]
1915    pub fn enabled_passes(&self) -> Vec<&MLIRPassConfig> {
1916        self.configs.iter().filter(|c| c.enabled).collect()
1917    }
1918    #[allow(dead_code)]
1919    pub fn get_stats(&self, name: &str) -> Option<&MLIRPassStats> {
1920        self.stats.get(name)
1921    }
1922    #[allow(dead_code)]
1923    pub fn total_passes(&self) -> usize {
1924        self.configs.len()
1925    }
1926    #[allow(dead_code)]
1927    pub fn enabled_count(&self) -> usize {
1928        self.enabled_passes().len()
1929    }
1930    #[allow(dead_code)]
1931    pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
1932        if let Some(stats) = self.stats.get_mut(name) {
1933            stats.record_run(changes, time_ms, iter);
1934        }
1935    }
1936}
1937#[allow(dead_code)]
1938pub struct MLIRConstantFoldingHelper;
1939impl MLIRConstantFoldingHelper {
1940    #[allow(dead_code)]
1941    pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
1942        a.checked_add(b)
1943    }
1944    #[allow(dead_code)]
1945    pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
1946        a.checked_sub(b)
1947    }
1948    #[allow(dead_code)]
1949    pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
1950        a.checked_mul(b)
1951    }
1952    #[allow(dead_code)]
1953    pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
1954        if b == 0 {
1955            None
1956        } else {
1957            a.checked_div(b)
1958        }
1959    }
1960    #[allow(dead_code)]
1961    pub fn fold_add_f64(a: f64, b: f64) -> f64 {
1962        a + b
1963    }
1964    #[allow(dead_code)]
1965    pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
1966        a * b
1967    }
1968    #[allow(dead_code)]
1969    pub fn fold_neg_i64(a: i64) -> Option<i64> {
1970        a.checked_neg()
1971    }
1972    #[allow(dead_code)]
1973    pub fn fold_not_bool(a: bool) -> bool {
1974        !a
1975    }
1976    #[allow(dead_code)]
1977    pub fn fold_and_bool(a: bool, b: bool) -> bool {
1978        a && b
1979    }
1980    #[allow(dead_code)]
1981    pub fn fold_or_bool(a: bool, b: bool) -> bool {
1982        a || b
1983    }
1984    #[allow(dead_code)]
1985    pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
1986        a.checked_shl(b)
1987    }
1988    #[allow(dead_code)]
1989    pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
1990        a.checked_shr(b)
1991    }
1992    #[allow(dead_code)]
1993    pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
1994        if b == 0 {
1995            None
1996        } else {
1997            Some(a % b)
1998        }
1999    }
2000    #[allow(dead_code)]
2001    pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
2002        a & b
2003    }
2004    #[allow(dead_code)]
2005    pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
2006        a | b
2007    }
2008    #[allow(dead_code)]
2009    pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
2010        a ^ b
2011    }
2012    #[allow(dead_code)]
2013    pub fn fold_bitnot_i64(a: i64) -> i64 {
2014        !a
2015    }
2016}