1use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9 nodes::*, BlockId, CapabilityFlag, Dimension, IrModule, IrNode, IrType, ScalarType, Terminator,
10 ValueId,
11};
12
13#[derive(Debug, Clone)]
15pub struct MslLoweringConfig {
16 pub metal_version: (u32, u32),
18 pub simd_groups: bool,
20 pub threadgroup_size: (u32, u32, u32),
22 pub indirect_commands: bool,
24 pub enable_hlc: bool,
26 pub enable_k2k: bool,
28 pub debug: bool,
30}
31
32impl Default for MslLoweringConfig {
33 fn default() -> Self {
34 Self {
35 metal_version: (2, 4),
36 threadgroup_size: (256, 1, 1),
37 simd_groups: true,
38 indirect_commands: false,
39 enable_hlc: false,
40 enable_k2k: false,
41 debug: false,
42 }
43 }
44}
45
46impl MslLoweringConfig {
47 pub fn metal3() -> Self {
49 Self {
50 metal_version: (3, 0),
51 simd_groups: true,
52 indirect_commands: true,
53 ..Default::default()
54 }
55 }
56
57 pub fn with_threadgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
59 self.threadgroup_size = (x, y, z);
60 self
61 }
62
63 pub fn with_persistent(mut self) -> Self {
65 self.enable_hlc = true;
66 self.enable_k2k = true;
67 self
68 }
69}
70
71pub struct MslLowering {
73 config: MslLoweringConfig,
74 output: String,
75 indent: usize,
76 value_names: HashMap<ValueId, String>,
77 name_counter: usize,
78 block_labels: HashMap<BlockId, String>,
79}
80
81impl MslLowering {
82 pub fn new(config: MslLoweringConfig) -> Self {
84 Self {
85 config,
86 output: String::new(),
87 indent: 0,
88 value_names: HashMap::new(),
89 name_counter: 0,
90 block_labels: HashMap::new(),
91 }
92 }
93
94 pub fn lower(mut self, module: &IrModule) -> Result<String, MslLoweringError> {
96 self.check_capabilities(module)?;
98
99 self.emit_header();
101
102 self.emit_type_definitions(module);
104
105 self.emit_kernel(module)?;
107
108 Ok(self.output)
109 }
110
111 fn check_capabilities(&self, module: &IrModule) -> Result<(), MslLoweringError> {
112 if module.required_capabilities.has(CapabilityFlag::Float64) {
114 return Err(MslLoweringError::UnsupportedCapability(
115 "f64 not supported in Metal (will downcast to f32)".to_string(),
116 ));
117 }
118
119 if module
121 .required_capabilities
122 .has(CapabilityFlag::CooperativeGroups)
123 {
124 return Err(MslLoweringError::UnsupportedCapability(
125 "Grid-wide sync not supported in Metal".to_string(),
126 ));
127 }
128
129 Ok(())
130 }
131
132 fn emit_header(&mut self) {
133 self.emit_line("// Generated by ringkernel-ir MSL lowering");
134 self.emit_line("#include <metal_stdlib>");
135 self.emit_line("#include <simdgroup_matrix>");
136 self.emit_line("using namespace metal;");
137 self.emit_line("");
138 }
139
140 fn emit_type_definitions(&mut self, _module: &IrModule) {
141 if self.config.enable_hlc {
143 self.emit_line("// HLC Timestamp");
144 self.emit_line("struct HlcTimestamp {");
145 self.indent += 1;
146 self.emit_line("uint64_t physical;");
147 self.emit_line("uint64_t logical;");
148 self.emit_line("uint64_t node_id;");
149 self.indent -= 1;
150 self.emit_line("};");
151 self.emit_line("");
152
153 self.emit_line("// HLC Intrinsics (provided by runtime)");
155 self.emit_line("uint64_t ringkernel_hlc_now();");
156 self.emit_line("uint64_t ringkernel_hlc_tick();");
157 self.emit_line("uint64_t ringkernel_hlc_update(uint64_t incoming);");
158 self.emit_line("");
159 }
160
161 if self.config.enable_k2k {
163 self.emit_line("// Control Block");
165 self.emit_line("struct ControlBlock {");
166 self.indent += 1;
167 self.emit_line("uint32_t is_active;");
168 self.emit_line("uint32_t should_terminate;");
169 self.emit_line("uint32_t has_terminated;");
170 self.emit_line("uint32_t _pad1;");
171 self.emit_line("uint64_t messages_processed;");
172 self.emit_line("uint64_t messages_in_flight;");
173 self.emit_line("uint64_t input_head;");
174 self.emit_line("uint64_t input_tail;");
175 self.emit_line("uint64_t output_head;");
176 self.emit_line("uint64_t output_tail;");
177 self.emit_line("uint32_t input_capacity;");
178 self.emit_line("uint32_t output_capacity;");
179 self.emit_line("uint32_t input_mask;");
180 self.emit_line("uint32_t output_mask;");
181 self.indent -= 1;
182 self.emit_line("};");
183 self.emit_line("");
184
185 self.emit_line("// Queue Intrinsics (provided by runtime)");
187 self.emit_line("bool ringkernel_k2h_enqueue(const device void* msg);");
188 self.emit_line("device void* ringkernel_h2k_dequeue();");
189 self.emit_line("bool ringkernel_h2k_is_empty();");
190 self.emit_line("");
191
192 self.emit_line("// K2K Messaging Intrinsics (provided by runtime)");
194 self.emit_line("bool ringkernel_k2k_send(uint64_t target_id, const device void* msg);");
195 self.emit_line("device void* ringkernel_k2k_recv();");
196 self.emit_line("struct K2KOptionalMsg { bool valid; device void* data; };");
197 self.emit_line("K2KOptionalMsg ringkernel_k2k_try_recv();");
198 self.emit_line("");
199 }
200 }
201
202 fn emit_kernel(&mut self, module: &IrModule) -> Result<(), MslLoweringError> {
203 self.assign_names(module);
205
206 self.emit_line("kernel void");
208 writeln!(self.output, "{}(", module.name).unwrap();
209 self.indent += 1;
210
211 for (buffer_idx, param) in module.parameters.iter().enumerate() {
213 let ty = self.lower_type(¶m.ty);
214 let qualifier = if param.ty.is_ptr() {
215 "device"
216 } else {
217 "constant"
218 };
219 self.emit_line(&format!(
220 "{} {}& {} [[buffer({})]],",
221 qualifier, ty, param.name, buffer_idx
222 ));
223 }
224
225 self.emit_line("uint3 thread_position_in_grid [[thread_position_in_grid]],");
227 self.emit_line("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],");
228 self.emit_line("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],");
229 self.emit_line("uint3 threads_per_threadgroup [[threads_per_threadgroup]],");
230 self.emit_line("uint3 threadgroups_per_grid [[threadgroups_per_grid]],");
231 self.emit_line("uint thread_index_in_simdgroup [[thread_index_in_simdgroup]],");
232 self.emit_line("uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]]");
233
234 self.indent -= 1;
235 self.emit_line(") {");
236 self.indent += 1;
237
238 self.emit_block(module, module.entry_block)?;
240
241 self.indent -= 1;
242 self.emit_line("}");
243
244 Ok(())
245 }
246
247 fn assign_names(&mut self, module: &IrModule) {
248 for param in &module.parameters {
249 self.value_names.insert(param.value_id, param.name.clone());
250 }
251
252 for (block_id, block) in &module.blocks {
253 self.block_labels.insert(*block_id, block.label.clone());
254 }
255 }
256
257 fn emit_block(&mut self, module: &IrModule, block_id: BlockId) -> Result<(), MslLoweringError> {
258 let block = module
259 .blocks
260 .get(&block_id)
261 .ok_or(MslLoweringError::UndefinedBlock(block_id))?;
262
263 if block_id != module.entry_block {
265 self.emit_line(&format!("{}: {{", block.label));
266 self.indent += 1;
267 }
268
269 for inst in &block.instructions {
271 self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
272 }
273
274 if let Some(term) = &block.terminator {
276 self.emit_terminator(module, term)?;
277 }
278
279 if block_id != module.entry_block {
280 self.indent -= 1;
281 self.emit_line("}");
282 }
283
284 Ok(())
285 }
286
287 fn emit_instruction(
288 &mut self,
289 _module: &IrModule,
290 result: &ValueId,
291 result_type: &IrType,
292 node: &IrNode,
293 ) -> Result<(), MslLoweringError> {
294 let result_name = self.get_or_create_name(*result);
295 let ty = self.lower_type(result_type);
296
297 match node {
298 IrNode::Constant(c) => {
300 let val = self.lower_constant(c);
301 self.emit_line(&format!("{} {} = {};", ty, result_name, val));
302 }
303
304 IrNode::BinaryOp(op, lhs, rhs) => {
306 let lhs_name = self.get_value_name(*lhs);
307 let rhs_name = self.get_value_name(*rhs);
308 let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
309 self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
310 }
311
312 IrNode::UnaryOp(op, val) => {
314 let val_name = self.get_value_name(*val);
315 let expr = self.lower_unary_op(op, &val_name);
316 self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
317 }
318
319 IrNode::Compare(op, lhs, rhs) => {
321 let lhs_name = self.get_value_name(*lhs);
322 let rhs_name = self.get_value_name(*rhs);
323 let cmp_op = self.lower_compare_op(op);
324 self.emit_line(&format!(
325 "bool {} = {} {} {};",
326 result_name, lhs_name, cmp_op, rhs_name
327 ));
328 }
329
330 IrNode::Load(ptr) => {
332 let ptr_name = self.get_value_name(*ptr);
333 self.emit_line(&format!("{} {} = {};", ty, result_name, ptr_name));
334 }
335
336 IrNode::Store(ptr, val) => {
337 let ptr_name = self.get_value_name(*ptr);
338 let val_name = self.get_value_name(*val);
339 self.emit_line(&format!("{} = {};", ptr_name, val_name));
340 }
341
342 IrNode::GetElementPtr(ptr, indices) => {
343 let ptr_name = self.get_value_name(*ptr);
344 let idx_name = self.get_value_name(indices[0]);
345 self.emit_line(&format!(
346 "{} {} = {}[{}];",
347 ty, result_name, ptr_name, idx_name
348 ));
349 }
350
351 IrNode::SharedAlloc(elem_ty, count) => {
352 let elem = self.lower_type(elem_ty);
353 self.emit_line(&format!("threadgroup {} {}[{}];", elem, result_name, count));
354 }
355
356 IrNode::ThreadId(dim) => {
358 let idx = self.lower_dimension(dim, "thread_position_in_threadgroup");
359 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
360 }
361
362 IrNode::BlockId(dim) => {
363 let idx = self.lower_dimension(dim, "threadgroup_position_in_grid");
364 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
365 }
366
367 IrNode::BlockDim(dim) => {
368 let idx = self.lower_dimension(dim, "threads_per_threadgroup");
369 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
370 }
371
372 IrNode::GridDim(dim) => {
373 let idx = self.lower_dimension(dim, "threadgroups_per_grid");
374 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
375 }
376
377 IrNode::GlobalThreadId(dim) => {
378 let idx = self.lower_dimension(dim, "thread_position_in_grid");
379 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
380 }
381
382 IrNode::WarpId => {
383 self.emit_line(&format!(
384 "{} {} = simdgroup_index_in_threadgroup;",
385 ty, result_name
386 ));
387 }
388
389 IrNode::LaneId => {
390 self.emit_line(&format!(
391 "{} {} = thread_index_in_simdgroup;",
392 ty, result_name
393 ));
394 }
395
396 IrNode::Barrier => {
398 self.emit_line("threadgroup_barrier(mem_flags::mem_threadgroup);");
399 }
400
401 IrNode::MemoryFence(scope) => {
402 let fence = match scope {
403 MemoryScope::Thread => "threadgroup_barrier(mem_flags::mem_none)",
404 MemoryScope::Threadgroup => "threadgroup_barrier(mem_flags::mem_threadgroup)",
405 MemoryScope::Device => "threadgroup_barrier(mem_flags::mem_device)",
406 MemoryScope::System => "threadgroup_barrier(mem_flags::mem_device)",
407 };
408 self.emit_line(&format!("{};", fence));
409 }
410
411 IrNode::GridSync => {
412 return Err(MslLoweringError::UnsupportedOperation(
413 "Grid sync not supported in Metal".to_string(),
414 ));
415 }
416
417 IrNode::Atomic(op, ptr, val) => {
419 let ptr_name = self.get_value_name(*ptr);
420 let val_name = self.get_value_name(*val);
421 let atomic_fn = match op {
422 AtomicOp::Add => "atomic_fetch_add_explicit",
423 AtomicOp::Sub => "atomic_fetch_sub_explicit",
424 AtomicOp::Exchange => "atomic_exchange_explicit",
425 AtomicOp::Min => "atomic_fetch_min_explicit",
426 AtomicOp::Max => "atomic_fetch_max_explicit",
427 AtomicOp::And => "atomic_fetch_and_explicit",
428 AtomicOp::Or => "atomic_fetch_or_explicit",
429 AtomicOp::Xor => "atomic_fetch_xor_explicit",
430 AtomicOp::Load => {
431 self.emit_line(&format!(
432 "{} {} = atomic_load_explicit(&{}, memory_order_relaxed);",
433 ty, result_name, ptr_name
434 ));
435 return Ok(());
436 }
437 AtomicOp::Store => {
438 self.emit_line(&format!(
439 "atomic_store_explicit(&{}, {}, memory_order_relaxed);",
440 ptr_name, val_name
441 ));
442 return Ok(());
443 }
444 };
445 self.emit_line(&format!(
446 "{} {} = {}(&{}, {}, memory_order_relaxed);",
447 ty, result_name, atomic_fn, ptr_name, val_name
448 ));
449 }
450
451 IrNode::AtomicCas(ptr, expected, desired) => {
452 let ptr_name = self.get_value_name(*ptr);
453 let exp_name = self.get_value_name(*expected);
454 let des_name = self.get_value_name(*desired);
455 self.emit_line(&format!("{} {} = {};", ty, result_name, exp_name));
456 self.emit_line(&format!(
457 "atomic_compare_exchange_weak_explicit(&{}, &{}, {}, memory_order_relaxed, memory_order_relaxed);",
458 ptr_name, result_name, des_name
459 ));
460 }
461
462 IrNode::WarpVote(op, val) => {
464 if !self.config.simd_groups {
465 return Err(MslLoweringError::UnsupportedOperation(
466 "SIMD group operations require simd_groups feature".to_string(),
467 ));
468 }
469 let val_name = self.get_value_name(*val);
470 let vote_fn = match op {
471 WarpVoteOp::All => "simd_all",
472 WarpVoteOp::Any => "simd_any",
473 WarpVoteOp::Ballot => "simd_ballot",
474 };
475 self.emit_line(&format!(
476 "{} {} = {}({});",
477 ty, result_name, vote_fn, val_name
478 ));
479 }
480
481 IrNode::WarpShuffle(op, val, lane) => {
482 if !self.config.simd_groups {
483 return Err(MslLoweringError::UnsupportedOperation(
484 "SIMD shuffle requires simd_groups feature".to_string(),
485 ));
486 }
487 let val_name = self.get_value_name(*val);
488 let lane_name = self.get_value_name(*lane);
489 let shfl_fn = match op {
490 WarpShuffleOp::Index => "simd_shuffle",
491 WarpShuffleOp::Up => "simd_shuffle_up",
492 WarpShuffleOp::Down => "simd_shuffle_down",
493 WarpShuffleOp::Xor => "simd_shuffle_xor",
494 };
495 self.emit_line(&format!(
496 "{} {} = {}({}, {});",
497 ty, result_name, shfl_fn, val_name, lane_name
498 ));
499 }
500
501 IrNode::Select(cond, then_val, else_val) => {
503 let cond_name = self.get_value_name(*cond);
504 let then_name = self.get_value_name(*then_val);
505 let else_name = self.get_value_name(*else_val);
506 self.emit_line(&format!(
507 "{} {} = select({}, {}, {});",
508 ty, result_name, else_name, then_name, cond_name
509 ));
510 }
511
512 IrNode::Math(op, args) => {
514 let fn_name = self.lower_math_op(op);
515 let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
516 self.emit_line(&format!(
517 "{} {} = {}({});",
518 ty,
519 result_name,
520 fn_name,
521 args_str.join(", ")
522 ));
523 }
524
525 IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
527
528 IrNode::K2HEnqueue(value) => {
534 let val_name = self.get_value_name(*value);
535 self.emit_line(&format!(
537 "{} {} = ringkernel_k2h_enqueue({});",
538 ty, result_name, val_name
539 ));
540 }
541
542 IrNode::H2KDequeue => {
544 self.emit_line(&format!(
546 "{} {} = ringkernel_h2k_dequeue();",
547 ty, result_name
548 ));
549 }
550
551 IrNode::H2KIsEmpty => {
553 self.emit_line(&format!(
555 "{} {} = ringkernel_h2k_is_empty();",
556 ty, result_name
557 ));
558 }
559
560 IrNode::K2KSend(target_id, message) => {
562 let target_name = self.get_value_name(*target_id);
563 let msg_name = self.get_value_name(*message);
564 self.emit_line(&format!(
566 "{} {} = ringkernel_k2k_send({}, {});",
567 ty, result_name, target_name, msg_name
568 ));
569 }
570
571 IrNode::K2KRecv => {
573 self.emit_line(&format!("{} {} = ringkernel_k2k_recv();", ty, result_name));
575 }
576
577 IrNode::K2KTryRecv => {
579 self.emit_line(&format!(
581 "{} {} = ringkernel_k2k_try_recv();",
582 ty, result_name
583 ));
584 }
585
586 IrNode::HlcNow => {
592 self.emit_line(&format!("{} {} = ringkernel_hlc_now();", ty, result_name));
594 }
595
596 IrNode::HlcTick => {
598 self.emit_line(&format!("{} {} = ringkernel_hlc_tick();", ty, result_name));
600 }
601
602 IrNode::HlcUpdate(incoming) => {
604 let incoming_name = self.get_value_name(*incoming);
605 self.emit_line(&format!(
607 "{} {} = ringkernel_hlc_update({});",
608 ty, result_name, incoming_name
609 ));
610 }
611
612 _ => {
613 self.emit_line(&format!("// Unhandled: {:?}", node));
614 }
615 }
616
617 Ok(())
618 }
619
620 fn emit_terminator(
621 &mut self,
622 _module: &IrModule,
623 term: &Terminator,
624 ) -> Result<(), MslLoweringError> {
625 match term {
626 Terminator::Return(None) => {
627 self.emit_line("return;");
628 }
629 Terminator::Return(Some(val)) => {
630 let val_name = self.get_value_name(*val);
631 self.emit_line(&format!("// Return: {}", val_name));
632 self.emit_line("return;");
633 }
634 Terminator::Branch(target) => {
635 let label = self.block_labels.get(target).cloned().unwrap_or_default();
636 self.emit_line(&format!("goto {};", label));
637 }
638 Terminator::CondBranch(cond, then_block, else_block) => {
639 let cond_name = self.get_value_name(*cond);
640 let then_label = self
641 .block_labels
642 .get(then_block)
643 .cloned()
644 .unwrap_or_default();
645 let else_label = self
646 .block_labels
647 .get(else_block)
648 .cloned()
649 .unwrap_or_default();
650 self.emit_line(&format!(
651 "if ({}) goto {}; else goto {};",
652 cond_name, then_label, else_label
653 ));
654 }
655 Terminator::Switch(val, default, cases) => {
656 let val_name = self.get_value_name(*val);
657 self.emit_line(&format!("switch ({}) {{", val_name));
658 self.indent += 1;
659 for (case_val, target) in cases {
660 let case_str = self.lower_constant(case_val);
661 let label = self.block_labels.get(target).cloned().unwrap_or_default();
662 self.emit_line(&format!("case {}: goto {};", case_str, label));
663 }
664 let default_label = self.block_labels.get(default).cloned().unwrap_or_default();
665 self.emit_line(&format!("default: goto {};", default_label));
666 self.indent -= 1;
667 self.emit_line("}");
668 }
669 Terminator::Unreachable => {
670 self.emit_line("// unreachable");
671 }
672 }
673 Ok(())
674 }
675
676 fn lower_type(&self, ty: &IrType) -> String {
677 match ty {
678 IrType::Void => "void".to_string(),
679 IrType::Scalar(s) => self.lower_scalar_type(s),
680 IrType::Vector(v) => format!("{}{}", self.lower_scalar_type(&v.element), v.count),
681 IrType::Ptr(inner) => format!("device {}*", self.lower_type(inner)),
682 IrType::Array(inner, size) => format!("array<{}, {}>", self.lower_type(inner), size),
683 IrType::Slice(inner) => format!("device {}*", self.lower_type(inner)),
684 IrType::Struct(s) => s.name.clone(),
685 IrType::Function(_) => "void*".to_string(),
686 }
687 }
688
689 fn lower_scalar_type(&self, ty: &ScalarType) -> String {
690 match ty {
691 ScalarType::Bool => "bool",
692 ScalarType::I8 => "char",
693 ScalarType::I16 => "short",
694 ScalarType::I32 => "int",
695 ScalarType::I64 => "long",
696 ScalarType::U8 => "uchar",
697 ScalarType::U16 => "ushort",
698 ScalarType::U32 => "uint",
699 ScalarType::U64 => "ulong",
700 ScalarType::F16 => "half",
701 ScalarType::F32 => "float",
702 ScalarType::F64 => "float", }
704 .to_string()
705 }
706
707 fn lower_constant(&self, c: &ConstantValue) -> String {
708 match c {
709 ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
710 ConstantValue::I32(v) => format!("{}", v),
711 ConstantValue::I64(v) => format!("{}L", v),
712 ConstantValue::U32(v) => format!("{}u", v),
713 ConstantValue::U64(v) => format!("{}uL", v),
714 ConstantValue::F32(v) => format!("{}f", v),
715 ConstantValue::F64(v) => format!("{}f", *v as f32), ConstantValue::Null => "nullptr".to_string(),
717 ConstantValue::Array(elems) => {
718 let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
719 format!("{{{}}}", elems_str.join(", "))
720 }
721 ConstantValue::Struct(fields) => {
722 let fields_str: Vec<String> =
723 fields.iter().map(|f| self.lower_constant(f)).collect();
724 format!("{{{}}}", fields_str.join(", "))
725 }
726 }
727 }
728
729 fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
730 match op {
731 BinaryOp::Add => format!("{} + {}", lhs, rhs),
732 BinaryOp::Sub => format!("{} - {}", lhs, rhs),
733 BinaryOp::Mul => format!("{} * {}", lhs, rhs),
734 BinaryOp::Div => format!("{} / {}", lhs, rhs),
735 BinaryOp::Rem => format!("{} % {}", lhs, rhs),
736 BinaryOp::And => format!("{} & {}", lhs, rhs),
737 BinaryOp::Or => format!("{} | {}", lhs, rhs),
738 BinaryOp::Xor => format!("{} ^ {}", lhs, rhs),
739 BinaryOp::Shl => format!("{} << {}", lhs, rhs),
740 BinaryOp::Shr => format!("{} >> {}", lhs, rhs),
741 BinaryOp::Sar => format!("{} >> {}", lhs, rhs),
742 BinaryOp::Fma => format!("fma({}, {}, 0.0f)", lhs, rhs),
743 BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
744 BinaryOp::Min => format!("min({}, {})", lhs, rhs),
745 BinaryOp::Max => format!("max({}, {})", lhs, rhs),
746 }
747 }
748
749 fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
750 match op {
751 UnaryOp::Neg => format!("-{}", val),
752 UnaryOp::Not => format!("~{}", val),
753 UnaryOp::LogicalNot => format!("!{}", val),
754 UnaryOp::Abs => format!("abs({})", val),
755 UnaryOp::Sqrt => format!("sqrt({})", val),
756 UnaryOp::Rsqrt => format!("rsqrt({})", val),
757 UnaryOp::Floor => format!("floor({})", val),
758 UnaryOp::Ceil => format!("ceil({})", val),
759 UnaryOp::Round => format!("round({})", val),
760 UnaryOp::Trunc => format!("trunc({})", val),
761 UnaryOp::Sign => format!("sign({})", val),
762 }
763 }
764
765 fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
766 match op {
767 CompareOp::Eq => "==",
768 CompareOp::Ne => "!=",
769 CompareOp::Lt => "<",
770 CompareOp::Le => "<=",
771 CompareOp::Gt => ">",
772 CompareOp::Ge => ">=",
773 }
774 }
775
776 fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
777 match dim {
778 Dimension::X => format!("{}.x", prefix),
779 Dimension::Y => format!("{}.y", prefix),
780 Dimension::Z => format!("{}.z", prefix),
781 }
782 }
783
784 fn lower_math_op(&self, op: &MathOp) -> &'static str {
785 match op {
786 MathOp::Sin => "sin",
787 MathOp::Cos => "cos",
788 MathOp::Tan => "tan",
789 MathOp::Asin => "asin",
790 MathOp::Acos => "acos",
791 MathOp::Atan => "atan",
792 MathOp::Atan2 => "atan2",
793 MathOp::Sinh => "sinh",
794 MathOp::Cosh => "cosh",
795 MathOp::Tanh => "tanh",
796 MathOp::Exp => "exp",
797 MathOp::Exp2 => "exp2",
798 MathOp::Log => "log",
799 MathOp::Log2 => "log2",
800 MathOp::Log10 => "log10",
801 MathOp::Lerp => "mix",
802 MathOp::Clamp => "clamp",
803 MathOp::Step => "step",
804 MathOp::SmoothStep => "smoothstep",
805 MathOp::Fract => "fract",
806 MathOp::CopySign => "copysign",
807 }
808 }
809
810 fn get_value_name(&self, id: ValueId) -> String {
811 self.value_names
812 .get(&id)
813 .cloned()
814 .unwrap_or_else(|| format!("v{}", id.raw()))
815 }
816
817 fn get_or_create_name(&mut self, id: ValueId) -> String {
818 if let Some(name) = self.value_names.get(&id) {
819 return name.clone();
820 }
821 let name = format!("t{}", self.name_counter);
822 self.name_counter += 1;
823 self.value_names.insert(id, name.clone());
824 name
825 }
826
827 fn emit_line(&mut self, line: &str) {
828 let indent = " ".repeat(self.indent);
829 writeln!(self.output, "{}{}", indent, line).unwrap();
830 }
831}
832
833#[derive(Debug, Clone)]
835pub enum MslLoweringError {
836 UnsupportedCapability(String),
838 UnsupportedOperation(String),
840 UndefinedBlock(BlockId),
842 UndefinedValue(ValueId),
844 TypeError(String),
846}
847
848impl std::fmt::Display for MslLoweringError {
849 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
850 match self {
851 MslLoweringError::UnsupportedCapability(cap) => {
852 write!(f, "Unsupported capability: {}", cap)
853 }
854 MslLoweringError::UnsupportedOperation(op) => {
855 write!(f, "Unsupported operation: {}", op)
856 }
857 MslLoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
858 MslLoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
859 MslLoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
860 }
861 }
862}
863
864impl std::error::Error for MslLoweringError {}
865
866pub fn lower_to_msl(module: &IrModule) -> Result<String, MslLoweringError> {
868 MslLowering::new(MslLoweringConfig::default()).lower(module)
869}
870
871pub fn lower_to_msl_with_config(
873 module: &IrModule,
874 config: MslLoweringConfig,
875) -> Result<String, MslLoweringError> {
876 MslLowering::new(config).lower(module)
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882 use crate::IrBuilder;
883
884 #[test]
885 fn test_lower_simple_kernel() {
886 let mut builder = IrBuilder::new("add_one");
887
888 let _x = builder.parameter("x", IrType::ptr(IrType::F32));
889 let _n = builder.parameter("n", IrType::I32);
890
891 let idx = builder.global_thread_id(Dimension::X);
892 let _ = idx;
893
894 builder.ret();
895
896 let module = builder.build();
897 let msl = lower_to_msl(&module).unwrap();
898
899 assert!(msl.contains("kernel void"));
900 assert!(msl.contains("add_one"));
901 assert!(msl.contains("thread_position_in_grid"));
902 }
903
904 #[test]
905 fn test_lower_with_threadgroup_memory() {
906 let mut builder = IrBuilder::new("reduce");
907
908 let shared = builder.shared_alloc(IrType::F32, 256);
909 let _ = shared;
910
911 builder.barrier();
912 builder.ret();
913
914 let module = builder.build();
915 let msl = lower_to_msl(&module).unwrap();
916
917 assert!(msl.contains("threadgroup float"));
918 assert!(msl.contains("threadgroup_barrier"));
919 }
920
921 #[test]
922 fn test_lower_with_simd_ops() {
923 let mut builder = IrBuilder::new("simd");
924
925 let val = builder.const_bool(true);
926 let _ = val;
927
928 builder.ret();
929
930 let module = builder.build();
931 let config = MslLoweringConfig::metal3();
932 let msl = lower_to_msl_with_config(&module, config).unwrap();
933
934 assert!(msl.contains("#include <metal_stdlib>"));
935 }
936
937 #[test]
938 fn test_lower_with_atomics() {
939 let mut builder = IrBuilder::new("atomic");
940
941 let counter = builder.parameter("counter", IrType::ptr(IrType::U32));
942 let one = builder.const_u32(1);
943 let _old = builder.atomic_add(counter, one);
944
945 builder.ret();
946
947 let module = builder.build();
948 let msl = lower_to_msl(&module).unwrap();
949
950 assert!(msl.contains("atomic_fetch_add_explicit"));
951 }
952
953 #[test]
954 fn test_lower_rejects_grid_sync() {
955 let mut builder = IrBuilder::new("grid");
956 builder.grid_sync();
957 builder.ret();
958
959 let module = builder.build();
960 let result = lower_to_msl(&module);
961
962 assert!(result.is_err());
963 }
964}