1use std::collections::HashMap;
6
7use std::collections::{HashSet, VecDeque};
8
9#[derive(Debug, Clone)]
11pub struct MlirFunc {
12 pub name: String,
14 pub args: Vec<(String, MlirType)>,
16 pub results: Vec<MlirType>,
18 pub body: MlirRegion,
20 pub attributes: Vec<(String, MlirAttr)>,
22 pub is_declaration: bool,
24}
25impl MlirFunc {
26 pub fn new(
28 name: impl Into<String>,
29 args: Vec<(String, MlirType)>,
30 results: Vec<MlirType>,
31 body: MlirRegion,
32 ) -> Self {
33 MlirFunc {
34 name: name.into(),
35 args,
36 results,
37 body,
38 attributes: vec![],
39 is_declaration: false,
40 }
41 }
42 pub fn declaration(
44 name: impl Into<String>,
45 args: Vec<MlirType>,
46 results: Vec<MlirType>,
47 ) -> Self {
48 let arg_vals = args
49 .into_iter()
50 .enumerate()
51 .map(|(i, t)| (format!("arg{}", i), t))
52 .collect();
53 MlirFunc {
54 name: name.into(),
55 args: arg_vals,
56 results,
57 body: MlirRegion::empty(),
58 attributes: vec![],
59 is_declaration: true,
60 }
61 }
62 pub fn emit(&self) -> String {
64 let mut out = String::new();
65 if self.is_declaration {
66 out.push_str(" func.func private @");
67 } else {
68 out.push_str(" func.func @");
69 }
70 out.push_str(&self.name);
71 out.push('(');
72 for (i, (name, ty)) in self.args.iter().enumerate() {
73 if i > 0 {
74 out.push_str(", ");
75 }
76 out.push_str(&format!("%{}: {}", name, ty));
77 }
78 out.push(')');
79 if !self.results.is_empty() {
80 out.push_str(" -> ");
81 if self.results.len() == 1 {
82 out.push_str(&self.results[0].to_string());
83 } else {
84 out.push('(');
85 for (i, r) in self.results.iter().enumerate() {
86 if i > 0 {
87 out.push_str(", ");
88 }
89 out.push_str(&r.to_string());
90 }
91 out.push(')');
92 }
93 }
94 if !self.attributes.is_empty() {
95 out.push_str(" attributes {");
96 for (i, (k, v)) in self.attributes.iter().enumerate() {
97 if i > 0 {
98 out.push_str(", ");
99 }
100 out.push_str(&format!("{} = {}", k, v));
101 }
102 out.push('}');
103 }
104 if self.is_declaration {
105 out.push('\n');
106 } else {
107 out.push_str(" {\n");
108 for block in &self.body.blocks {
109 out.push_str(&format!("{}", block));
110 }
111 out.push_str(" }\n");
112 }
113 out
114 }
115}
116#[allow(dead_code)]
117#[derive(Debug, Clone, PartialEq)]
118pub enum MLIRPassPhase {
119 Analysis,
120 Transformation,
121 Verification,
122 Cleanup,
123}
124impl MLIRPassPhase {
125 #[allow(dead_code)]
126 pub fn name(&self) -> &str {
127 match self {
128 MLIRPassPhase::Analysis => "analysis",
129 MLIRPassPhase::Transformation => "transformation",
130 MLIRPassPhase::Verification => "verification",
131 MLIRPassPhase::Cleanup => "cleanup",
132 }
133 }
134 #[allow(dead_code)]
135 pub fn is_modifying(&self) -> bool {
136 matches!(self, MLIRPassPhase::Transformation | MLIRPassPhase::Cleanup)
137 }
138}
139pub struct MlirBackend {
141 pub(super) module: MlirModule,
142 pub(super) ssa: SsaCounter,
143 pub(super) pass_pipeline: Vec<String>,
144}
145impl MlirBackend {
146 pub fn new() -> Self {
148 MlirBackend {
149 module: MlirModule::new(),
150 ssa: SsaCounter::new(),
151 pass_pipeline: vec![],
152 }
153 }
154 pub fn with_name(name: impl Into<String>) -> Self {
156 MlirBackend {
157 module: MlirModule::named(name),
158 ssa: SsaCounter::new(),
159 pass_pipeline: vec![],
160 }
161 }
162 pub fn add_pass(&mut self, pass: impl Into<String>) {
164 self.pass_pipeline.push(pass.into());
165 }
166 pub fn compile_add_func(&mut self, name: &str, bits: u32) {
168 let int_ty = MlirType::Integer(bits, false);
169 let mut builder = MlirBuilder::new();
170 let arg0 = MlirValue::named("arg0", int_ty.clone());
171 let arg1 = MlirValue::named("arg1", int_ty.clone());
172 let sum = builder.addi(arg0.clone(), arg1.clone());
173 builder.return_op(vec![sum]);
174 let block = MlirBlock::entry(vec![arg0, arg1], builder.take_ops());
175 let region = MlirRegion::single_block(block);
176 let func = MlirFunc::new(name, vec![], vec![int_ty.clone()], region);
177 self.module.add_function(func);
178 }
179 pub fn compile_decl(&mut self, name: &str, arg_types: Vec<MlirType>, ret_type: MlirType) {
181 let args: Vec<(String, MlirType)> = arg_types
182 .into_iter()
183 .enumerate()
184 .map(|(i, t)| (format!("arg{}", i), t))
185 .collect();
186 let mut builder = MlirBuilder::new();
187 let zero = builder.const_int(0, 64);
188 builder.return_op(vec![zero]);
189 let block = MlirBlock::entry(vec![], builder.take_ops());
190 let func = MlirFunc::new(name, args, vec![ret_type], MlirRegion::single_block(block));
191 self.module.add_function(func);
192 }
193 pub fn emit_module(&self) -> String {
195 self.module.emit()
196 }
197 pub fn run_passes(&self) -> String {
199 if self.pass_pipeline.is_empty() {
200 String::new()
201 } else {
202 format!("mlir-opt --{}", self.pass_pipeline.join(" --"))
203 }
204 }
205 pub fn module_mut(&mut self) -> &mut MlirModule {
207 &mut self.module
208 }
209}
210#[allow(dead_code)]
212#[derive(Debug, Clone, Default)]
213pub struct MLIRExtConstFolder {
214 pub(super) folds: usize,
215 pub(super) failures: usize,
216 pub(super) enabled: bool,
217}
218impl MLIRExtConstFolder {
219 #[allow(dead_code)]
220 pub fn new() -> Self {
221 Self {
222 folds: 0,
223 failures: 0,
224 enabled: true,
225 }
226 }
227 #[allow(dead_code)]
228 pub fn add_i64(&mut self, a: i64, b: i64) -> Option<i64> {
229 self.folds += 1;
230 a.checked_add(b)
231 }
232 #[allow(dead_code)]
233 pub fn sub_i64(&mut self, a: i64, b: i64) -> Option<i64> {
234 self.folds += 1;
235 a.checked_sub(b)
236 }
237 #[allow(dead_code)]
238 pub fn mul_i64(&mut self, a: i64, b: i64) -> Option<i64> {
239 self.folds += 1;
240 a.checked_mul(b)
241 }
242 #[allow(dead_code)]
243 pub fn div_i64(&mut self, a: i64, b: i64) -> Option<i64> {
244 if b == 0 {
245 self.failures += 1;
246 None
247 } else {
248 self.folds += 1;
249 a.checked_div(b)
250 }
251 }
252 #[allow(dead_code)]
253 pub fn rem_i64(&mut self, a: i64, b: i64) -> Option<i64> {
254 if b == 0 {
255 self.failures += 1;
256 None
257 } else {
258 self.folds += 1;
259 a.checked_rem(b)
260 }
261 }
262 #[allow(dead_code)]
263 pub fn neg_i64(&mut self, a: i64) -> Option<i64> {
264 self.folds += 1;
265 a.checked_neg()
266 }
267 #[allow(dead_code)]
268 pub fn shl_i64(&mut self, a: i64, s: u32) -> Option<i64> {
269 if s >= 64 {
270 self.failures += 1;
271 None
272 } else {
273 self.folds += 1;
274 a.checked_shl(s)
275 }
276 }
277 #[allow(dead_code)]
278 pub fn shr_i64(&mut self, a: i64, s: u32) -> Option<i64> {
279 if s >= 64 {
280 self.failures += 1;
281 None
282 } else {
283 self.folds += 1;
284 a.checked_shr(s)
285 }
286 }
287 #[allow(dead_code)]
288 pub fn and_i64(&mut self, a: i64, b: i64) -> i64 {
289 self.folds += 1;
290 a & b
291 }
292 #[allow(dead_code)]
293 pub fn or_i64(&mut self, a: i64, b: i64) -> i64 {
294 self.folds += 1;
295 a | b
296 }
297 #[allow(dead_code)]
298 pub fn xor_i64(&mut self, a: i64, b: i64) -> i64 {
299 self.folds += 1;
300 a ^ b
301 }
302 #[allow(dead_code)]
303 pub fn not_i64(&mut self, a: i64) -> i64 {
304 self.folds += 1;
305 !a
306 }
307 #[allow(dead_code)]
308 pub fn fold_count(&self) -> usize {
309 self.folds
310 }
311 #[allow(dead_code)]
312 pub fn failure_count(&self) -> usize {
313 self.failures
314 }
315 #[allow(dead_code)]
316 pub fn enable(&mut self) {
317 self.enabled = true;
318 }
319 #[allow(dead_code)]
320 pub fn disable(&mut self) {
321 self.enabled = false;
322 }
323 #[allow(dead_code)]
324 pub fn is_enabled(&self) -> bool {
325 self.enabled
326 }
327}
328#[allow(dead_code)]
329#[derive(Debug, Clone)]
330pub struct MLIRCacheEntry {
331 pub key: String,
332 pub data: Vec<u8>,
333 pub timestamp: u64,
334 pub valid: bool,
335}
336#[allow(dead_code)]
337#[derive(Debug, Clone)]
338pub struct MLIRLivenessInfo {
339 pub live_in: Vec<std::collections::HashSet<u32>>,
340 pub live_out: Vec<std::collections::HashSet<u32>>,
341 pub defs: Vec<std::collections::HashSet<u32>>,
342 pub uses: Vec<std::collections::HashSet<u32>>,
343}
344impl MLIRLivenessInfo {
345 #[allow(dead_code)]
346 pub fn new(block_count: usize) -> Self {
347 MLIRLivenessInfo {
348 live_in: vec![std::collections::HashSet::new(); block_count],
349 live_out: vec![std::collections::HashSet::new(); block_count],
350 defs: vec![std::collections::HashSet::new(); block_count],
351 uses: vec![std::collections::HashSet::new(); block_count],
352 }
353 }
354 #[allow(dead_code)]
355 pub fn add_def(&mut self, block: usize, var: u32) {
356 if block < self.defs.len() {
357 self.defs[block].insert(var);
358 }
359 }
360 #[allow(dead_code)]
361 pub fn add_use(&mut self, block: usize, var: u32) {
362 if block < self.uses.len() {
363 self.uses[block].insert(var);
364 }
365 }
366 #[allow(dead_code)]
367 pub fn is_live_in(&self, block: usize, var: u32) -> bool {
368 self.live_in
369 .get(block)
370 .map(|s| s.contains(&var))
371 .unwrap_or(false)
372 }
373 #[allow(dead_code)]
374 pub fn is_live_out(&self, block: usize, var: u32) -> bool {
375 self.live_out
376 .get(block)
377 .map(|s| s.contains(&var))
378 .unwrap_or(false)
379 }
380}
381#[allow(dead_code)]
382#[derive(Debug, Clone)]
383pub struct MLIRWorklist {
384 pub(super) items: std::collections::VecDeque<u32>,
385 pub(super) in_worklist: std::collections::HashSet<u32>,
386}
387impl MLIRWorklist {
388 #[allow(dead_code)]
389 pub fn new() -> Self {
390 MLIRWorklist {
391 items: std::collections::VecDeque::new(),
392 in_worklist: std::collections::HashSet::new(),
393 }
394 }
395 #[allow(dead_code)]
396 pub fn push(&mut self, item: u32) -> bool {
397 if self.in_worklist.insert(item) {
398 self.items.push_back(item);
399 true
400 } else {
401 false
402 }
403 }
404 #[allow(dead_code)]
405 pub fn pop(&mut self) -> Option<u32> {
406 let item = self.items.pop_front()?;
407 self.in_worklist.remove(&item);
408 Some(item)
409 }
410 #[allow(dead_code)]
411 pub fn is_empty(&self) -> bool {
412 self.items.is_empty()
413 }
414 #[allow(dead_code)]
415 pub fn len(&self) -> usize {
416 self.items.len()
417 }
418 #[allow(dead_code)]
419 pub fn contains(&self, item: u32) -> bool {
420 self.in_worklist.contains(&item)
421 }
422}
423#[allow(dead_code)]
425#[derive(Debug, Clone, PartialEq, Eq, Hash)]
426pub enum MLIRExtPassPhase {
427 Early,
428 Middle,
429 Late,
430 Finalize,
431}
432impl MLIRExtPassPhase {
433 #[allow(dead_code)]
434 pub fn is_early(&self) -> bool {
435 matches!(self, Self::Early)
436 }
437 #[allow(dead_code)]
438 pub fn is_middle(&self) -> bool {
439 matches!(self, Self::Middle)
440 }
441 #[allow(dead_code)]
442 pub fn is_late(&self) -> bool {
443 matches!(self, Self::Late)
444 }
445 #[allow(dead_code)]
446 pub fn is_finalize(&self) -> bool {
447 matches!(self, Self::Finalize)
448 }
449 #[allow(dead_code)]
450 pub fn order(&self) -> u32 {
451 match self {
452 Self::Early => 0,
453 Self::Middle => 1,
454 Self::Late => 2,
455 Self::Finalize => 3,
456 }
457 }
458 #[allow(dead_code)]
459 pub fn from_order(n: u32) -> Option<Self> {
460 match n {
461 0 => Some(Self::Early),
462 1 => Some(Self::Middle),
463 2 => Some(Self::Late),
464 3 => Some(Self::Finalize),
465 _ => None,
466 }
467 }
468}
469#[derive(Debug, Clone, PartialEq, Eq, Hash)]
471pub enum MlirDialect {
472 Builtin,
474 Arith,
476 Func,
478 CF,
480 MemRef,
482 SCF,
484 Affine,
486 Tensor,
488 Vector,
490 Linalg,
492 GPU,
494 LLVM,
496 Math,
498 Index,
500}
501#[allow(dead_code)]
502#[derive(Debug, Clone, Default)]
503pub struct MLIRPassStats {
504 pub total_runs: u32,
505 pub successful_runs: u32,
506 pub total_changes: u64,
507 pub time_ms: u64,
508 pub iterations_used: u32,
509}
510impl MLIRPassStats {
511 #[allow(dead_code)]
512 pub fn new() -> Self {
513 Self::default()
514 }
515 #[allow(dead_code)]
516 pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
517 self.total_runs += 1;
518 self.successful_runs += 1;
519 self.total_changes += changes;
520 self.time_ms += time_ms;
521 self.iterations_used = iterations;
522 }
523 #[allow(dead_code)]
524 pub fn average_changes_per_run(&self) -> f64 {
525 if self.total_runs == 0 {
526 return 0.0;
527 }
528 self.total_changes as f64 / self.total_runs as f64
529 }
530 #[allow(dead_code)]
531 pub fn success_rate(&self) -> f64 {
532 if self.total_runs == 0 {
533 return 0.0;
534 }
535 self.successful_runs as f64 / self.total_runs as f64
536 }
537 #[allow(dead_code)]
538 pub fn format_summary(&self) -> String {
539 format!(
540 "Runs: {}/{}, Changes: {}, Time: {}ms",
541 self.successful_runs, self.total_runs, self.total_changes, self.time_ms
542 )
543 }
544}
545#[allow(dead_code)]
547#[derive(Debug)]
548pub struct MLIRExtCache {
549 pub(super) entries: Vec<(u64, Vec<u8>, bool, u32)>,
550 pub(super) cap: usize,
551 pub(super) total_hits: u64,
552 pub(super) total_misses: u64,
553}
554impl MLIRExtCache {
555 #[allow(dead_code)]
556 pub fn new(cap: usize) -> Self {
557 Self {
558 entries: Vec::new(),
559 cap,
560 total_hits: 0,
561 total_misses: 0,
562 }
563 }
564 #[allow(dead_code)]
565 pub fn get(&mut self, key: u64) -> Option<&[u8]> {
566 for e in self.entries.iter_mut() {
567 if e.0 == key && e.2 {
568 e.3 += 1;
569 self.total_hits += 1;
570 return Some(&e.1);
571 }
572 }
573 self.total_misses += 1;
574 None
575 }
576 #[allow(dead_code)]
577 pub fn put(&mut self, key: u64, data: Vec<u8>) {
578 if self.entries.len() >= self.cap {
579 self.entries.retain(|e| e.2);
580 if self.entries.len() >= self.cap {
581 self.entries.remove(0);
582 }
583 }
584 self.entries.push((key, data, true, 0));
585 }
586 #[allow(dead_code)]
587 pub fn invalidate(&mut self) {
588 for e in self.entries.iter_mut() {
589 e.2 = false;
590 }
591 }
592 #[allow(dead_code)]
593 pub fn hit_rate(&self) -> f64 {
594 let t = self.total_hits + self.total_misses;
595 if t == 0 {
596 0.0
597 } else {
598 self.total_hits as f64 / t as f64
599 }
600 }
601 #[allow(dead_code)]
602 pub fn live_count(&self) -> usize {
603 self.entries.iter().filter(|e| e.2).count()
604 }
605}
606#[allow(dead_code)]
608#[derive(Debug, Clone)]
609pub struct MLIRExtPassConfig {
610 pub name: String,
611 pub phase: MLIRExtPassPhase,
612 pub enabled: bool,
613 pub max_iterations: usize,
614 pub debug: u32,
615 pub timeout_ms: Option<u64>,
616}
617impl MLIRExtPassConfig {
618 #[allow(dead_code)]
619 pub fn new(name: impl Into<String>) -> Self {
620 Self {
621 name: name.into(),
622 phase: MLIRExtPassPhase::Middle,
623 enabled: true,
624 max_iterations: 100,
625 debug: 0,
626 timeout_ms: None,
627 }
628 }
629 #[allow(dead_code)]
630 pub fn with_phase(mut self, phase: MLIRExtPassPhase) -> Self {
631 self.phase = phase;
632 self
633 }
634 #[allow(dead_code)]
635 pub fn with_max_iter(mut self, n: usize) -> Self {
636 self.max_iterations = n;
637 self
638 }
639 #[allow(dead_code)]
640 pub fn with_debug(mut self, d: u32) -> Self {
641 self.debug = d;
642 self
643 }
644 #[allow(dead_code)]
645 pub fn disabled(mut self) -> Self {
646 self.enabled = false;
647 self
648 }
649 #[allow(dead_code)]
650 pub fn with_timeout(mut self, ms: u64) -> Self {
651 self.timeout_ms = Some(ms);
652 self
653 }
654 #[allow(dead_code)]
655 pub fn is_debug_enabled(&self) -> bool {
656 self.debug > 0
657 }
658}
659#[derive(Debug, Default)]
661pub struct SsaCounter {
662 pub(super) counter: u32,
663 pub(super) named: HashMap<String, u32>,
664}
665impl SsaCounter {
666 pub fn new() -> Self {
668 SsaCounter::default()
669 }
670 pub fn next(&mut self, ty: MlirType) -> MlirValue {
672 let id = self.counter;
673 self.counter += 1;
674 MlirValue::numbered(id, ty)
675 }
676 pub fn named(&mut self, base: &str, ty: MlirType) -> MlirValue {
678 let count = self.named.entry(base.to_string()).or_insert(0);
679 let name = if *count == 0 {
680 base.to_string()
681 } else {
682 format!("{}_{}", base, count)
683 };
684 *count += 1;
685 MlirValue::named(name, ty)
686 }
687 pub fn reset(&mut self) {
689 self.counter = 0;
690 self.named.clear();
691 }
692}
693#[allow(dead_code)]
694#[derive(Debug, Clone)]
695pub struct MLIRDominatorTree {
696 pub idom: Vec<Option<u32>>,
697 pub dom_children: Vec<Vec<u32>>,
698 pub dom_depth: Vec<u32>,
699}
700impl MLIRDominatorTree {
701 #[allow(dead_code)]
702 pub fn new(size: usize) -> Self {
703 MLIRDominatorTree {
704 idom: vec![None; size],
705 dom_children: vec![Vec::new(); size],
706 dom_depth: vec![0; size],
707 }
708 }
709 #[allow(dead_code)]
710 pub fn set_idom(&mut self, node: usize, idom: u32) {
711 self.idom[node] = Some(idom);
712 }
713 #[allow(dead_code)]
714 pub fn dominates(&self, a: usize, b: usize) -> bool {
715 if a == b {
716 return true;
717 }
718 let mut cur = b;
719 loop {
720 match self.idom[cur] {
721 Some(parent) if parent as usize == a => return true,
722 Some(parent) if parent as usize == cur => return false,
723 Some(parent) => cur = parent as usize,
724 None => return false,
725 }
726 }
727 }
728 #[allow(dead_code)]
729 pub fn depth(&self, node: usize) -> u32 {
730 self.dom_depth.get(node).copied().unwrap_or(0)
731 }
732}
733pub struct MlirBuilder {
735 pub(super) ssa: SsaCounter,
736 pub(super) ops: Vec<MlirOp>,
737}
738impl MlirBuilder {
739 pub fn new() -> Self {
741 MlirBuilder {
742 ssa: SsaCounter::new(),
743 ops: vec![],
744 }
745 }
746 pub fn const_int(&mut self, value: i64, bits: u32) -> MlirValue {
748 let ty = MlirType::Integer(bits, false);
749 let result = self.ssa.next(ty.clone());
750 let mut op = MlirOp::unary_result(
751 result.clone(),
752 "arith.constant",
753 vec![],
754 vec![("value".to_string(), MlirAttr::Integer(value, ty))],
755 );
756 op.type_annotations = vec![result.ty.clone()];
757 self.ops.push(op);
758 result
759 }
760 pub fn const_float(&mut self, value: f64, bits: u32) -> MlirValue {
762 let ty = MlirType::Float(bits);
763 let result = self.ssa.next(ty.clone());
764 let op = MlirOp::unary_result(
765 result.clone(),
766 "arith.constant",
767 vec![],
768 vec![("value".to_string(), MlirAttr::Float(value))],
769 );
770 self.ops.push(op);
771 result
772 }
773 pub fn addi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
775 let ty = lhs.ty.clone();
776 let result = self.ssa.next(ty.clone());
777 let mut op = MlirOp::unary_result(result.clone(), "arith.addi", vec![lhs, rhs], vec![]);
778 op.type_annotations = vec![ty];
779 self.ops.push(op);
780 result
781 }
782 pub fn subi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
784 let ty = lhs.ty.clone();
785 let result = self.ssa.next(ty.clone());
786 let mut op = MlirOp::unary_result(result.clone(), "arith.subi", vec![lhs, rhs], vec![]);
787 op.type_annotations = vec![ty];
788 self.ops.push(op);
789 result
790 }
791 pub fn muli(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
793 let ty = lhs.ty.clone();
794 let result = self.ssa.next(ty.clone());
795 let mut op = MlirOp::unary_result(result.clone(), "arith.muli", vec![lhs, rhs], vec![]);
796 op.type_annotations = vec![ty];
797 self.ops.push(op);
798 result
799 }
800 pub fn divsi(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
802 let ty = lhs.ty.clone();
803 let result = self.ssa.next(ty.clone());
804 let mut op = MlirOp::unary_result(result.clone(), "arith.divsi", vec![lhs, rhs], vec![]);
805 op.type_annotations = vec![ty];
806 self.ops.push(op);
807 result
808 }
809 pub fn addf(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
811 let ty = lhs.ty.clone();
812 let result = self.ssa.next(ty.clone());
813 let mut op = MlirOp::unary_result(result.clone(), "arith.addf", vec![lhs, rhs], vec![]);
814 op.type_annotations = vec![ty];
815 self.ops.push(op);
816 result
817 }
818 pub fn mulf(&mut self, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
820 let ty = lhs.ty.clone();
821 let result = self.ssa.next(ty.clone());
822 let mut op = MlirOp::unary_result(result.clone(), "arith.mulf", vec![lhs, rhs], vec![]);
823 op.type_annotations = vec![ty];
824 self.ops.push(op);
825 result
826 }
827 pub fn cmpi(&mut self, pred: CmpiPred, lhs: MlirValue, rhs: MlirValue) -> MlirValue {
829 let result = self.ssa.next(MlirType::Integer(1, false));
830 let mut op = MlirOp::unary_result(
831 result.clone(),
832 "arith.cmpi",
833 vec![lhs.clone(), rhs],
834 vec![("predicate".to_string(), MlirAttr::Str(pred.to_string()))],
835 );
836 op.type_annotations = vec![lhs.ty];
837 self.ops.push(op);
838 result
839 }
840 pub fn extsi(&mut self, val: MlirValue, target_bits: u32) -> MlirValue {
842 let result = self.ssa.next(MlirType::Integer(target_bits, false));
843 let src_ty = val.ty.clone();
844 let dst_ty = result.ty.clone();
845 let mut op = MlirOp::unary_result(result.clone(), "arith.extsi", vec![val], vec![]);
846 op.type_annotations = vec![src_ty, dst_ty];
847 self.ops.push(op);
848 result
849 }
850 pub fn trunci(&mut self, val: MlirValue, target_bits: u32) -> MlirValue {
852 let result = self.ssa.next(MlirType::Integer(target_bits, false));
853 let src_ty = val.ty.clone();
854 let dst_ty = result.ty.clone();
855 let mut op = MlirOp::unary_result(result.clone(), "arith.trunci", vec![val], vec![]);
856 op.type_annotations = vec![src_ty, dst_ty];
857 self.ops.push(op);
858 result
859 }
860 pub fn sin(&mut self, val: MlirValue) -> MlirValue {
862 let ty = val.ty.clone();
863 let result = self.ssa.next(ty.clone());
864 let mut op = MlirOp::unary_result(result.clone(), "math.sin", vec![val], vec![]);
865 op.type_annotations = vec![ty];
866 self.ops.push(op);
867 result
868 }
869 pub fn cos(&mut self, val: MlirValue) -> MlirValue {
871 let ty = val.ty.clone();
872 let result = self.ssa.next(ty.clone());
873 let mut op = MlirOp::unary_result(result.clone(), "math.cos", vec![val], vec![]);
874 op.type_annotations = vec![ty];
875 self.ops.push(op);
876 result
877 }
878 pub fn exp(&mut self, val: MlirValue) -> MlirValue {
880 let ty = val.ty.clone();
881 let result = self.ssa.next(ty.clone());
882 let mut op = MlirOp::unary_result(result.clone(), "math.exp", vec![val], vec![]);
883 op.type_annotations = vec![ty];
884 self.ops.push(op);
885 result
886 }
887 pub fn log(&mut self, val: MlirValue) -> MlirValue {
889 let ty = val.ty.clone();
890 let result = self.ssa.next(ty.clone());
891 let mut op = MlirOp::unary_result(result.clone(), "math.log", vec![val], vec![]);
892 op.type_annotations = vec![ty];
893 self.ops.push(op);
894 result
895 }
896 pub fn sqrt(&mut self, val: MlirValue) -> MlirValue {
898 let ty = val.ty.clone();
899 let result = self.ssa.next(ty.clone());
900 let mut op = MlirOp::unary_result(result.clone(), "math.sqrt", vec![val], vec![]);
901 op.type_annotations = vec![ty];
902 self.ops.push(op);
903 result
904 }
905 pub fn alloc(&mut self, elem_ty: MlirType, dims: Vec<i64>) -> MlirValue {
907 let memref_ty = MlirType::MemRef(Box::new(elem_ty), dims, AffineMap::Constant);
908 let result = self.ssa.next(memref_ty.clone());
909 let op = MlirOp::unary_result(result.clone(), "memref.alloc", vec![], vec![]);
910 self.ops.push(op);
911 result
912 }
913 pub fn dealloc(&mut self, memref: MlirValue) {
915 let op = MlirOp::void_op("memref.dealloc", vec![memref], vec![]);
916 self.ops.push(op);
917 }
918 pub fn return_op(&mut self, values: Vec<MlirValue>) {
920 let op = MlirOp::void_op("func.return", values, vec![]);
921 self.ops.push(op);
922 }
923 pub fn call(
925 &mut self,
926 callee: &str,
927 args: Vec<MlirValue>,
928 result_types: Vec<MlirType>,
929 ) -> Vec<MlirValue> {
930 let results: Vec<MlirValue> = result_types.into_iter().map(|t| self.ssa.next(t)).collect();
931 let mut op = MlirOp {
932 results: results.clone(),
933 op_name: "func.call".to_string(),
934 operands: args,
935 regions: vec![],
936 successors: vec![],
937 attributes: vec![("callee".to_string(), MlirAttr::Symbol(callee.to_string()))],
938 type_annotations: vec![],
939 };
940 op.type_annotations = results.iter().map(|r| r.ty.clone()).collect();
941 self.ops.push(op);
942 results
943 }
944 pub fn take_ops(&mut self) -> Vec<MlirOp> {
946 std::mem::take(&mut self.ops)
947 }
948 pub fn finish_block(&mut self, args: Vec<MlirValue>) -> MlirBlock {
950 let ops = self.take_ops();
951 MlirBlock::entry(args, ops)
952 }
953}
954#[derive(Debug, Clone, PartialEq)]
956pub enum MlirAttr {
957 Integer(i64, MlirType),
959 Float(f64),
961 Str(String),
963 Type(MlirType),
965 Array(Vec<MlirAttr>),
967 Dict(Vec<(String, MlirAttr)>),
969 AffineMap(String),
971 Unit,
973 Bool(bool),
975 Symbol(String),
977 Dense(Vec<MlirAttr>, MlirType),
979}
980#[derive(Debug, Clone, PartialEq)]
982pub enum MlirType {
983 Integer(u32, bool),
987 Float(u32),
989 Index,
991 MemRef(Box<MlirType>, Vec<i64>, AffineMap),
993 Tensor(Vec<i64>, Box<MlirType>),
995 Vector(Vec<u64>, Box<MlirType>),
997 Tuple(Vec<MlirType>),
999 NoneType,
1001 Custom(String),
1003 FuncType(Vec<MlirType>, Vec<MlirType>),
1005 Complex(Box<MlirType>),
1007 UnrankedMemRef(Box<MlirType>),
1009}
1010#[allow(dead_code)]
1012#[derive(Debug, Clone, Default)]
1013pub struct MLIRExtLiveness {
1014 pub live_in: Vec<Vec<usize>>,
1015 pub live_out: Vec<Vec<usize>>,
1016 pub defs: Vec<Vec<usize>>,
1017 pub uses: Vec<Vec<usize>>,
1018}
1019impl MLIRExtLiveness {
1020 #[allow(dead_code)]
1021 pub fn new(n: usize) -> Self {
1022 Self {
1023 live_in: vec![Vec::new(); n],
1024 live_out: vec![Vec::new(); n],
1025 defs: vec![Vec::new(); n],
1026 uses: vec![Vec::new(); n],
1027 }
1028 }
1029 #[allow(dead_code)]
1030 pub fn live_in(&self, b: usize, v: usize) -> bool {
1031 self.live_in.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1032 }
1033 #[allow(dead_code)]
1034 pub fn live_out(&self, b: usize, v: usize) -> bool {
1035 self.live_out
1036 .get(b)
1037 .map(|s| s.contains(&v))
1038 .unwrap_or(false)
1039 }
1040 #[allow(dead_code)]
1041 pub fn add_def(&mut self, b: usize, v: usize) {
1042 if let Some(s) = self.defs.get_mut(b) {
1043 if !s.contains(&v) {
1044 s.push(v);
1045 }
1046 }
1047 }
1048 #[allow(dead_code)]
1049 pub fn add_use(&mut self, b: usize, v: usize) {
1050 if let Some(s) = self.uses.get_mut(b) {
1051 if !s.contains(&v) {
1052 s.push(v);
1053 }
1054 }
1055 }
1056 #[allow(dead_code)]
1057 pub fn var_is_used_in_block(&self, b: usize, v: usize) -> bool {
1058 self.uses.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1059 }
1060 #[allow(dead_code)]
1061 pub fn var_is_def_in_block(&self, b: usize, v: usize) -> bool {
1062 self.defs.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1063 }
1064}
1065#[derive(Debug, Clone)]
1067pub struct MlirRegion {
1068 pub blocks: Vec<MlirBlock>,
1070}
1071impl MlirRegion {
1072 pub fn single_block(block: MlirBlock) -> Self {
1074 MlirRegion {
1075 blocks: vec![block],
1076 }
1077 }
1078 pub fn empty() -> Self {
1080 MlirRegion { blocks: vec![] }
1081 }
1082}
1083#[allow(dead_code)]
1085#[derive(Debug, Clone, Default)]
1086pub struct MLIRExtPassStats {
1087 pub iterations: usize,
1088 pub changed: bool,
1089 pub nodes_visited: usize,
1090 pub nodes_modified: usize,
1091 pub time_ms: u64,
1092 pub memory_bytes: usize,
1093 pub errors: usize,
1094}
1095impl MLIRExtPassStats {
1096 #[allow(dead_code)]
1097 pub fn new() -> Self {
1098 Self::default()
1099 }
1100 #[allow(dead_code)]
1101 pub fn visit(&mut self) {
1102 self.nodes_visited += 1;
1103 }
1104 #[allow(dead_code)]
1105 pub fn modify(&mut self) {
1106 self.nodes_modified += 1;
1107 self.changed = true;
1108 }
1109 #[allow(dead_code)]
1110 pub fn iterate(&mut self) {
1111 self.iterations += 1;
1112 }
1113 #[allow(dead_code)]
1114 pub fn error(&mut self) {
1115 self.errors += 1;
1116 }
1117 #[allow(dead_code)]
1118 pub fn efficiency(&self) -> f64 {
1119 if self.nodes_visited == 0 {
1120 0.0
1121 } else {
1122 self.nodes_modified as f64 / self.nodes_visited as f64
1123 }
1124 }
1125 #[allow(dead_code)]
1126 pub fn merge(&mut self, o: &MLIRExtPassStats) {
1127 self.iterations += o.iterations;
1128 self.changed |= o.changed;
1129 self.nodes_visited += o.nodes_visited;
1130 self.nodes_modified += o.nodes_modified;
1131 self.time_ms += o.time_ms;
1132 self.memory_bytes = self.memory_bytes.max(o.memory_bytes);
1133 self.errors += o.errors;
1134 }
1135}
1136#[derive(Debug, Clone)]
1138pub struct MlirBlock {
1139 pub label: Option<String>,
1141 pub arguments: Vec<MlirValue>,
1143 pub body: Vec<MlirOp>,
1145 pub terminator: Option<MlirOp>,
1147}
1148impl MlirBlock {
1149 pub fn entry(arguments: Vec<MlirValue>, body: Vec<MlirOp>) -> Self {
1151 MlirBlock {
1152 label: None,
1153 arguments,
1154 body,
1155 terminator: None,
1156 }
1157 }
1158 pub fn labeled(label: impl Into<String>, arguments: Vec<MlirValue>, body: Vec<MlirOp>) -> Self {
1160 MlirBlock {
1161 label: Some(label.into()),
1162 arguments,
1163 body,
1164 terminator: None,
1165 }
1166 }
1167}
1168#[allow(dead_code)]
1170#[derive(Debug, Clone)]
1171pub struct MLIRExtWorklist {
1172 pub(super) items: std::collections::VecDeque<usize>,
1173 pub(super) present: Vec<bool>,
1174}
1175impl MLIRExtWorklist {
1176 #[allow(dead_code)]
1177 pub fn new(capacity: usize) -> Self {
1178 Self {
1179 items: std::collections::VecDeque::new(),
1180 present: vec![false; capacity],
1181 }
1182 }
1183 #[allow(dead_code)]
1184 pub fn push(&mut self, id: usize) {
1185 if id < self.present.len() && !self.present[id] {
1186 self.present[id] = true;
1187 self.items.push_back(id);
1188 }
1189 }
1190 #[allow(dead_code)]
1191 pub fn push_front(&mut self, id: usize) {
1192 if id < self.present.len() && !self.present[id] {
1193 self.present[id] = true;
1194 self.items.push_front(id);
1195 }
1196 }
1197 #[allow(dead_code)]
1198 pub fn pop(&mut self) -> Option<usize> {
1199 let id = self.items.pop_front()?;
1200 if id < self.present.len() {
1201 self.present[id] = false;
1202 }
1203 Some(id)
1204 }
1205 #[allow(dead_code)]
1206 pub fn is_empty(&self) -> bool {
1207 self.items.is_empty()
1208 }
1209 #[allow(dead_code)]
1210 pub fn len(&self) -> usize {
1211 self.items.len()
1212 }
1213 #[allow(dead_code)]
1214 pub fn contains(&self, id: usize) -> bool {
1215 id < self.present.len() && self.present[id]
1216 }
1217 #[allow(dead_code)]
1218 pub fn drain_all(&mut self) -> Vec<usize> {
1219 let v: Vec<usize> = self.items.drain(..).collect();
1220 for &id in &v {
1221 if id < self.present.len() {
1222 self.present[id] = false;
1223 }
1224 }
1225 v
1226 }
1227}
1228#[derive(Debug, Clone)]
1230pub struct MlirOp {
1231 pub results: Vec<MlirValue>,
1233 pub op_name: String,
1235 pub operands: Vec<MlirValue>,
1237 pub regions: Vec<MlirRegion>,
1239 pub successors: Vec<String>,
1241 pub attributes: Vec<(String, MlirAttr)>,
1243 pub type_annotations: Vec<MlirType>,
1245}
1246impl MlirOp {
1247 pub fn unary_result(
1249 result: MlirValue,
1250 op_name: impl Into<String>,
1251 operands: Vec<MlirValue>,
1252 attrs: Vec<(String, MlirAttr)>,
1253 ) -> Self {
1254 MlirOp {
1255 results: vec![result],
1256 op_name: op_name.into(),
1257 operands,
1258 regions: vec![],
1259 successors: vec![],
1260 attributes: attrs,
1261 type_annotations: vec![],
1262 }
1263 }
1264 pub fn void_op(
1266 op_name: impl Into<String>,
1267 operands: Vec<MlirValue>,
1268 attrs: Vec<(String, MlirAttr)>,
1269 ) -> Self {
1270 MlirOp {
1271 results: vec![],
1272 op_name: op_name.into(),
1273 operands,
1274 regions: vec![],
1275 successors: vec![],
1276 attributes: attrs,
1277 type_annotations: vec![],
1278 }
1279 }
1280}
1281#[allow(dead_code)]
1282#[derive(Debug, Clone)]
1283pub struct MLIRAnalysisCache {
1284 pub(super) entries: std::collections::HashMap<String, MLIRCacheEntry>,
1285 pub(super) max_size: usize,
1286 pub(super) hits: u64,
1287 pub(super) misses: u64,
1288}
1289impl MLIRAnalysisCache {
1290 #[allow(dead_code)]
1291 pub fn new(max_size: usize) -> Self {
1292 MLIRAnalysisCache {
1293 entries: std::collections::HashMap::new(),
1294 max_size,
1295 hits: 0,
1296 misses: 0,
1297 }
1298 }
1299 #[allow(dead_code)]
1300 pub fn get(&mut self, key: &str) -> Option<&MLIRCacheEntry> {
1301 if self.entries.contains_key(key) {
1302 self.hits += 1;
1303 self.entries.get(key)
1304 } else {
1305 self.misses += 1;
1306 None
1307 }
1308 }
1309 #[allow(dead_code)]
1310 pub fn insert(&mut self, key: String, data: Vec<u8>) {
1311 if self.entries.len() >= self.max_size {
1312 if let Some(oldest) = self.entries.keys().next().cloned() {
1313 self.entries.remove(&oldest);
1314 }
1315 }
1316 self.entries.insert(
1317 key.clone(),
1318 MLIRCacheEntry {
1319 key,
1320 data,
1321 timestamp: 0,
1322 valid: true,
1323 },
1324 );
1325 }
1326 #[allow(dead_code)]
1327 pub fn invalidate(&mut self, key: &str) {
1328 if let Some(entry) = self.entries.get_mut(key) {
1329 entry.valid = false;
1330 }
1331 }
1332 #[allow(dead_code)]
1333 pub fn clear(&mut self) {
1334 self.entries.clear();
1335 }
1336 #[allow(dead_code)]
1337 pub fn hit_rate(&self) -> f64 {
1338 let total = self.hits + self.misses;
1339 if total == 0 {
1340 return 0.0;
1341 }
1342 self.hits as f64 / total as f64
1343 }
1344 #[allow(dead_code)]
1345 pub fn size(&self) -> usize {
1346 self.entries.len()
1347 }
1348}
1349#[allow(dead_code)]
1350#[derive(Debug, Clone)]
1351pub struct MLIRPassConfig {
1352 pub phase: MLIRPassPhase,
1353 pub enabled: bool,
1354 pub max_iterations: u32,
1355 pub debug_output: bool,
1356 pub pass_name: String,
1357}
1358impl MLIRPassConfig {
1359 #[allow(dead_code)]
1360 pub fn new(name: impl Into<String>, phase: MLIRPassPhase) -> Self {
1361 MLIRPassConfig {
1362 phase,
1363 enabled: true,
1364 max_iterations: 10,
1365 debug_output: false,
1366 pass_name: name.into(),
1367 }
1368 }
1369 #[allow(dead_code)]
1370 pub fn disabled(mut self) -> Self {
1371 self.enabled = false;
1372 self
1373 }
1374 #[allow(dead_code)]
1375 pub fn with_debug(mut self) -> Self {
1376 self.debug_output = true;
1377 self
1378 }
1379 #[allow(dead_code)]
1380 pub fn max_iter(mut self, n: u32) -> Self {
1381 self.max_iterations = n;
1382 self
1383 }
1384}
1385#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1387pub enum CmpfPred {
1388 Oeq,
1390 One,
1392 Olt,
1394 Ole,
1396 Ogt,
1398 Oge,
1400 Ueq,
1402 Une,
1404}
1405#[derive(Debug, Clone, PartialEq)]
1407pub enum AffineMap {
1408 Identity(usize),
1410 Constant,
1412 Custom(String),
1414}
1415#[derive(Debug, Clone, PartialEq)]
1417pub struct MlirValue {
1418 pub name: String,
1420 pub ty: MlirType,
1422}
1423impl MlirValue {
1424 pub fn numbered(id: u32, ty: MlirType) -> Self {
1426 MlirValue {
1427 name: id.to_string(),
1428 ty,
1429 }
1430 }
1431 pub fn named(name: impl Into<String>, ty: MlirType) -> Self {
1433 MlirValue {
1434 name: name.into(),
1435 ty,
1436 }
1437 }
1438}
1439#[allow(dead_code)]
1441#[derive(Debug, Clone)]
1442pub struct MLIRExtDepGraph {
1443 pub(super) n: usize,
1444 pub(super) adj: Vec<Vec<usize>>,
1445 pub(super) rev: Vec<Vec<usize>>,
1446 pub(super) edge_count: usize,
1447}
1448impl MLIRExtDepGraph {
1449 #[allow(dead_code)]
1450 pub fn new(n: usize) -> Self {
1451 Self {
1452 n,
1453 adj: vec![Vec::new(); n],
1454 rev: vec![Vec::new(); n],
1455 edge_count: 0,
1456 }
1457 }
1458 #[allow(dead_code)]
1459 pub fn add_edge(&mut self, from: usize, to: usize) {
1460 if from < self.n && to < self.n {
1461 if !self.adj[from].contains(&to) {
1462 self.adj[from].push(to);
1463 self.rev[to].push(from);
1464 self.edge_count += 1;
1465 }
1466 }
1467 }
1468 #[allow(dead_code)]
1469 pub fn succs(&self, n: usize) -> &[usize] {
1470 self.adj.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1471 }
1472 #[allow(dead_code)]
1473 pub fn preds(&self, n: usize) -> &[usize] {
1474 self.rev.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1475 }
1476 #[allow(dead_code)]
1477 pub fn topo_sort(&self) -> Option<Vec<usize>> {
1478 let mut deg: Vec<usize> = (0..self.n).map(|i| self.rev[i].len()).collect();
1479 let mut q: std::collections::VecDeque<usize> =
1480 (0..self.n).filter(|&i| deg[i] == 0).collect();
1481 let mut out = Vec::with_capacity(self.n);
1482 while let Some(u) = q.pop_front() {
1483 out.push(u);
1484 for &v in &self.adj[u] {
1485 deg[v] -= 1;
1486 if deg[v] == 0 {
1487 q.push_back(v);
1488 }
1489 }
1490 }
1491 if out.len() == self.n {
1492 Some(out)
1493 } else {
1494 None
1495 }
1496 }
1497 #[allow(dead_code)]
1498 pub fn has_cycle(&self) -> bool {
1499 self.topo_sort().is_none()
1500 }
1501 #[allow(dead_code)]
1502 pub fn reachable(&self, start: usize) -> Vec<usize> {
1503 let mut vis = vec![false; self.n];
1504 let mut stk = vec![start];
1505 let mut out = Vec::new();
1506 while let Some(u) = stk.pop() {
1507 if u < self.n && !vis[u] {
1508 vis[u] = true;
1509 out.push(u);
1510 for &v in &self.adj[u] {
1511 if !vis[v] {
1512 stk.push(v);
1513 }
1514 }
1515 }
1516 }
1517 out
1518 }
1519 #[allow(dead_code)]
1520 pub fn scc(&self) -> Vec<Vec<usize>> {
1521 let mut visited = vec![false; self.n];
1522 let mut order = Vec::new();
1523 for i in 0..self.n {
1524 if !visited[i] {
1525 let mut stk = vec![(i, 0usize)];
1526 while let Some((u, idx)) = stk.last_mut() {
1527 if !visited[*u] {
1528 visited[*u] = true;
1529 }
1530 if *idx < self.adj[*u].len() {
1531 let v = self.adj[*u][*idx];
1532 *idx += 1;
1533 if !visited[v] {
1534 stk.push((v, 0));
1535 }
1536 } else {
1537 order.push(*u);
1538 stk.pop();
1539 }
1540 }
1541 }
1542 }
1543 let mut comp = vec![usize::MAX; self.n];
1544 let mut components: Vec<Vec<usize>> = Vec::new();
1545 for &start in order.iter().rev() {
1546 if comp[start] == usize::MAX {
1547 let cid = components.len();
1548 let mut component = Vec::new();
1549 let mut stk = vec![start];
1550 while let Some(u) = stk.pop() {
1551 if comp[u] == usize::MAX {
1552 comp[u] = cid;
1553 component.push(u);
1554 for &v in &self.rev[u] {
1555 if comp[v] == usize::MAX {
1556 stk.push(v);
1557 }
1558 }
1559 }
1560 }
1561 components.push(component);
1562 }
1563 }
1564 components
1565 }
1566 #[allow(dead_code)]
1567 pub fn node_count(&self) -> usize {
1568 self.n
1569 }
1570 #[allow(dead_code)]
1571 pub fn edge_count(&self) -> usize {
1572 self.edge_count
1573 }
1574}
1575#[derive(Debug, Clone)]
1577pub struct MlirModule {
1578 pub name: Option<String>,
1580 pub functions: Vec<MlirFunc>,
1582 pub globals: Vec<MlirGlobal>,
1584 pub required_dialects: Vec<MlirDialect>,
1586}
1587impl MlirModule {
1588 pub fn new() -> Self {
1590 MlirModule {
1591 name: None,
1592 functions: vec![],
1593 globals: vec![],
1594 required_dialects: vec![],
1595 }
1596 }
1597 pub fn named(name: impl Into<String>) -> Self {
1599 MlirModule {
1600 name: Some(name.into()),
1601 functions: vec![],
1602 globals: vec![],
1603 required_dialects: vec![],
1604 }
1605 }
1606 pub fn add_function(&mut self, func: MlirFunc) {
1608 self.functions.push(func);
1609 }
1610 pub fn add_global(&mut self, global: MlirGlobal) {
1612 self.globals.push(global);
1613 }
1614 pub fn emit(&self) -> String {
1616 let mut out = String::new();
1617 if let Some(name) = &self.name {
1618 out.push_str(&format!("module @{} {{\n", name));
1619 } else {
1620 out.push_str("module {\n");
1621 }
1622 for global in &self.globals {
1623 out.push_str(&global.emit());
1624 }
1625 for func in &self.functions {
1626 out.push_str(&func.emit());
1627 }
1628 out.push_str("}\n");
1629 out
1630 }
1631}
1632#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1634pub enum CmpiPred {
1635 Eq,
1637 Ne,
1639 Slt,
1641 Sle,
1643 Sgt,
1645 Sge,
1647 Ult,
1649 Ule,
1651 Ugt,
1653 Uge,
1655}
1656#[allow(dead_code)]
1658#[derive(Debug, Clone)]
1659pub struct MLIRExtDomTree {
1660 pub(super) idom: Vec<Option<usize>>,
1661 pub(super) children: Vec<Vec<usize>>,
1662 pub(super) depth: Vec<usize>,
1663}
1664impl MLIRExtDomTree {
1665 #[allow(dead_code)]
1666 pub fn new(n: usize) -> Self {
1667 Self {
1668 idom: vec![None; n],
1669 children: vec![Vec::new(); n],
1670 depth: vec![0; n],
1671 }
1672 }
1673 #[allow(dead_code)]
1674 pub fn set_idom(&mut self, node: usize, dom: usize) {
1675 if node < self.idom.len() {
1676 self.idom[node] = Some(dom);
1677 if dom < self.children.len() {
1678 self.children[dom].push(node);
1679 }
1680 self.depth[node] = if dom < self.depth.len() {
1681 self.depth[dom] + 1
1682 } else {
1683 1
1684 };
1685 }
1686 }
1687 #[allow(dead_code)]
1688 pub fn dominates(&self, a: usize, mut b: usize) -> bool {
1689 if a == b {
1690 return true;
1691 }
1692 let n = self.idom.len();
1693 for _ in 0..n {
1694 match self.idom.get(b).copied().flatten() {
1695 None => return false,
1696 Some(p) if p == a => return true,
1697 Some(p) if p == b => return false,
1698 Some(p) => b = p,
1699 }
1700 }
1701 false
1702 }
1703 #[allow(dead_code)]
1704 pub fn children_of(&self, n: usize) -> &[usize] {
1705 self.children.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1706 }
1707 #[allow(dead_code)]
1708 pub fn depth_of(&self, n: usize) -> usize {
1709 self.depth.get(n).copied().unwrap_or(0)
1710 }
1711 #[allow(dead_code)]
1712 pub fn lca(&self, mut a: usize, mut b: usize) -> usize {
1713 let n = self.idom.len();
1714 for _ in 0..(2 * n) {
1715 if a == b {
1716 return a;
1717 }
1718 if self.depth_of(a) > self.depth_of(b) {
1719 a = self.idom.get(a).and_then(|x| *x).unwrap_or(a);
1720 } else {
1721 b = self.idom.get(b).and_then(|x| *x).unwrap_or(b);
1722 }
1723 }
1724 0
1725 }
1726}
1727#[derive(Debug, Clone)]
1729pub struct MlirGlobal {
1730 pub name: String,
1732 pub ty: MlirType,
1734 pub initial_value: Option<MlirAttr>,
1736 pub is_constant: bool,
1738 pub linkage: String,
1740}
1741impl MlirGlobal {
1742 pub fn constant(name: impl Into<String>, ty: MlirType, value: MlirAttr) -> Self {
1744 MlirGlobal {
1745 name: name.into(),
1746 ty,
1747 initial_value: Some(value),
1748 is_constant: true,
1749 linkage: "public".to_string(),
1750 }
1751 }
1752 pub fn emit(&self) -> String {
1754 let mut out = String::new();
1755 out.push_str(" memref.global ");
1756 if self.is_constant {
1757 out.push_str("constant ");
1758 }
1759 out.push_str(&format!("@{} : {}", self.name, self.ty));
1760 if let Some(v) = &self.initial_value {
1761 out.push_str(&format!(" = {}", v));
1762 }
1763 out.push('\n');
1764 out
1765 }
1766}
1767#[allow(dead_code)]
1769#[derive(Debug, Default)]
1770pub struct MLIRExtPassRegistry {
1771 pub(super) configs: Vec<MLIRExtPassConfig>,
1772 pub(super) stats: Vec<MLIRExtPassStats>,
1773}
1774impl MLIRExtPassRegistry {
1775 #[allow(dead_code)]
1776 pub fn new() -> Self {
1777 Self::default()
1778 }
1779 #[allow(dead_code)]
1780 pub fn register(&mut self, c: MLIRExtPassConfig) {
1781 self.stats.push(MLIRExtPassStats::new());
1782 self.configs.push(c);
1783 }
1784 #[allow(dead_code)]
1785 pub fn len(&self) -> usize {
1786 self.configs.len()
1787 }
1788 #[allow(dead_code)]
1789 pub fn is_empty(&self) -> bool {
1790 self.configs.is_empty()
1791 }
1792 #[allow(dead_code)]
1793 pub fn get(&self, i: usize) -> Option<&MLIRExtPassConfig> {
1794 self.configs.get(i)
1795 }
1796 #[allow(dead_code)]
1797 pub fn get_stats(&self, i: usize) -> Option<&MLIRExtPassStats> {
1798 self.stats.get(i)
1799 }
1800 #[allow(dead_code)]
1801 pub fn enabled_passes(&self) -> Vec<&MLIRExtPassConfig> {
1802 self.configs.iter().filter(|c| c.enabled).collect()
1803 }
1804 #[allow(dead_code)]
1805 pub fn passes_in_phase(&self, ph: &MLIRExtPassPhase) -> Vec<&MLIRExtPassConfig> {
1806 self.configs
1807 .iter()
1808 .filter(|c| c.enabled && &c.phase == ph)
1809 .collect()
1810 }
1811 #[allow(dead_code)]
1812 pub fn total_nodes_visited(&self) -> usize {
1813 self.stats.iter().map(|s| s.nodes_visited).sum()
1814 }
1815 #[allow(dead_code)]
1816 pub fn any_changed(&self) -> bool {
1817 self.stats.iter().any(|s| s.changed)
1818 }
1819}
1820#[allow(dead_code)]
1821#[derive(Debug, Clone)]
1822pub struct MLIRDepGraph {
1823 pub(super) nodes: Vec<u32>,
1824 pub(super) edges: Vec<(u32, u32)>,
1825}
1826impl MLIRDepGraph {
1827 #[allow(dead_code)]
1828 pub fn new() -> Self {
1829 MLIRDepGraph {
1830 nodes: Vec::new(),
1831 edges: Vec::new(),
1832 }
1833 }
1834 #[allow(dead_code)]
1835 pub fn add_node(&mut self, id: u32) {
1836 if !self.nodes.contains(&id) {
1837 self.nodes.push(id);
1838 }
1839 }
1840 #[allow(dead_code)]
1841 pub fn add_dep(&mut self, dep: u32, dependent: u32) {
1842 self.add_node(dep);
1843 self.add_node(dependent);
1844 self.edges.push((dep, dependent));
1845 }
1846 #[allow(dead_code)]
1847 pub fn dependents_of(&self, node: u32) -> Vec<u32> {
1848 self.edges
1849 .iter()
1850 .filter(|(d, _)| *d == node)
1851 .map(|(_, dep)| *dep)
1852 .collect()
1853 }
1854 #[allow(dead_code)]
1855 pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
1856 self.edges
1857 .iter()
1858 .filter(|(_, dep)| *dep == node)
1859 .map(|(d, _)| *d)
1860 .collect()
1861 }
1862 #[allow(dead_code)]
1863 pub fn topological_sort(&self) -> Vec<u32> {
1864 let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
1865 for &n in &self.nodes {
1866 in_degree.insert(n, 0);
1867 }
1868 for (_, dep) in &self.edges {
1869 *in_degree.entry(*dep).or_insert(0) += 1;
1870 }
1871 let mut queue: std::collections::VecDeque<u32> = self
1872 .nodes
1873 .iter()
1874 .filter(|&&n| in_degree[&n] == 0)
1875 .copied()
1876 .collect();
1877 let mut result = Vec::new();
1878 while let Some(node) = queue.pop_front() {
1879 result.push(node);
1880 for dep in self.dependents_of(node) {
1881 let cnt = in_degree.entry(dep).or_insert(0);
1882 *cnt = cnt.saturating_sub(1);
1883 if *cnt == 0 {
1884 queue.push_back(dep);
1885 }
1886 }
1887 }
1888 result
1889 }
1890 #[allow(dead_code)]
1891 pub fn has_cycle(&self) -> bool {
1892 self.topological_sort().len() < self.nodes.len()
1893 }
1894}
1895#[allow(dead_code)]
1896pub struct MLIRPassRegistry {
1897 pub(super) configs: Vec<MLIRPassConfig>,
1898 pub(super) stats: std::collections::HashMap<String, MLIRPassStats>,
1899}
1900impl MLIRPassRegistry {
1901 #[allow(dead_code)]
1902 pub fn new() -> Self {
1903 MLIRPassRegistry {
1904 configs: Vec::new(),
1905 stats: std::collections::HashMap::new(),
1906 }
1907 }
1908 #[allow(dead_code)]
1909 pub fn register(&mut self, config: MLIRPassConfig) {
1910 self.stats
1911 .insert(config.pass_name.clone(), MLIRPassStats::new());
1912 self.configs.push(config);
1913 }
1914 #[allow(dead_code)]
1915 pub fn enabled_passes(&self) -> Vec<&MLIRPassConfig> {
1916 self.configs.iter().filter(|c| c.enabled).collect()
1917 }
1918 #[allow(dead_code)]
1919 pub fn get_stats(&self, name: &str) -> Option<&MLIRPassStats> {
1920 self.stats.get(name)
1921 }
1922 #[allow(dead_code)]
1923 pub fn total_passes(&self) -> usize {
1924 self.configs.len()
1925 }
1926 #[allow(dead_code)]
1927 pub fn enabled_count(&self) -> usize {
1928 self.enabled_passes().len()
1929 }
1930 #[allow(dead_code)]
1931 pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
1932 if let Some(stats) = self.stats.get_mut(name) {
1933 stats.record_run(changes, time_ms, iter);
1934 }
1935 }
1936}
1937#[allow(dead_code)]
1938pub struct MLIRConstantFoldingHelper;
1939impl MLIRConstantFoldingHelper {
1940 #[allow(dead_code)]
1941 pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
1942 a.checked_add(b)
1943 }
1944 #[allow(dead_code)]
1945 pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
1946 a.checked_sub(b)
1947 }
1948 #[allow(dead_code)]
1949 pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
1950 a.checked_mul(b)
1951 }
1952 #[allow(dead_code)]
1953 pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
1954 if b == 0 {
1955 None
1956 } else {
1957 a.checked_div(b)
1958 }
1959 }
1960 #[allow(dead_code)]
1961 pub fn fold_add_f64(a: f64, b: f64) -> f64 {
1962 a + b
1963 }
1964 #[allow(dead_code)]
1965 pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
1966 a * b
1967 }
1968 #[allow(dead_code)]
1969 pub fn fold_neg_i64(a: i64) -> Option<i64> {
1970 a.checked_neg()
1971 }
1972 #[allow(dead_code)]
1973 pub fn fold_not_bool(a: bool) -> bool {
1974 !a
1975 }
1976 #[allow(dead_code)]
1977 pub fn fold_and_bool(a: bool, b: bool) -> bool {
1978 a && b
1979 }
1980 #[allow(dead_code)]
1981 pub fn fold_or_bool(a: bool, b: bool) -> bool {
1982 a || b
1983 }
1984 #[allow(dead_code)]
1985 pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
1986 a.checked_shl(b)
1987 }
1988 #[allow(dead_code)]
1989 pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
1990 a.checked_shr(b)
1991 }
1992 #[allow(dead_code)]
1993 pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
1994 if b == 0 {
1995 None
1996 } else {
1997 Some(a % b)
1998 }
1999 }
2000 #[allow(dead_code)]
2001 pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
2002 a & b
2003 }
2004 #[allow(dead_code)]
2005 pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
2006 a | b
2007 }
2008 #[allow(dead_code)]
2009 pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
2010 a ^ b
2011 }
2012 #[allow(dead_code)]
2013 pub fn fold_bitnot_i64(a: i64) -> i64 {
2014 !a
2015 }
2016}