1use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone, PartialEq)]
9pub struct MetalField {
10 pub ty: MetalType,
12 pub name: String,
14 pub attr: MetalParamAttr,
16}
17impl MetalField {
18 pub fn new(ty: MetalType, name: impl Into<String>) -> Self {
20 MetalField {
21 ty,
22 name: name.into(),
23 attr: MetalParamAttr::None,
24 }
25 }
26 pub fn with_builtin(ty: MetalType, name: impl Into<String>, b: MetalBuiltin) -> Self {
28 MetalField {
29 ty,
30 name: name.into(),
31 attr: MetalParamAttr::Builtin(b),
32 }
33 }
34 pub(super) fn emit(&self) -> String {
35 format!(" {}{} {};", self.attr, self.ty, self.name)
36 }
37}
38#[derive(Debug)]
40pub struct MetalExtEventLog {
41 pub(super) entries: std::collections::VecDeque<String>,
42 pub(super) capacity: usize,
43}
44impl MetalExtEventLog {
45 pub fn new(capacity: usize) -> Self {
46 MetalExtEventLog {
47 entries: std::collections::VecDeque::with_capacity(capacity),
48 capacity,
49 }
50 }
51 pub fn push(&mut self, event: impl Into<String>) {
52 if self.entries.len() >= self.capacity {
53 self.entries.pop_front();
54 }
55 self.entries.push_back(event.into());
56 }
57 pub fn iter(&self) -> impl Iterator<Item = &String> {
58 self.entries.iter()
59 }
60 pub fn len(&self) -> usize {
61 self.entries.len()
62 }
63 pub fn is_empty(&self) -> bool {
64 self.entries.is_empty()
65 }
66 pub fn capacity(&self) -> usize {
67 self.capacity
68 }
69 pub fn clear(&mut self) {
70 self.entries.clear();
71 }
72}
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum MemFlags {
76 None,
78 Device,
80 Threadgroup,
82 Texture,
84}
85#[derive(Debug, Clone)]
87pub struct MetalExtDiagMsg {
88 pub severity: MetalExtDiagSeverity,
89 pub pass: String,
90 pub message: String,
91}
92impl MetalExtDiagMsg {
93 pub fn error(pass: impl Into<String>, msg: impl Into<String>) -> Self {
94 MetalExtDiagMsg {
95 severity: MetalExtDiagSeverity::Error,
96 pass: pass.into(),
97 message: msg.into(),
98 }
99 }
100 pub fn warning(pass: impl Into<String>, msg: impl Into<String>) -> Self {
101 MetalExtDiagMsg {
102 severity: MetalExtDiagSeverity::Warning,
103 pass: pass.into(),
104 message: msg.into(),
105 }
106 }
107 pub fn note(pass: impl Into<String>, msg: impl Into<String>) -> Self {
108 MetalExtDiagMsg {
109 severity: MetalExtDiagSeverity::Note,
110 pass: pass.into(),
111 message: msg.into(),
112 }
113 }
114}
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
117pub enum MetalUnOp {
118 Neg,
119 Not,
120 BitNot,
121}
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
124pub enum MetalBinOp {
125 Add,
126 Sub,
127 Mul,
128 Div,
129 Mod,
130 Eq,
131 Neq,
132 Lt,
133 Le,
134 Gt,
135 Ge,
136 And,
137 Or,
138 BitAnd,
139 BitOr,
140 BitXor,
141 Shl,
142 Shr,
143}
144#[derive(Debug, Clone, Default)]
146pub struct MetalExtEmitStats {
147 pub bytes_emitted: usize,
148 pub items_emitted: usize,
149 pub errors: usize,
150 pub warnings: usize,
151 pub elapsed_ms: u64,
152}
153impl MetalExtEmitStats {
154 pub fn new() -> Self {
155 MetalExtEmitStats::default()
156 }
157 pub fn throughput_bps(&self) -> f64 {
158 if self.elapsed_ms == 0 {
159 0.0
160 } else {
161 self.bytes_emitted as f64 / (self.elapsed_ms as f64 / 1000.0)
162 }
163 }
164 pub fn is_clean(&self) -> bool {
165 self.errors == 0
166 }
167}
168#[allow(dead_code)]
169#[derive(Debug, Clone)]
170pub struct MetalAnalysisCache {
171 pub(super) entries: std::collections::HashMap<String, MetalCacheEntry>,
172 pub(super) max_size: usize,
173 pub(super) hits: u64,
174 pub(super) misses: u64,
175}
176impl MetalAnalysisCache {
177 #[allow(dead_code)]
178 pub fn new(max_size: usize) -> Self {
179 MetalAnalysisCache {
180 entries: std::collections::HashMap::new(),
181 max_size,
182 hits: 0,
183 misses: 0,
184 }
185 }
186 #[allow(dead_code)]
187 pub fn get(&mut self, key: &str) -> Option<&MetalCacheEntry> {
188 if self.entries.contains_key(key) {
189 self.hits += 1;
190 self.entries.get(key)
191 } else {
192 self.misses += 1;
193 None
194 }
195 }
196 #[allow(dead_code)]
197 pub fn insert(&mut self, key: String, data: Vec<u8>) {
198 if self.entries.len() >= self.max_size {
199 if let Some(oldest) = self.entries.keys().next().cloned() {
200 self.entries.remove(&oldest);
201 }
202 }
203 self.entries.insert(
204 key.clone(),
205 MetalCacheEntry {
206 key,
207 data,
208 timestamp: 0,
209 valid: true,
210 },
211 );
212 }
213 #[allow(dead_code)]
214 pub fn invalidate(&mut self, key: &str) {
215 if let Some(entry) = self.entries.get_mut(key) {
216 entry.valid = false;
217 }
218 }
219 #[allow(dead_code)]
220 pub fn clear(&mut self) {
221 self.entries.clear();
222 }
223 #[allow(dead_code)]
224 pub fn hit_rate(&self) -> f64 {
225 let total = self.hits + self.misses;
226 if total == 0 {
227 return 0.0;
228 }
229 self.hits as f64 / total as f64
230 }
231 #[allow(dead_code)]
232 pub fn size(&self) -> usize {
233 self.entries.len()
234 }
235}
236#[derive(Debug, Default)]
238pub struct MetalExtNameScope {
239 pub(super) declared: std::collections::HashSet<String>,
240 pub(super) depth: usize,
241 pub(super) parent: Option<Box<MetalExtNameScope>>,
242}
243impl MetalExtNameScope {
244 pub fn new() -> Self {
245 MetalExtNameScope::default()
246 }
247 pub fn declare(&mut self, name: impl Into<String>) -> bool {
248 self.declared.insert(name.into())
249 }
250 pub fn is_declared(&self, name: &str) -> bool {
251 self.declared.contains(name)
252 }
253 pub fn push_scope(self) -> Self {
254 MetalExtNameScope {
255 declared: std::collections::HashSet::new(),
256 depth: self.depth + 1,
257 parent: Some(Box::new(self)),
258 }
259 }
260 pub fn pop_scope(self) -> Self {
261 *self.parent.unwrap_or_default()
262 }
263 pub fn depth(&self) -> usize {
264 self.depth
265 }
266 pub fn len(&self) -> usize {
267 self.declared.len()
268 }
269}
270#[derive(Debug, Clone, PartialEq)]
272pub struct MetalStruct {
273 pub name: String,
275 pub fields: Vec<MetalField>,
277}
278impl MetalStruct {
279 pub fn new(name: impl Into<String>) -> Self {
281 MetalStruct {
282 name: name.into(),
283 fields: Vec::new(),
284 }
285 }
286 pub fn add_field(mut self, f: MetalField) -> Self {
288 self.fields.push(f);
289 self
290 }
291}
292#[derive(Debug, Clone, PartialEq)]
294pub struct MetalShader {
295 pub includes: Vec<String>,
297 pub using_namespaces: Vec<String>,
299 pub structs: Vec<MetalStruct>,
301 pub functions: Vec<MetalFunction>,
303 pub constants: Vec<(MetalType, String, MetalExpr)>,
305}
306impl MetalShader {
307 pub fn new() -> Self {
309 MetalShader {
310 includes: vec!["metal_stdlib".to_string()],
311 using_namespaces: vec!["metal".to_string()],
312 structs: Vec::new(),
313 functions: Vec::new(),
314 constants: Vec::new(),
315 }
316 }
317 pub fn add_include(mut self, header: impl Into<String>) -> Self {
319 self.includes.push(header.into());
320 self
321 }
322 pub fn add_namespace(mut self, ns: impl Into<String>) -> Self {
324 self.using_namespaces.push(ns.into());
325 self
326 }
327 pub fn add_struct(mut self, s: MetalStruct) -> Self {
329 self.structs.push(s);
330 self
331 }
332 pub fn add_function(mut self, f: MetalFunction) -> Self {
334 self.functions.push(f);
335 self
336 }
337 pub fn add_constant(mut self, ty: MetalType, name: impl Into<String>, val: MetalExpr) -> Self {
339 self.constants.push((ty, name.into(), val));
340 self
341 }
342}
343#[allow(dead_code)]
345#[derive(Debug, Clone, PartialEq, Eq, Hash)]
346pub enum MetalExtPassPhase {
347 Early,
348 Middle,
349 Late,
350 Finalize,
351}
352impl MetalExtPassPhase {
353 #[allow(dead_code)]
354 pub fn is_early(&self) -> bool {
355 matches!(self, Self::Early)
356 }
357 #[allow(dead_code)]
358 pub fn is_middle(&self) -> bool {
359 matches!(self, Self::Middle)
360 }
361 #[allow(dead_code)]
362 pub fn is_late(&self) -> bool {
363 matches!(self, Self::Late)
364 }
365 #[allow(dead_code)]
366 pub fn is_finalize(&self) -> bool {
367 matches!(self, Self::Finalize)
368 }
369 #[allow(dead_code)]
370 pub fn order(&self) -> u32 {
371 match self {
372 Self::Early => 0,
373 Self::Middle => 1,
374 Self::Late => 2,
375 Self::Finalize => 3,
376 }
377 }
378 #[allow(dead_code)]
379 pub fn from_order(n: u32) -> Option<Self> {
380 match n {
381 0 => Some(Self::Early),
382 1 => Some(Self::Middle),
383 2 => Some(Self::Late),
384 3 => Some(Self::Finalize),
385 _ => None,
386 }
387 }
388}
389#[allow(dead_code)]
391#[derive(Debug, Clone)]
392pub struct MetalExtWorklist {
393 pub(super) items: std::collections::VecDeque<usize>,
394 pub(super) present: Vec<bool>,
395}
396impl MetalExtWorklist {
397 #[allow(dead_code)]
398 pub fn new(capacity: usize) -> Self {
399 Self {
400 items: std::collections::VecDeque::new(),
401 present: vec![false; capacity],
402 }
403 }
404 #[allow(dead_code)]
405 pub fn push(&mut self, id: usize) {
406 if id < self.present.len() && !self.present[id] {
407 self.present[id] = true;
408 self.items.push_back(id);
409 }
410 }
411 #[allow(dead_code)]
412 pub fn push_front(&mut self, id: usize) {
413 if id < self.present.len() && !self.present[id] {
414 self.present[id] = true;
415 self.items.push_front(id);
416 }
417 }
418 #[allow(dead_code)]
419 pub fn pop(&mut self) -> Option<usize> {
420 let id = self.items.pop_front()?;
421 if id < self.present.len() {
422 self.present[id] = false;
423 }
424 Some(id)
425 }
426 #[allow(dead_code)]
427 pub fn is_empty(&self) -> bool {
428 self.items.is_empty()
429 }
430 #[allow(dead_code)]
431 pub fn len(&self) -> usize {
432 self.items.len()
433 }
434 #[allow(dead_code)]
435 pub fn contains(&self, id: usize) -> bool {
436 id < self.present.len() && self.present[id]
437 }
438 #[allow(dead_code)]
439 pub fn drain_all(&mut self) -> Vec<usize> {
440 let v: Vec<usize> = self.items.drain(..).collect();
441 for &id in &v {
442 if id < self.present.len() {
443 self.present[id] = false;
444 }
445 }
446 v
447 }
448}
449#[allow(dead_code)]
451#[derive(Debug, Clone, Default)]
452pub struct MetalExtConstFolder {
453 pub(super) folds: usize,
454 pub(super) failures: usize,
455 pub(super) enabled: bool,
456}
457impl MetalExtConstFolder {
458 #[allow(dead_code)]
459 pub fn new() -> Self {
460 Self {
461 folds: 0,
462 failures: 0,
463 enabled: true,
464 }
465 }
466 #[allow(dead_code)]
467 pub fn add_i64(&mut self, a: i64, b: i64) -> Option<i64> {
468 self.folds += 1;
469 a.checked_add(b)
470 }
471 #[allow(dead_code)]
472 pub fn sub_i64(&mut self, a: i64, b: i64) -> Option<i64> {
473 self.folds += 1;
474 a.checked_sub(b)
475 }
476 #[allow(dead_code)]
477 pub fn mul_i64(&mut self, a: i64, b: i64) -> Option<i64> {
478 self.folds += 1;
479 a.checked_mul(b)
480 }
481 #[allow(dead_code)]
482 pub fn div_i64(&mut self, a: i64, b: i64) -> Option<i64> {
483 if b == 0 {
484 self.failures += 1;
485 None
486 } else {
487 self.folds += 1;
488 a.checked_div(b)
489 }
490 }
491 #[allow(dead_code)]
492 pub fn rem_i64(&mut self, a: i64, b: i64) -> Option<i64> {
493 if b == 0 {
494 self.failures += 1;
495 None
496 } else {
497 self.folds += 1;
498 a.checked_rem(b)
499 }
500 }
501 #[allow(dead_code)]
502 pub fn neg_i64(&mut self, a: i64) -> Option<i64> {
503 self.folds += 1;
504 a.checked_neg()
505 }
506 #[allow(dead_code)]
507 pub fn shl_i64(&mut self, a: i64, s: u32) -> Option<i64> {
508 if s >= 64 {
509 self.failures += 1;
510 None
511 } else {
512 self.folds += 1;
513 a.checked_shl(s)
514 }
515 }
516 #[allow(dead_code)]
517 pub fn shr_i64(&mut self, a: i64, s: u32) -> Option<i64> {
518 if s >= 64 {
519 self.failures += 1;
520 None
521 } else {
522 self.folds += 1;
523 a.checked_shr(s)
524 }
525 }
526 #[allow(dead_code)]
527 pub fn and_i64(&mut self, a: i64, b: i64) -> i64 {
528 self.folds += 1;
529 a & b
530 }
531 #[allow(dead_code)]
532 pub fn or_i64(&mut self, a: i64, b: i64) -> i64 {
533 self.folds += 1;
534 a | b
535 }
536 #[allow(dead_code)]
537 pub fn xor_i64(&mut self, a: i64, b: i64) -> i64 {
538 self.folds += 1;
539 a ^ b
540 }
541 #[allow(dead_code)]
542 pub fn not_i64(&mut self, a: i64) -> i64 {
543 self.folds += 1;
544 !a
545 }
546 #[allow(dead_code)]
547 pub fn fold_count(&self) -> usize {
548 self.folds
549 }
550 #[allow(dead_code)]
551 pub fn failure_count(&self) -> usize {
552 self.failures
553 }
554 #[allow(dead_code)]
555 pub fn enable(&mut self) {
556 self.enabled = true;
557 }
558 #[allow(dead_code)]
559 pub fn disable(&mut self) {
560 self.enabled = false;
561 }
562 #[allow(dead_code)]
563 pub fn is_enabled(&self) -> bool {
564 self.enabled
565 }
566}
567#[allow(dead_code)]
569#[derive(Debug, Clone)]
570pub struct MetalExtPassConfig {
571 pub name: String,
572 pub phase: MetalExtPassPhase,
573 pub enabled: bool,
574 pub max_iterations: usize,
575 pub debug: u32,
576 pub timeout_ms: Option<u64>,
577}
578impl MetalExtPassConfig {
579 #[allow(dead_code)]
580 pub fn new(name: impl Into<String>) -> Self {
581 Self {
582 name: name.into(),
583 phase: MetalExtPassPhase::Middle,
584 enabled: true,
585 max_iterations: 100,
586 debug: 0,
587 timeout_ms: None,
588 }
589 }
590 #[allow(dead_code)]
591 pub fn with_phase(mut self, phase: MetalExtPassPhase) -> Self {
592 self.phase = phase;
593 self
594 }
595 #[allow(dead_code)]
596 pub fn with_max_iter(mut self, n: usize) -> Self {
597 self.max_iterations = n;
598 self
599 }
600 #[allow(dead_code)]
601 pub fn with_debug(mut self, d: u32) -> Self {
602 self.debug = d;
603 self
604 }
605 #[allow(dead_code)]
606 pub fn disabled(mut self) -> Self {
607 self.enabled = false;
608 self
609 }
610 #[allow(dead_code)]
611 pub fn with_timeout(mut self, ms: u64) -> Self {
612 self.timeout_ms = Some(ms);
613 self
614 }
615 #[allow(dead_code)]
616 pub fn is_debug_enabled(&self) -> bool {
617 self.debug > 0
618 }
619}
620#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
622pub enum MetalBuiltin {
623 ThreadPositionInGrid,
625 ThreadPositionInThreadgroup,
627 ThreadgroupPositionInGrid,
629 ThreadsPerThreadgroup,
631 ThreadsPerGrid,
633 ThreadIndexInSimdgroup,
635 SimdgroupIndexInThreadgroup,
637 VertexId,
639 InstanceId,
641 Position,
643 FrontFacing,
645 SampleId,
647 Depth,
649}
650impl MetalBuiltin {
651 pub fn attribute(&self) -> &'static str {
653 match self {
654 MetalBuiltin::ThreadPositionInGrid => "[[thread_position_in_grid]]",
655 MetalBuiltin::ThreadPositionInThreadgroup => "[[thread_position_in_threadgroup]]",
656 MetalBuiltin::ThreadgroupPositionInGrid => "[[threadgroup_position_in_grid]]",
657 MetalBuiltin::ThreadsPerThreadgroup => "[[threads_per_threadgroup]]",
658 MetalBuiltin::ThreadsPerGrid => "[[threads_per_grid]]",
659 MetalBuiltin::ThreadIndexInSimdgroup => "[[thread_index_in_simdgroup]]",
660 MetalBuiltin::SimdgroupIndexInThreadgroup => "[[simdgroup_index_in_threadgroup]]",
661 MetalBuiltin::VertexId => "[[vertex_id]]",
662 MetalBuiltin::InstanceId => "[[instance_id]]",
663 MetalBuiltin::Position => "[[position]]",
664 MetalBuiltin::FrontFacing => "[[front_facing]]",
665 MetalBuiltin::SampleId => "[[sample_id]]",
666 MetalBuiltin::Depth => "[[depth(any)]]",
667 }
668 }
669 pub fn metal_type(&self) -> MetalType {
671 match self {
672 MetalBuiltin::ThreadPositionInGrid
673 | MetalBuiltin::ThreadPositionInThreadgroup
674 | MetalBuiltin::ThreadgroupPositionInGrid
675 | MetalBuiltin::ThreadsPerThreadgroup
676 | MetalBuiltin::ThreadsPerGrid => MetalType::Uint3,
677 MetalBuiltin::ThreadIndexInSimdgroup
678 | MetalBuiltin::SimdgroupIndexInThreadgroup
679 | MetalBuiltin::VertexId
680 | MetalBuiltin::InstanceId
681 | MetalBuiltin::SampleId => MetalType::Uint,
682 MetalBuiltin::Position => MetalType::Float4,
683 MetalBuiltin::FrontFacing => MetalType::Bool,
684 MetalBuiltin::Depth => MetalType::Float,
685 }
686 }
687}
688#[derive(Debug, Clone, PartialEq)]
690pub enum MetalStmt {
691 VarDecl {
693 ty: MetalType,
694 name: String,
695 init: Option<MetalExpr>,
696 is_const: bool,
698 },
699 Assign { lhs: MetalExpr, rhs: MetalExpr },
701 CompoundAssign {
703 lhs: MetalExpr,
704 op: MetalBinOp,
705 rhs: MetalExpr,
706 },
707 IfElse {
709 cond: MetalExpr,
710 then_body: Vec<MetalStmt>,
711 else_body: Option<Vec<MetalStmt>>,
712 },
713 ForLoop {
715 init: Box<MetalStmt>,
716 cond: MetalExpr,
717 step: MetalExpr,
718 body: Vec<MetalStmt>,
719 },
720 WhileLoop {
722 cond: MetalExpr,
723 body: Vec<MetalStmt>,
724 },
725 Return(Option<MetalExpr>),
727 Expr(MetalExpr),
729 Barrier(MemFlags),
731 Block(Vec<MetalStmt>),
733 Break,
735 Continue,
737}
738#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
740pub enum MetalExtDiagSeverity {
741 Note,
742 Warning,
743 Error,
744}
745#[derive(Debug, Clone)]
747pub struct MetalExtPassTiming {
748 pub pass_name: String,
749 pub elapsed_us: u64,
750 pub items_processed: usize,
751 pub bytes_before: usize,
752 pub bytes_after: usize,
753}
754impl MetalExtPassTiming {
755 pub fn new(
756 pass_name: impl Into<String>,
757 elapsed_us: u64,
758 items: usize,
759 before: usize,
760 after: usize,
761 ) -> Self {
762 MetalExtPassTiming {
763 pass_name: pass_name.into(),
764 elapsed_us,
765 items_processed: items,
766 bytes_before: before,
767 bytes_after: after,
768 }
769 }
770 pub fn throughput_mps(&self) -> f64 {
771 if self.elapsed_us == 0 {
772 0.0
773 } else {
774 self.items_processed as f64 / (self.elapsed_us as f64 / 1_000_000.0)
775 }
776 }
777 pub fn size_ratio(&self) -> f64 {
778 if self.bytes_before == 0 {
779 1.0
780 } else {
781 self.bytes_after as f64 / self.bytes_before as f64
782 }
783 }
784 pub fn is_profitable(&self) -> bool {
785 self.size_ratio() <= 1.05
786 }
787}
788#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
790pub enum MetalStage {
791 Vertex,
793 Fragment,
795 Kernel,
797 Mesh,
799 Device,
801}
802#[derive(Debug, Default)]
804pub struct MetalExtProfiler {
805 pub(super) timings: Vec<MetalExtPassTiming>,
806}
807impl MetalExtProfiler {
808 pub fn new() -> Self {
809 MetalExtProfiler::default()
810 }
811 pub fn record(&mut self, t: MetalExtPassTiming) {
812 self.timings.push(t);
813 }
814 pub fn total_elapsed_us(&self) -> u64 {
815 self.timings.iter().map(|t| t.elapsed_us).sum()
816 }
817 pub fn slowest_pass(&self) -> Option<&MetalExtPassTiming> {
818 self.timings.iter().max_by_key(|t| t.elapsed_us)
819 }
820 pub fn num_passes(&self) -> usize {
821 self.timings.len()
822 }
823 pub fn profitable_passes(&self) -> Vec<&MetalExtPassTiming> {
824 self.timings.iter().filter(|t| t.is_profitable()).collect()
825 }
826}
827#[derive(Debug, Clone, Default)]
829pub struct MetalExtFeatures {
830 pub(super) flags: std::collections::HashSet<String>,
831}
832impl MetalExtFeatures {
833 pub fn new() -> Self {
834 MetalExtFeatures::default()
835 }
836 pub fn enable(&mut self, flag: impl Into<String>) {
837 self.flags.insert(flag.into());
838 }
839 pub fn disable(&mut self, flag: &str) {
840 self.flags.remove(flag);
841 }
842 pub fn is_enabled(&self, flag: &str) -> bool {
843 self.flags.contains(flag)
844 }
845 pub fn len(&self) -> usize {
846 self.flags.len()
847 }
848 pub fn is_empty(&self) -> bool {
849 self.flags.is_empty()
850 }
851 pub fn union(&self, other: &MetalExtFeatures) -> MetalExtFeatures {
852 MetalExtFeatures {
853 flags: self.flags.union(&other.flags).cloned().collect(),
854 }
855 }
856 pub fn intersection(&self, other: &MetalExtFeatures) -> MetalExtFeatures {
857 MetalExtFeatures {
858 flags: self.flags.intersection(&other.flags).cloned().collect(),
859 }
860 }
861}
862#[allow(dead_code)]
863pub struct MetalConstantFoldingHelper;
864impl MetalConstantFoldingHelper {
865 #[allow(dead_code)]
866 pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
867 a.checked_add(b)
868 }
869 #[allow(dead_code)]
870 pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
871 a.checked_sub(b)
872 }
873 #[allow(dead_code)]
874 pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
875 a.checked_mul(b)
876 }
877 #[allow(dead_code)]
878 pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
879 if b == 0 {
880 None
881 } else {
882 a.checked_div(b)
883 }
884 }
885 #[allow(dead_code)]
886 pub fn fold_add_f64(a: f64, b: f64) -> f64 {
887 a + b
888 }
889 #[allow(dead_code)]
890 pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
891 a * b
892 }
893 #[allow(dead_code)]
894 pub fn fold_neg_i64(a: i64) -> Option<i64> {
895 a.checked_neg()
896 }
897 #[allow(dead_code)]
898 pub fn fold_not_bool(a: bool) -> bool {
899 !a
900 }
901 #[allow(dead_code)]
902 pub fn fold_and_bool(a: bool, b: bool) -> bool {
903 a && b
904 }
905 #[allow(dead_code)]
906 pub fn fold_or_bool(a: bool, b: bool) -> bool {
907 a || b
908 }
909 #[allow(dead_code)]
910 pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
911 a.checked_shl(b)
912 }
913 #[allow(dead_code)]
914 pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
915 a.checked_shr(b)
916 }
917 #[allow(dead_code)]
918 pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
919 if b == 0 {
920 None
921 } else {
922 Some(a % b)
923 }
924 }
925 #[allow(dead_code)]
926 pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
927 a & b
928 }
929 #[allow(dead_code)]
930 pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
931 a | b
932 }
933 #[allow(dead_code)]
934 pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
935 a ^ b
936 }
937 #[allow(dead_code)]
938 pub fn fold_bitnot_i64(a: i64) -> i64 {
939 !a
940 }
941}
942#[derive(Debug, Clone, PartialEq, Eq, Hash)]
944pub enum MetalType {
945 Bool,
947 Half,
949 Float,
951 Int,
953 Uint,
955 Short,
957 Ushort,
959 Char,
961 Uchar,
963 Long,
965 Ulong,
967 Float2,
969 Float3,
971 Float4,
973 Half2,
975 Half3,
977 Half4,
979 Int2,
981 Int3,
983 Int4,
985 Uint2,
987 Uint3,
989 Uint4,
991 Float2x2,
993 Float3x3,
995 Float4x4,
997 Array(Box<MetalType>, usize),
999 Struct(String),
1001 Texture(Box<MetalType>),
1003 Sampler,
1005 Pointer(Box<MetalType>, MetalAddressSpace),
1007 Void,
1009}
1010#[allow(dead_code)]
1011#[derive(Debug, Clone)]
1012pub struct MetalLivenessInfo {
1013 pub live_in: Vec<std::collections::HashSet<u32>>,
1014 pub live_out: Vec<std::collections::HashSet<u32>>,
1015 pub defs: Vec<std::collections::HashSet<u32>>,
1016 pub uses: Vec<std::collections::HashSet<u32>>,
1017}
1018impl MetalLivenessInfo {
1019 #[allow(dead_code)]
1020 pub fn new(block_count: usize) -> Self {
1021 MetalLivenessInfo {
1022 live_in: vec![std::collections::HashSet::new(); block_count],
1023 live_out: vec![std::collections::HashSet::new(); block_count],
1024 defs: vec![std::collections::HashSet::new(); block_count],
1025 uses: vec![std::collections::HashSet::new(); block_count],
1026 }
1027 }
1028 #[allow(dead_code)]
1029 pub fn add_def(&mut self, block: usize, var: u32) {
1030 if block < self.defs.len() {
1031 self.defs[block].insert(var);
1032 }
1033 }
1034 #[allow(dead_code)]
1035 pub fn add_use(&mut self, block: usize, var: u32) {
1036 if block < self.uses.len() {
1037 self.uses[block].insert(var);
1038 }
1039 }
1040 #[allow(dead_code)]
1041 pub fn is_live_in(&self, block: usize, var: u32) -> bool {
1042 self.live_in
1043 .get(block)
1044 .map(|s| s.contains(&var))
1045 .unwrap_or(false)
1046 }
1047 #[allow(dead_code)]
1048 pub fn is_live_out(&self, block: usize, var: u32) -> bool {
1049 self.live_out
1050 .get(block)
1051 .map(|s| s.contains(&var))
1052 .unwrap_or(false)
1053 }
1054}
1055#[allow(dead_code)]
1057#[derive(Debug, Clone)]
1058pub struct MetalExtDepGraph {
1059 pub(super) n: usize,
1060 pub(super) adj: Vec<Vec<usize>>,
1061 pub(super) rev: Vec<Vec<usize>>,
1062 pub(super) edge_count: usize,
1063}
1064impl MetalExtDepGraph {
1065 #[allow(dead_code)]
1066 pub fn new(n: usize) -> Self {
1067 Self {
1068 n,
1069 adj: vec![Vec::new(); n],
1070 rev: vec![Vec::new(); n],
1071 edge_count: 0,
1072 }
1073 }
1074 #[allow(dead_code)]
1075 pub fn add_edge(&mut self, from: usize, to: usize) {
1076 if from < self.n && to < self.n {
1077 if !self.adj[from].contains(&to) {
1078 self.adj[from].push(to);
1079 self.rev[to].push(from);
1080 self.edge_count += 1;
1081 }
1082 }
1083 }
1084 #[allow(dead_code)]
1085 pub fn succs(&self, n: usize) -> &[usize] {
1086 self.adj.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1087 }
1088 #[allow(dead_code)]
1089 pub fn preds(&self, n: usize) -> &[usize] {
1090 self.rev.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1091 }
1092 #[allow(dead_code)]
1093 pub fn topo_sort(&self) -> Option<Vec<usize>> {
1094 let mut deg: Vec<usize> = (0..self.n).map(|i| self.rev[i].len()).collect();
1095 let mut q: std::collections::VecDeque<usize> =
1096 (0..self.n).filter(|&i| deg[i] == 0).collect();
1097 let mut out = Vec::with_capacity(self.n);
1098 while let Some(u) = q.pop_front() {
1099 out.push(u);
1100 for &v in &self.adj[u] {
1101 deg[v] -= 1;
1102 if deg[v] == 0 {
1103 q.push_back(v);
1104 }
1105 }
1106 }
1107 if out.len() == self.n {
1108 Some(out)
1109 } else {
1110 None
1111 }
1112 }
1113 #[allow(dead_code)]
1114 pub fn has_cycle(&self) -> bool {
1115 self.topo_sort().is_none()
1116 }
1117 #[allow(dead_code)]
1118 pub fn reachable(&self, start: usize) -> Vec<usize> {
1119 let mut vis = vec![false; self.n];
1120 let mut stk = vec![start];
1121 let mut out = Vec::new();
1122 while let Some(u) = stk.pop() {
1123 if u < self.n && !vis[u] {
1124 vis[u] = true;
1125 out.push(u);
1126 for &v in &self.adj[u] {
1127 if !vis[v] {
1128 stk.push(v);
1129 }
1130 }
1131 }
1132 }
1133 out
1134 }
1135 #[allow(dead_code)]
1136 pub fn scc(&self) -> Vec<Vec<usize>> {
1137 let mut visited = vec![false; self.n];
1138 let mut order = Vec::new();
1139 for i in 0..self.n {
1140 if !visited[i] {
1141 let mut stk = vec![(i, 0usize)];
1142 while let Some((u, idx)) = stk.last_mut() {
1143 if !visited[*u] {
1144 visited[*u] = true;
1145 }
1146 if *idx < self.adj[*u].len() {
1147 let v = self.adj[*u][*idx];
1148 *idx += 1;
1149 if !visited[v] {
1150 stk.push((v, 0));
1151 }
1152 } else {
1153 order.push(*u);
1154 stk.pop();
1155 }
1156 }
1157 }
1158 }
1159 let mut comp = vec![usize::MAX; self.n];
1160 let mut components: Vec<Vec<usize>> = Vec::new();
1161 for &start in order.iter().rev() {
1162 if comp[start] == usize::MAX {
1163 let cid = components.len();
1164 let mut component = Vec::new();
1165 let mut stk = vec![start];
1166 while let Some(u) = stk.pop() {
1167 if comp[u] == usize::MAX {
1168 comp[u] = cid;
1169 component.push(u);
1170 for &v in &self.rev[u] {
1171 if comp[v] == usize::MAX {
1172 stk.push(v);
1173 }
1174 }
1175 }
1176 }
1177 components.push(component);
1178 }
1179 }
1180 components
1181 }
1182 #[allow(dead_code)]
1183 pub fn node_count(&self) -> usize {
1184 self.n
1185 }
1186 #[allow(dead_code)]
1187 pub fn edge_count(&self) -> usize {
1188 self.edge_count
1189 }
1190}
1191#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1193pub struct MetalExtVersion {
1194 pub major: u32,
1195 pub minor: u32,
1196 pub patch: u32,
1197 pub pre: Option<String>,
1198}
1199impl MetalExtVersion {
1200 pub fn new(major: u32, minor: u32, patch: u32) -> Self {
1201 MetalExtVersion {
1202 major,
1203 minor,
1204 patch,
1205 pre: None,
1206 }
1207 }
1208 pub fn with_pre(mut self, pre: impl Into<String>) -> Self {
1209 self.pre = Some(pre.into());
1210 self
1211 }
1212 pub fn is_stable(&self) -> bool {
1213 self.pre.is_none()
1214 }
1215 pub fn is_compatible_with(&self, other: &MetalExtVersion) -> bool {
1216 self.major == other.major && self.minor >= other.minor
1217 }
1218}
1219#[derive(Debug, Default)]
1221pub struct MetalExtSourceBuffer {
1222 pub(super) buf: String,
1223 pub(super) indent_level: usize,
1224 pub(super) indent_str: String,
1225}
1226impl MetalExtSourceBuffer {
1227 pub fn new() -> Self {
1228 MetalExtSourceBuffer {
1229 buf: String::new(),
1230 indent_level: 0,
1231 indent_str: " ".to_string(),
1232 }
1233 }
1234 pub fn with_indent(mut self, indent: impl Into<String>) -> Self {
1235 self.indent_str = indent.into();
1236 self
1237 }
1238 pub fn push_line(&mut self, line: &str) {
1239 for _ in 0..self.indent_level {
1240 self.buf.push_str(&self.indent_str);
1241 }
1242 self.buf.push_str(line);
1243 self.buf.push('\n');
1244 }
1245 pub fn push_raw(&mut self, s: &str) {
1246 self.buf.push_str(s);
1247 }
1248 pub fn indent(&mut self) {
1249 self.indent_level += 1;
1250 }
1251 pub fn dedent(&mut self) {
1252 self.indent_level = self.indent_level.saturating_sub(1);
1253 }
1254 pub fn as_str(&self) -> &str {
1255 &self.buf
1256 }
1257 pub fn len(&self) -> usize {
1258 self.buf.len()
1259 }
1260 pub fn is_empty(&self) -> bool {
1261 self.buf.is_empty()
1262 }
1263 pub fn line_count(&self) -> usize {
1264 self.buf.lines().count()
1265 }
1266 pub fn into_string(self) -> String {
1267 self.buf
1268 }
1269 pub fn reset(&mut self) {
1270 self.buf.clear();
1271 self.indent_level = 0;
1272 }
1273}
1274#[allow(dead_code)]
1276#[derive(Debug, Clone, Default)]
1277pub struct MetalExtLiveness {
1278 pub live_in: Vec<Vec<usize>>,
1279 pub live_out: Vec<Vec<usize>>,
1280 pub defs: Vec<Vec<usize>>,
1281 pub uses: Vec<Vec<usize>>,
1282}
1283impl MetalExtLiveness {
1284 #[allow(dead_code)]
1285 pub fn new(n: usize) -> Self {
1286 Self {
1287 live_in: vec![Vec::new(); n],
1288 live_out: vec![Vec::new(); n],
1289 defs: vec![Vec::new(); n],
1290 uses: vec![Vec::new(); n],
1291 }
1292 }
1293 #[allow(dead_code)]
1294 pub fn live_in(&self, b: usize, v: usize) -> bool {
1295 self.live_in.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1296 }
1297 #[allow(dead_code)]
1298 pub fn live_out(&self, b: usize, v: usize) -> bool {
1299 self.live_out
1300 .get(b)
1301 .map(|s| s.contains(&v))
1302 .unwrap_or(false)
1303 }
1304 #[allow(dead_code)]
1305 pub fn add_def(&mut self, b: usize, v: usize) {
1306 if let Some(s) = self.defs.get_mut(b) {
1307 if !s.contains(&v) {
1308 s.push(v);
1309 }
1310 }
1311 }
1312 #[allow(dead_code)]
1313 pub fn add_use(&mut self, b: usize, v: usize) {
1314 if let Some(s) = self.uses.get_mut(b) {
1315 if !s.contains(&v) {
1316 s.push(v);
1317 }
1318 }
1319 }
1320 #[allow(dead_code)]
1321 pub fn var_is_used_in_block(&self, b: usize, v: usize) -> bool {
1322 self.uses.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1323 }
1324 #[allow(dead_code)]
1325 pub fn var_is_def_in_block(&self, b: usize, v: usize) -> bool {
1326 self.defs.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1327 }
1328}
1329#[derive(Debug, Default)]
1331pub struct MetalExtDiagCollector {
1332 pub(super) msgs: Vec<MetalExtDiagMsg>,
1333}
1334impl MetalExtDiagCollector {
1335 pub fn new() -> Self {
1336 MetalExtDiagCollector::default()
1337 }
1338 pub fn emit(&mut self, d: MetalExtDiagMsg) {
1339 self.msgs.push(d);
1340 }
1341 pub fn has_errors(&self) -> bool {
1342 self.msgs
1343 .iter()
1344 .any(|d| d.severity == MetalExtDiagSeverity::Error)
1345 }
1346 pub fn errors(&self) -> Vec<&MetalExtDiagMsg> {
1347 self.msgs
1348 .iter()
1349 .filter(|d| d.severity == MetalExtDiagSeverity::Error)
1350 .collect()
1351 }
1352 pub fn warnings(&self) -> Vec<&MetalExtDiagMsg> {
1353 self.msgs
1354 .iter()
1355 .filter(|d| d.severity == MetalExtDiagSeverity::Warning)
1356 .collect()
1357 }
1358 pub fn len(&self) -> usize {
1359 self.msgs.len()
1360 }
1361 pub fn is_empty(&self) -> bool {
1362 self.msgs.is_empty()
1363 }
1364 pub fn clear(&mut self) {
1365 self.msgs.clear();
1366 }
1367}
1368#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1370pub struct MetalExtIncrKey {
1371 pub content_hash: u64,
1372 pub config_hash: u64,
1373}
1374impl MetalExtIncrKey {
1375 pub fn new(content: u64, config: u64) -> Self {
1376 MetalExtIncrKey {
1377 content_hash: content,
1378 config_hash: config,
1379 }
1380 }
1381 pub fn combined_hash(&self) -> u64 {
1382 self.content_hash.wrapping_mul(0x9e3779b97f4a7c15) ^ self.config_hash
1383 }
1384 pub fn matches(&self, other: &MetalExtIncrKey) -> bool {
1385 self.content_hash == other.content_hash && self.config_hash == other.config_hash
1386 }
1387}
1388#[allow(dead_code)]
1389pub struct MetalPassRegistry {
1390 pub(super) configs: Vec<MetalPassConfig>,
1391 pub(super) stats: std::collections::HashMap<String, MetalPassStats>,
1392}
1393impl MetalPassRegistry {
1394 #[allow(dead_code)]
1395 pub fn new() -> Self {
1396 MetalPassRegistry {
1397 configs: Vec::new(),
1398 stats: std::collections::HashMap::new(),
1399 }
1400 }
1401 #[allow(dead_code)]
1402 pub fn register(&mut self, config: MetalPassConfig) {
1403 self.stats
1404 .insert(config.pass_name.clone(), MetalPassStats::new());
1405 self.configs.push(config);
1406 }
1407 #[allow(dead_code)]
1408 pub fn enabled_passes(&self) -> Vec<&MetalPassConfig> {
1409 self.configs.iter().filter(|c| c.enabled).collect()
1410 }
1411 #[allow(dead_code)]
1412 pub fn get_stats(&self, name: &str) -> Option<&MetalPassStats> {
1413 self.stats.get(name)
1414 }
1415 #[allow(dead_code)]
1416 pub fn total_passes(&self) -> usize {
1417 self.configs.len()
1418 }
1419 #[allow(dead_code)]
1420 pub fn enabled_count(&self) -> usize {
1421 self.enabled_passes().len()
1422 }
1423 #[allow(dead_code)]
1424 pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
1425 if let Some(stats) = self.stats.get_mut(name) {
1426 stats.record_run(changes, time_ms, iter);
1427 }
1428 }
1429}
1430#[allow(dead_code)]
1431#[derive(Debug, Clone)]
1432pub struct MetalDominatorTree {
1433 pub idom: Vec<Option<u32>>,
1434 pub dom_children: Vec<Vec<u32>>,
1435 pub dom_depth: Vec<u32>,
1436}
1437impl MetalDominatorTree {
1438 #[allow(dead_code)]
1439 pub fn new(size: usize) -> Self {
1440 MetalDominatorTree {
1441 idom: vec![None; size],
1442 dom_children: vec![Vec::new(); size],
1443 dom_depth: vec![0; size],
1444 }
1445 }
1446 #[allow(dead_code)]
1447 pub fn set_idom(&mut self, node: usize, idom: u32) {
1448 self.idom[node] = Some(idom);
1449 }
1450 #[allow(dead_code)]
1451 pub fn dominates(&self, a: usize, b: usize) -> bool {
1452 if a == b {
1453 return true;
1454 }
1455 let mut cur = b;
1456 loop {
1457 match self.idom[cur] {
1458 Some(parent) if parent as usize == a => return true,
1459 Some(parent) if parent as usize == cur => return false,
1460 Some(parent) => cur = parent as usize,
1461 None => return false,
1462 }
1463 }
1464 }
1465 #[allow(dead_code)]
1466 pub fn depth(&self, node: usize) -> u32 {
1467 self.dom_depth.get(node).copied().unwrap_or(0)
1468 }
1469}
1470#[allow(dead_code)]
1471#[derive(Debug, Clone)]
1472pub struct MetalPassConfig {
1473 pub phase: MetalPassPhase,
1474 pub enabled: bool,
1475 pub max_iterations: u32,
1476 pub debug_output: bool,
1477 pub pass_name: String,
1478}
1479impl MetalPassConfig {
1480 #[allow(dead_code)]
1481 pub fn new(name: impl Into<String>, phase: MetalPassPhase) -> Self {
1482 MetalPassConfig {
1483 phase,
1484 enabled: true,
1485 max_iterations: 10,
1486 debug_output: false,
1487 pass_name: name.into(),
1488 }
1489 }
1490 #[allow(dead_code)]
1491 pub fn disabled(mut self) -> Self {
1492 self.enabled = false;
1493 self
1494 }
1495 #[allow(dead_code)]
1496 pub fn with_debug(mut self) -> Self {
1497 self.debug_output = true;
1498 self
1499 }
1500 #[allow(dead_code)]
1501 pub fn max_iter(mut self, n: u32) -> Self {
1502 self.max_iterations = n;
1503 self
1504 }
1505}
1506pub struct MetalBackend {
1508 pub(super) indent_width: usize,
1509}
1510impl MetalBackend {
1511 pub fn new() -> Self {
1513 MetalBackend { indent_width: 4 }
1514 }
1515 pub fn with_indent(indent_width: usize) -> Self {
1517 MetalBackend { indent_width }
1518 }
1519 pub(super) fn indent(&self, depth: usize) -> String {
1520 " ".repeat(self.indent_width * depth)
1521 }
1522 pub fn emit_expr(&self, expr: &MetalExpr) -> String {
1524 expr.emit()
1525 }
1526 pub fn emit_stmt(&self, stmt: &MetalStmt, depth: usize) -> String {
1528 let ind = self.indent(depth);
1529 match stmt {
1530 MetalStmt::VarDecl {
1531 ty,
1532 name,
1533 init,
1534 is_const,
1535 } => {
1536 let const_kw = if *is_const { "const " } else { "" };
1537 match init {
1538 Some(expr) => {
1539 format!("{}{}{} {} = {};", ind, const_kw, ty, name, expr.emit())
1540 }
1541 None => format!("{}{}{} {};", ind, const_kw, ty, name),
1542 }
1543 }
1544 MetalStmt::Assign { lhs, rhs } => {
1545 format!("{}{} = {};", ind, lhs.emit(), rhs.emit())
1546 }
1547 MetalStmt::CompoundAssign { lhs, op, rhs } => {
1548 format!("{}{} {}= {};", ind, lhs.emit(), op, rhs.emit())
1549 }
1550 MetalStmt::IfElse {
1551 cond,
1552 then_body,
1553 else_body,
1554 } => self.emit_if_else(cond, then_body, else_body.as_deref(), depth),
1555 MetalStmt::ForLoop {
1556 init,
1557 cond,
1558 step,
1559 body,
1560 } => self.emit_for_loop(init, cond, step, body, depth),
1561 MetalStmt::WhileLoop { cond, body } => self.emit_while(cond, body, depth),
1562 MetalStmt::Return(Some(expr)) => format!("{}return {};", ind, expr.emit()),
1563 MetalStmt::Return(None) => format!("{}return;", ind),
1564 MetalStmt::Expr(expr) => format!("{}{};", ind, expr.emit()),
1565 MetalStmt::Barrier(flags) => {
1566 format!("{}threadgroup_barrier({});", ind, flags)
1567 }
1568 MetalStmt::Block(stmts) => {
1569 let mut out = format!("{}{{\n", ind);
1570 for s in stmts {
1571 out.push_str(&self.emit_stmt(s, depth + 1));
1572 out.push('\n');
1573 }
1574 out.push_str(&format!("{}}}", ind));
1575 out
1576 }
1577 MetalStmt::Break => format!("{}break;", ind),
1578 MetalStmt::Continue => format!("{}continue;", ind),
1579 }
1580 }
1581 pub(super) fn emit_if_else(
1582 &self,
1583 cond: &MetalExpr,
1584 then_body: &[MetalStmt],
1585 else_body: Option<&[MetalStmt]>,
1586 depth: usize,
1587 ) -> String {
1588 let ind = self.indent(depth);
1589 let mut out = format!("{}if ({}) {{\n", ind, cond.emit());
1590 for s in then_body {
1591 out.push_str(&self.emit_stmt(s, depth + 1));
1592 out.push('\n');
1593 }
1594 out.push_str(&format!("{}}}", ind));
1595 if let Some(eb) = else_body {
1596 out.push_str(" else {\n");
1597 for s in eb {
1598 out.push_str(&self.emit_stmt(s, depth + 1));
1599 out.push('\n');
1600 }
1601 out.push_str(&format!("{}}}", ind));
1602 }
1603 out
1604 }
1605 pub(super) fn emit_for_loop(
1606 &self,
1607 init: &MetalStmt,
1608 cond: &MetalExpr,
1609 step: &MetalExpr,
1610 body: &[MetalStmt],
1611 depth: usize,
1612 ) -> String {
1613 let ind = self.indent(depth);
1614 let init_str = self.emit_stmt(init, 0).trim().to_string();
1615 let init_header = init_str.trim_end_matches(';');
1616 let mut out = format!(
1617 "{}for ({}; {}; {}) {{\n",
1618 ind,
1619 init_header,
1620 cond.emit(),
1621 step.emit()
1622 );
1623 for s in body {
1624 out.push_str(&self.emit_stmt(s, depth + 1));
1625 out.push('\n');
1626 }
1627 out.push_str(&format!("{}}}", ind));
1628 out
1629 }
1630 pub(super) fn emit_while(&self, cond: &MetalExpr, body: &[MetalStmt], depth: usize) -> String {
1631 let ind = self.indent(depth);
1632 let mut out = format!("{}while ({}) {{\n", ind, cond.emit());
1633 for s in body {
1634 out.push_str(&self.emit_stmt(s, depth + 1));
1635 out.push('\n');
1636 }
1637 out.push_str(&format!("{}}}", ind));
1638 out
1639 }
1640 pub(super) fn emit_struct(&self, s: &MetalStruct) -> String {
1641 let mut out = format!("struct {} {{\n", s.name);
1642 for field in &s.fields {
1643 out.push_str(&field.emit());
1644 out.push('\n');
1645 }
1646 out.push_str("};");
1647 out
1648 }
1649 pub(super) fn emit_function(&self, f: &MetalFunction) -> String {
1650 let stage_str = format!("{}", f.stage);
1651 let inline_str = if f.is_inline { "inline " } else { "" };
1652 let stage_prefix = if stage_str.is_empty() {
1653 String::new()
1654 } else {
1655 format!("{}\n", stage_str)
1656 };
1657 let params: Vec<String> = f.params.iter().map(|p| p.emit()).collect();
1658 let mut out = format!(
1659 "{}{}{} {}({}) {{\n",
1660 stage_prefix,
1661 inline_str,
1662 f.return_type,
1663 f.name,
1664 params.join(",\n ")
1665 );
1666 for s in &f.body {
1667 out.push_str(&self.emit_stmt(s, 1));
1668 out.push('\n');
1669 }
1670 out.push('}');
1671 out
1672 }
1673 pub fn emit_shader(&self, shader: &MetalShader) -> String {
1675 let mut out = String::new();
1676 for inc in &shader.includes {
1677 out.push_str(&format!("#include <{}>\n", inc));
1678 }
1679 if !shader.includes.is_empty() {
1680 out.push('\n');
1681 }
1682 for ns in &shader.using_namespaces {
1683 out.push_str(&format!("using namespace {};\n", ns));
1684 }
1685 if !shader.using_namespaces.is_empty() {
1686 out.push('\n');
1687 }
1688 for (ty, name, val) in &shader.constants {
1689 out.push_str(&format!("constant {} {} = {};\n", ty, name, val.emit()));
1690 }
1691 if !shader.constants.is_empty() {
1692 out.push('\n');
1693 }
1694 for s in &shader.structs {
1695 out.push_str(&self.emit_struct(s));
1696 out.push_str("\n\n");
1697 }
1698 for f in &shader.functions {
1699 out.push_str(&self.emit_function(f));
1700 out.push_str("\n\n");
1701 }
1702 out
1703 }
1704}
1705#[derive(Debug, Clone, Default)]
1707pub struct MetalExtConfig {
1708 pub(super) entries: std::collections::HashMap<String, String>,
1709}
1710impl MetalExtConfig {
1711 pub fn new() -> Self {
1712 MetalExtConfig::default()
1713 }
1714 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
1715 self.entries.insert(key.into(), value.into());
1716 }
1717 pub fn get(&self, key: &str) -> Option<&str> {
1718 self.entries.get(key).map(|s| s.as_str())
1719 }
1720 pub fn get_bool(&self, key: &str) -> bool {
1721 matches!(self.get(key), Some("true") | Some("1") | Some("yes"))
1722 }
1723 pub fn get_int(&self, key: &str) -> Option<i64> {
1724 self.get(key)?.parse().ok()
1725 }
1726 pub fn len(&self) -> usize {
1727 self.entries.len()
1728 }
1729 pub fn is_empty(&self) -> bool {
1730 self.entries.is_empty()
1731 }
1732}
1733#[derive(Debug, Clone, PartialEq)]
1735pub struct MetalParam {
1736 pub ty: MetalType,
1738 pub name: String,
1740 pub attr: MetalParamAttr,
1742}
1743impl MetalParam {
1744 pub fn new(ty: MetalType, name: impl Into<String>) -> Self {
1746 MetalParam {
1747 ty,
1748 name: name.into(),
1749 attr: MetalParamAttr::None,
1750 }
1751 }
1752 pub fn buffer(ty: MetalType, name: impl Into<String>, index: u32) -> Self {
1754 MetalParam {
1755 ty,
1756 name: name.into(),
1757 attr: MetalParamAttr::Buffer(index),
1758 }
1759 }
1760 pub fn texture(ty: MetalType, name: impl Into<String>, index: u32) -> Self {
1762 MetalParam {
1763 ty,
1764 name: name.into(),
1765 attr: MetalParamAttr::Texture(index),
1766 }
1767 }
1768 pub fn builtin(b: MetalBuiltin) -> Self {
1770 let ty = b.metal_type();
1771 let name = format!("{:?}", b).to_lowercase();
1772 MetalParam {
1773 ty,
1774 name,
1775 attr: MetalParamAttr::Builtin(b),
1776 }
1777 }
1778 pub(super) fn emit(&self) -> String {
1779 format!("{} {}{}", self.ty, self.name, self.attr)
1780 }
1781}
1782#[allow(dead_code)]
1783#[derive(Debug, Clone, Default)]
1784pub struct MetalPassStats {
1785 pub total_runs: u32,
1786 pub successful_runs: u32,
1787 pub total_changes: u64,
1788 pub time_ms: u64,
1789 pub iterations_used: u32,
1790}
1791impl MetalPassStats {
1792 #[allow(dead_code)]
1793 pub fn new() -> Self {
1794 Self::default()
1795 }
1796 #[allow(dead_code)]
1797 pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
1798 self.total_runs += 1;
1799 self.successful_runs += 1;
1800 self.total_changes += changes;
1801 self.time_ms += time_ms;
1802 self.iterations_used = iterations;
1803 }
1804 #[allow(dead_code)]
1805 pub fn average_changes_per_run(&self) -> f64 {
1806 if self.total_runs == 0 {
1807 return 0.0;
1808 }
1809 self.total_changes as f64 / self.total_runs as f64
1810 }
1811 #[allow(dead_code)]
1812 pub fn success_rate(&self) -> f64 {
1813 if self.total_runs == 0 {
1814 return 0.0;
1815 }
1816 self.successful_runs as f64 / self.total_runs as f64
1817 }
1818 #[allow(dead_code)]
1819 pub fn format_summary(&self) -> String {
1820 format!(
1821 "Runs: {}/{}, Changes: {}, Time: {}ms",
1822 self.successful_runs, self.total_runs, self.total_changes, self.time_ms
1823 )
1824 }
1825}
1826#[derive(Debug, Clone, PartialEq)]
1828pub struct MetalFunction {
1829 pub name: String,
1831 pub stage: MetalStage,
1833 pub params: Vec<MetalParam>,
1835 pub return_type: MetalType,
1837 pub body: Vec<MetalStmt>,
1839 pub is_inline: bool,
1841}
1842impl MetalFunction {
1843 pub fn new(name: impl Into<String>, stage: MetalStage, return_type: MetalType) -> Self {
1845 MetalFunction {
1846 name: name.into(),
1847 stage,
1848 params: Vec::new(),
1849 return_type,
1850 body: Vec::new(),
1851 is_inline: false,
1852 }
1853 }
1854 pub fn kernel(name: impl Into<String>) -> Self {
1856 MetalFunction::new(name, MetalStage::Kernel, MetalType::Void)
1857 }
1858 pub fn vertex(name: impl Into<String>, return_type: MetalType) -> Self {
1860 MetalFunction::new(name, MetalStage::Vertex, return_type)
1861 }
1862 pub fn fragment(name: impl Into<String>, return_type: MetalType) -> Self {
1864 MetalFunction::new(name, MetalStage::Fragment, return_type)
1865 }
1866 pub fn device_fn(name: impl Into<String>, return_type: MetalType) -> Self {
1868 MetalFunction::new(name, MetalStage::Device, return_type)
1869 }
1870 pub fn with_inline(mut self) -> Self {
1872 self.is_inline = true;
1873 self
1874 }
1875 pub fn add_param(mut self, p: MetalParam) -> Self {
1877 self.params.push(p);
1878 self
1879 }
1880 pub fn add_stmt(mut self, s: MetalStmt) -> Self {
1882 self.body.push(s);
1883 self
1884 }
1885}
1886#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1888pub enum MetalAddressSpace {
1889 Device,
1891 Constant,
1893 Threadgroup,
1895 ThreadgroupImageblock,
1897 RayData,
1899 ObjectData,
1901 Thread,
1903}
1904#[derive(Debug, Clone, PartialEq, Eq)]
1906pub enum MetalParamAttr {
1907 Buffer(u32),
1909 Texture(u32),
1911 Sampler(u32),
1913 StageIn,
1915 Builtin(MetalBuiltin),
1917 None,
1919}
1920#[allow(dead_code)]
1921#[derive(Debug, Clone)]
1922pub struct MetalCacheEntry {
1923 pub key: String,
1924 pub data: Vec<u8>,
1925 pub timestamp: u64,
1926 pub valid: bool,
1927}
1928#[allow(dead_code)]
1929#[derive(Debug, Clone)]
1930pub struct MetalWorklist {
1931 pub(super) items: std::collections::VecDeque<u32>,
1932 pub(super) in_worklist: std::collections::HashSet<u32>,
1933}
1934impl MetalWorklist {
1935 #[allow(dead_code)]
1936 pub fn new() -> Self {
1937 MetalWorklist {
1938 items: std::collections::VecDeque::new(),
1939 in_worklist: std::collections::HashSet::new(),
1940 }
1941 }
1942 #[allow(dead_code)]
1943 pub fn push(&mut self, item: u32) -> bool {
1944 if self.in_worklist.insert(item) {
1945 self.items.push_back(item);
1946 true
1947 } else {
1948 false
1949 }
1950 }
1951 #[allow(dead_code)]
1952 pub fn pop(&mut self) -> Option<u32> {
1953 let item = self.items.pop_front()?;
1954 self.in_worklist.remove(&item);
1955 Some(item)
1956 }
1957 #[allow(dead_code)]
1958 pub fn is_empty(&self) -> bool {
1959 self.items.is_empty()
1960 }
1961 #[allow(dead_code)]
1962 pub fn len(&self) -> usize {
1963 self.items.len()
1964 }
1965 #[allow(dead_code)]
1966 pub fn contains(&self, item: u32) -> bool {
1967 self.in_worklist.contains(&item)
1968 }
1969}
1970#[allow(dead_code)]
1972#[derive(Debug, Clone, Default)]
1973pub struct MetalExtPassStats {
1974 pub iterations: usize,
1975 pub changed: bool,
1976 pub nodes_visited: usize,
1977 pub nodes_modified: usize,
1978 pub time_ms: u64,
1979 pub memory_bytes: usize,
1980 pub errors: usize,
1981}
1982impl MetalExtPassStats {
1983 #[allow(dead_code)]
1984 pub fn new() -> Self {
1985 Self::default()
1986 }
1987 #[allow(dead_code)]
1988 pub fn visit(&mut self) {
1989 self.nodes_visited += 1;
1990 }
1991 #[allow(dead_code)]
1992 pub fn modify(&mut self) {
1993 self.nodes_modified += 1;
1994 self.changed = true;
1995 }
1996 #[allow(dead_code)]
1997 pub fn iterate(&mut self) {
1998 self.iterations += 1;
1999 }
2000 #[allow(dead_code)]
2001 pub fn error(&mut self) {
2002 self.errors += 1;
2003 }
2004 #[allow(dead_code)]
2005 pub fn efficiency(&self) -> f64 {
2006 if self.nodes_visited == 0 {
2007 0.0
2008 } else {
2009 self.nodes_modified as f64 / self.nodes_visited as f64
2010 }
2011 }
2012 #[allow(dead_code)]
2013 pub fn merge(&mut self, o: &MetalExtPassStats) {
2014 self.iterations += o.iterations;
2015 self.changed |= o.changed;
2016 self.nodes_visited += o.nodes_visited;
2017 self.nodes_modified += o.nodes_modified;
2018 self.time_ms += o.time_ms;
2019 self.memory_bytes = self.memory_bytes.max(o.memory_bytes);
2020 self.errors += o.errors;
2021 }
2022}
2023#[derive(Debug, Clone, PartialEq)]
2025pub enum MetalExpr {
2026 LitInt(i64),
2028 LitFloat(f64),
2030 LitBool(bool),
2032 Var(String),
2034 Builtin(MetalBuiltin),
2036 Index(Box<MetalExpr>, Box<MetalExpr>),
2038 Member(Box<MetalExpr>, String),
2040 PtrMember(Box<MetalExpr>, String),
2042 Cast(MetalType, Box<MetalExpr>),
2044 Call(String, Vec<MetalExpr>),
2046 BinOp(Box<MetalExpr>, MetalBinOp, Box<MetalExpr>),
2048 UnOp(MetalUnOp, Box<MetalExpr>),
2050 Ternary(Box<MetalExpr>, Box<MetalExpr>, Box<MetalExpr>),
2052 SimdSum(Box<MetalExpr>),
2054 SimdShuffleDown(Box<MetalExpr>, Box<MetalExpr>),
2056 SimdBroadcast(Box<MetalExpr>, Box<MetalExpr>),
2058 AtomicFetchAdd(Box<MetalExpr>, Box<MetalExpr>),
2060 ThreadgroupBarrier(MemFlags),
2062 AsType(MetalType, Box<MetalExpr>),
2064 Select(Box<MetalExpr>, Box<MetalExpr>, Box<MetalExpr>),
2066 Dot(Box<MetalExpr>, Box<MetalExpr>),
2068 Length(Box<MetalExpr>),
2070 Normalize(Box<MetalExpr>),
2072 Clamp(Box<MetalExpr>, Box<MetalExpr>, Box<MetalExpr>),
2074}
2075impl MetalExpr {
2076 pub(super) fn emit(&self) -> String {
2077 match self {
2078 MetalExpr::LitInt(n) => n.to_string(),
2079 MetalExpr::LitFloat(f) => format!("{:.6}f", f),
2080 MetalExpr::LitBool(b) => if *b { "true" } else { "false" }.to_string(),
2081 MetalExpr::Var(name) => name.clone(),
2082 MetalExpr::Builtin(b) => format!("{:?}", b).to_lowercase(),
2083 MetalExpr::Index(base, idx) => format!("{}[{}]", base.emit(), idx.emit()),
2084 MetalExpr::Member(base, field) => format!("{}.{}", base.emit(), field),
2085 MetalExpr::PtrMember(base, field) => format!("{}->{}", base.emit(), field),
2086 MetalExpr::Cast(ty, expr) => format!("(({})({})))", ty, expr.emit()),
2087 MetalExpr::Call(name, args) => {
2088 let arg_strs: Vec<String> = args.iter().map(|a| a.emit()).collect();
2089 format!("{}({})", name, arg_strs.join(", "))
2090 }
2091 MetalExpr::BinOp(lhs, op, rhs) => {
2092 format!("({} {} {})", lhs.emit(), op, rhs.emit())
2093 }
2094 MetalExpr::UnOp(op, expr) => format!("({}{})", op, expr.emit()),
2095 MetalExpr::Ternary(cond, then, els) => {
2096 format!("({} ? {} : {})", cond.emit(), then.emit(), els.emit())
2097 }
2098 MetalExpr::SimdSum(val) => format!("simd_sum({})", val.emit()),
2099 MetalExpr::SimdShuffleDown(val, delta) => {
2100 format!("simd_shuffle_down({}, {})", val.emit(), delta.emit())
2101 }
2102 MetalExpr::SimdBroadcast(val, lane) => {
2103 format!("simd_broadcast({}, {})", val.emit(), lane.emit())
2104 }
2105 MetalExpr::AtomicFetchAdd(atom, val) => {
2106 format!(
2107 "atomic_fetch_add_explicit({}, {}, memory_order_relaxed)",
2108 atom.emit(),
2109 val.emit()
2110 )
2111 }
2112 MetalExpr::ThreadgroupBarrier(flags) => {
2113 format!("threadgroup_barrier({})", flags)
2114 }
2115 MetalExpr::AsType(ty, expr) => format!("as_type<{}>({})", ty, expr.emit()),
2116 MetalExpr::Select(a, b, cond) => {
2117 format!("select({}, {}, {})", a.emit(), b.emit(), cond.emit())
2118 }
2119 MetalExpr::Dot(a, b) => format!("dot({}, {})", a.emit(), b.emit()),
2120 MetalExpr::Length(v) => format!("length({})", v.emit()),
2121 MetalExpr::Normalize(v) => format!("normalize({})", v.emit()),
2122 MetalExpr::Clamp(val, lo, hi) => {
2123 format!("clamp({}, {}, {})", val.emit(), lo.emit(), hi.emit())
2124 }
2125 }
2126 }
2127}
2128#[allow(dead_code)]
2129#[derive(Debug, Clone, PartialEq)]
2130pub enum MetalPassPhase {
2131 Analysis,
2132 Transformation,
2133 Verification,
2134 Cleanup,
2135}
2136impl MetalPassPhase {
2137 #[allow(dead_code)]
2138 pub fn name(&self) -> &str {
2139 match self {
2140 MetalPassPhase::Analysis => "analysis",
2141 MetalPassPhase::Transformation => "transformation",
2142 MetalPassPhase::Verification => "verification",
2143 MetalPassPhase::Cleanup => "cleanup",
2144 }
2145 }
2146 #[allow(dead_code)]
2147 pub fn is_modifying(&self) -> bool {
2148 matches!(
2149 self,
2150 MetalPassPhase::Transformation | MetalPassPhase::Cleanup
2151 )
2152 }
2153}
2154#[allow(dead_code)]
2156#[derive(Debug, Default)]
2157pub struct MetalExtPassRegistry {
2158 pub(super) configs: Vec<MetalExtPassConfig>,
2159 pub(super) stats: Vec<MetalExtPassStats>,
2160}
2161impl MetalExtPassRegistry {
2162 #[allow(dead_code)]
2163 pub fn new() -> Self {
2164 Self::default()
2165 }
2166 #[allow(dead_code)]
2167 pub fn register(&mut self, c: MetalExtPassConfig) {
2168 self.stats.push(MetalExtPassStats::new());
2169 self.configs.push(c);
2170 }
2171 #[allow(dead_code)]
2172 pub fn len(&self) -> usize {
2173 self.configs.len()
2174 }
2175 #[allow(dead_code)]
2176 pub fn is_empty(&self) -> bool {
2177 self.configs.is_empty()
2178 }
2179 #[allow(dead_code)]
2180 pub fn get(&self, i: usize) -> Option<&MetalExtPassConfig> {
2181 self.configs.get(i)
2182 }
2183 #[allow(dead_code)]
2184 pub fn get_stats(&self, i: usize) -> Option<&MetalExtPassStats> {
2185 self.stats.get(i)
2186 }
2187 #[allow(dead_code)]
2188 pub fn enabled_passes(&self) -> Vec<&MetalExtPassConfig> {
2189 self.configs.iter().filter(|c| c.enabled).collect()
2190 }
2191 #[allow(dead_code)]
2192 pub fn passes_in_phase(&self, ph: &MetalExtPassPhase) -> Vec<&MetalExtPassConfig> {
2193 self.configs
2194 .iter()
2195 .filter(|c| c.enabled && &c.phase == ph)
2196 .collect()
2197 }
2198 #[allow(dead_code)]
2199 pub fn total_nodes_visited(&self) -> usize {
2200 self.stats.iter().map(|s| s.nodes_visited).sum()
2201 }
2202 #[allow(dead_code)]
2203 pub fn any_changed(&self) -> bool {
2204 self.stats.iter().any(|s| s.changed)
2205 }
2206}
2207#[allow(dead_code)]
2209#[derive(Debug)]
2210pub struct MetalExtCache {
2211 pub(super) entries: Vec<(u64, Vec<u8>, bool, u32)>,
2212 pub(super) cap: usize,
2213 pub(super) total_hits: u64,
2214 pub(super) total_misses: u64,
2215}
2216impl MetalExtCache {
2217 #[allow(dead_code)]
2218 pub fn new(cap: usize) -> Self {
2219 Self {
2220 entries: Vec::new(),
2221 cap,
2222 total_hits: 0,
2223 total_misses: 0,
2224 }
2225 }
2226 #[allow(dead_code)]
2227 pub fn get(&mut self, key: u64) -> Option<&[u8]> {
2228 for e in self.entries.iter_mut() {
2229 if e.0 == key && e.2 {
2230 e.3 += 1;
2231 self.total_hits += 1;
2232 return Some(&e.1);
2233 }
2234 }
2235 self.total_misses += 1;
2236 None
2237 }
2238 #[allow(dead_code)]
2239 pub fn put(&mut self, key: u64, data: Vec<u8>) {
2240 if self.entries.len() >= self.cap {
2241 self.entries.retain(|e| e.2);
2242 if self.entries.len() >= self.cap {
2243 self.entries.remove(0);
2244 }
2245 }
2246 self.entries.push((key, data, true, 0));
2247 }
2248 #[allow(dead_code)]
2249 pub fn invalidate(&mut self) {
2250 for e in self.entries.iter_mut() {
2251 e.2 = false;
2252 }
2253 }
2254 #[allow(dead_code)]
2255 pub fn hit_rate(&self) -> f64 {
2256 let t = self.total_hits + self.total_misses;
2257 if t == 0 {
2258 0.0
2259 } else {
2260 self.total_hits as f64 / t as f64
2261 }
2262 }
2263 #[allow(dead_code)]
2264 pub fn live_count(&self) -> usize {
2265 self.entries.iter().filter(|e| e.2).count()
2266 }
2267}
2268#[allow(dead_code)]
2269#[derive(Debug, Clone)]
2270pub struct MetalDepGraph {
2271 pub(super) nodes: Vec<u32>,
2272 pub(super) edges: Vec<(u32, u32)>,
2273}
2274impl MetalDepGraph {
2275 #[allow(dead_code)]
2276 pub fn new() -> Self {
2277 MetalDepGraph {
2278 nodes: Vec::new(),
2279 edges: Vec::new(),
2280 }
2281 }
2282 #[allow(dead_code)]
2283 pub fn add_node(&mut self, id: u32) {
2284 if !self.nodes.contains(&id) {
2285 self.nodes.push(id);
2286 }
2287 }
2288 #[allow(dead_code)]
2289 pub fn add_dep(&mut self, dep: u32, dependent: u32) {
2290 self.add_node(dep);
2291 self.add_node(dependent);
2292 self.edges.push((dep, dependent));
2293 }
2294 #[allow(dead_code)]
2295 pub fn dependents_of(&self, node: u32) -> Vec<u32> {
2296 self.edges
2297 .iter()
2298 .filter(|(d, _)| *d == node)
2299 .map(|(_, dep)| *dep)
2300 .collect()
2301 }
2302 #[allow(dead_code)]
2303 pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
2304 self.edges
2305 .iter()
2306 .filter(|(_, dep)| *dep == node)
2307 .map(|(d, _)| *d)
2308 .collect()
2309 }
2310 #[allow(dead_code)]
2311 pub fn topological_sort(&self) -> Vec<u32> {
2312 let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
2313 for &n in &self.nodes {
2314 in_degree.insert(n, 0);
2315 }
2316 for (_, dep) in &self.edges {
2317 *in_degree.entry(*dep).or_insert(0) += 1;
2318 }
2319 let mut queue: std::collections::VecDeque<u32> = self
2320 .nodes
2321 .iter()
2322 .filter(|&&n| in_degree[&n] == 0)
2323 .copied()
2324 .collect();
2325 let mut result = Vec::new();
2326 while let Some(node) = queue.pop_front() {
2327 result.push(node);
2328 for dep in self.dependents_of(node) {
2329 let cnt = in_degree.entry(dep).or_insert(0);
2330 *cnt = cnt.saturating_sub(1);
2331 if *cnt == 0 {
2332 queue.push_back(dep);
2333 }
2334 }
2335 }
2336 result
2337 }
2338 #[allow(dead_code)]
2339 pub fn has_cycle(&self) -> bool {
2340 self.topological_sort().len() < self.nodes.len()
2341 }
2342}
2343#[allow(dead_code)]
2345#[derive(Debug, Clone)]
2346pub struct MetalExtDomTree {
2347 pub(super) idom: Vec<Option<usize>>,
2348 pub(super) children: Vec<Vec<usize>>,
2349 pub(super) depth: Vec<usize>,
2350}
2351impl MetalExtDomTree {
2352 #[allow(dead_code)]
2353 pub fn new(n: usize) -> Self {
2354 Self {
2355 idom: vec![None; n],
2356 children: vec![Vec::new(); n],
2357 depth: vec![0; n],
2358 }
2359 }
2360 #[allow(dead_code)]
2361 pub fn set_idom(&mut self, node: usize, dom: usize) {
2362 if node < self.idom.len() {
2363 self.idom[node] = Some(dom);
2364 if dom < self.children.len() {
2365 self.children[dom].push(node);
2366 }
2367 self.depth[node] = if dom < self.depth.len() {
2368 self.depth[dom] + 1
2369 } else {
2370 1
2371 };
2372 }
2373 }
2374 #[allow(dead_code)]
2375 pub fn dominates(&self, a: usize, mut b: usize) -> bool {
2376 if a == b {
2377 return true;
2378 }
2379 let n = self.idom.len();
2380 for _ in 0..n {
2381 match self.idom.get(b).copied().flatten() {
2382 None => return false,
2383 Some(p) if p == a => return true,
2384 Some(p) if p == b => return false,
2385 Some(p) => b = p,
2386 }
2387 }
2388 false
2389 }
2390 #[allow(dead_code)]
2391 pub fn children_of(&self, n: usize) -> &[usize] {
2392 self.children.get(n).map(|v| v.as_slice()).unwrap_or(&[])
2393 }
2394 #[allow(dead_code)]
2395 pub fn depth_of(&self, n: usize) -> usize {
2396 self.depth.get(n).copied().unwrap_or(0)
2397 }
2398 #[allow(dead_code)]
2399 pub fn lca(&self, mut a: usize, mut b: usize) -> usize {
2400 let n = self.idom.len();
2401 for _ in 0..(2 * n) {
2402 if a == b {
2403 return a;
2404 }
2405 if self.depth_of(a) > self.depth_of(b) {
2406 a = self.idom.get(a).and_then(|x| *x).unwrap_or(a);
2407 } else {
2408 b = self.idom.get(b).and_then(|x| *x).unwrap_or(b);
2409 }
2410 }
2411 0
2412 }
2413}
2414#[derive(Debug, Default)]
2416pub struct MetalExtIdGen {
2417 pub(super) next: u32,
2418}
2419impl MetalExtIdGen {
2420 pub fn new() -> Self {
2421 MetalExtIdGen::default()
2422 }
2423 pub fn next_id(&mut self) -> u32 {
2424 let id = self.next;
2425 self.next += 1;
2426 id
2427 }
2428 pub fn peek_next(&self) -> u32 {
2429 self.next
2430 }
2431 pub fn reset(&mut self) {
2432 self.next = 0;
2433 }
2434 pub fn skip(&mut self, n: u32) {
2435 self.next += n;
2436 }
2437}