Skip to main content

ringkernel_ir/
lower_msl.rs

1//! IR to MSL (Metal Shading Language) lowering pass.
2//!
3//! Lowers IR to Metal Shading Language for Apple GPU compute.
4
5use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9    nodes::*, BlockId, CapabilityFlag, Dimension, IrModule, IrNode, IrType, ScalarType, Terminator,
10    ValueId,
11};
12
13/// MSL lowering configuration.
14#[derive(Debug, Clone)]
15pub struct MslLoweringConfig {
16    /// Metal language version (e.g., 2.4, 3.0).
17    pub metal_version: (u32, u32),
18    /// Enable SIMD group operations.
19    pub simd_groups: bool,
20    /// Threadgroup size.
21    pub threadgroup_size: (u32, u32, u32),
22    /// Enable indirect command buffers.
23    pub indirect_commands: bool,
24    /// Enable HLC (Hybrid Logical Clocks).
25    pub enable_hlc: bool,
26    /// Enable K2K messaging.
27    pub enable_k2k: bool,
28    /// Generate debug comments.
29    pub debug: bool,
30}
31
32impl Default for MslLoweringConfig {
33    fn default() -> Self {
34        Self {
35            metal_version: (2, 4),
36            threadgroup_size: (256, 1, 1),
37            simd_groups: true,
38            indirect_commands: false,
39            enable_hlc: false,
40            enable_k2k: false,
41            debug: false,
42        }
43    }
44}
45
46impl MslLoweringConfig {
47    /// Create config for Metal 3.0.
48    pub fn metal3() -> Self {
49        Self {
50            metal_version: (3, 0),
51            simd_groups: true,
52            indirect_commands: true,
53            ..Default::default()
54        }
55    }
56
57    /// Set threadgroup size.
58    pub fn with_threadgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
59        self.threadgroup_size = (x, y, z);
60        self
61    }
62
63    /// Enable persistent kernel features.
64    pub fn with_persistent(mut self) -> Self {
65        self.enable_hlc = true;
66        self.enable_k2k = true;
67        self
68    }
69}
70
71/// MSL code generator.
72pub struct MslLowering {
73    config: MslLoweringConfig,
74    output: String,
75    indent: usize,
76    value_names: HashMap<ValueId, String>,
77    name_counter: usize,
78    block_labels: HashMap<BlockId, String>,
79}
80
81impl MslLowering {
82    /// Create a new MSL lowering pass.
83    pub fn new(config: MslLoweringConfig) -> Self {
84        Self {
85            config,
86            output: String::new(),
87            indent: 0,
88            value_names: HashMap::new(),
89            name_counter: 0,
90            block_labels: HashMap::new(),
91        }
92    }
93
94    /// Lower an IR module to MSL code.
95    pub fn lower(mut self, module: &IrModule) -> Result<String, MslLoweringError> {
96        // Check capabilities
97        self.check_capabilities(module)?;
98
99        // Generate header
100        self.emit_header();
101
102        // Generate type definitions
103        self.emit_type_definitions(module);
104
105        // Generate kernel
106        self.emit_kernel(module)?;
107
108        Ok(self.output)
109    }
110
111    fn check_capabilities(&self, module: &IrModule) -> Result<(), MslLoweringError> {
112        // Metal doesn't support f64
113        if module.required_capabilities.has(CapabilityFlag::Float64) {
114            return Err(MslLoweringError::UnsupportedCapability(
115                "f64 not supported in Metal (will downcast to f32)".to_string(),
116            ));
117        }
118
119        // Metal doesn't have true cooperative groups for grid sync
120        if module
121            .required_capabilities
122            .has(CapabilityFlag::CooperativeGroups)
123        {
124            return Err(MslLoweringError::UnsupportedCapability(
125                "Grid-wide sync not supported in Metal".to_string(),
126            ));
127        }
128
129        Ok(())
130    }
131
132    fn emit_header(&mut self) {
133        self.emit_line("// Generated by ringkernel-ir MSL lowering");
134        self.emit_line("#include <metal_stdlib>");
135        self.emit_line("#include <simdgroup_matrix>");
136        self.emit_line("using namespace metal;");
137        self.emit_line("");
138    }
139
140    fn emit_type_definitions(&mut self, _module: &IrModule) {
141        // HLC timestamp type
142        if self.config.enable_hlc {
143            self.emit_line("// HLC Timestamp");
144            self.emit_line("struct HlcTimestamp {");
145            self.indent += 1;
146            self.emit_line("uint64_t physical;");
147            self.emit_line("uint64_t logical;");
148            self.emit_line("uint64_t node_id;");
149            self.indent -= 1;
150            self.emit_line("};");
151            self.emit_line("");
152
153            // HLC intrinsic declarations
154            self.emit_line("// HLC Intrinsics (provided by runtime)");
155            self.emit_line("uint64_t ringkernel_hlc_now();");
156            self.emit_line("uint64_t ringkernel_hlc_tick();");
157            self.emit_line("uint64_t ringkernel_hlc_update(uint64_t incoming);");
158            self.emit_line("");
159        }
160
161        // K2K messaging types and intrinsics
162        if self.config.enable_k2k {
163            // Control block
164            self.emit_line("// Control Block");
165            self.emit_line("struct ControlBlock {");
166            self.indent += 1;
167            self.emit_line("uint32_t is_active;");
168            self.emit_line("uint32_t should_terminate;");
169            self.emit_line("uint32_t has_terminated;");
170            self.emit_line("uint32_t _pad1;");
171            self.emit_line("uint64_t messages_processed;");
172            self.emit_line("uint64_t messages_in_flight;");
173            self.emit_line("uint64_t input_head;");
174            self.emit_line("uint64_t input_tail;");
175            self.emit_line("uint64_t output_head;");
176            self.emit_line("uint64_t output_tail;");
177            self.emit_line("uint32_t input_capacity;");
178            self.emit_line("uint32_t output_capacity;");
179            self.emit_line("uint32_t input_mask;");
180            self.emit_line("uint32_t output_mask;");
181            self.indent -= 1;
182            self.emit_line("};");
183            self.emit_line("");
184
185            // K2H/H2K queue intrinsic declarations
186            self.emit_line("// Queue Intrinsics (provided by runtime)");
187            self.emit_line("bool ringkernel_k2h_enqueue(const device void* msg);");
188            self.emit_line("device void* ringkernel_h2k_dequeue();");
189            self.emit_line("bool ringkernel_h2k_is_empty();");
190            self.emit_line("");
191
192            // K2K messaging intrinsic declarations
193            self.emit_line("// K2K Messaging Intrinsics (provided by runtime)");
194            self.emit_line("bool ringkernel_k2k_send(uint64_t target_id, const device void* msg);");
195            self.emit_line("device void* ringkernel_k2k_recv();");
196            self.emit_line("struct K2KOptionalMsg { bool valid; device void* data; };");
197            self.emit_line("K2KOptionalMsg ringkernel_k2k_try_recv();");
198            self.emit_line("");
199        }
200    }
201
202    fn emit_kernel(&mut self, module: &IrModule) -> Result<(), MslLoweringError> {
203        // Assign names
204        self.assign_names(module);
205
206        // Kernel signature (threadgroup size set at dispatch time)
207        self.emit_line("kernel void");
208        writeln!(self.output, "{}(", module.name).unwrap();
209        self.indent += 1;
210
211        // Parameters
212        for (buffer_idx, param) in module.parameters.iter().enumerate() {
213            let ty = self.lower_type(&param.ty);
214            let qualifier = if param.ty.is_ptr() {
215                "device"
216            } else {
217                "constant"
218            };
219            self.emit_line(&format!(
220                "{} {}& {} [[buffer({})]],",
221                qualifier, ty, param.name, buffer_idx
222            ));
223        }
224
225        // Built-in arguments
226        self.emit_line("uint3 thread_position_in_grid [[thread_position_in_grid]],");
227        self.emit_line("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],");
228        self.emit_line("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],");
229        self.emit_line("uint3 threads_per_threadgroup [[threads_per_threadgroup]],");
230        self.emit_line("uint3 threadgroups_per_grid [[threadgroups_per_grid]],");
231        self.emit_line("uint thread_index_in_simdgroup [[thread_index_in_simdgroup]],");
232        self.emit_line("uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]]");
233
234        self.indent -= 1;
235        self.emit_line(") {");
236        self.indent += 1;
237
238        // Emit blocks
239        self.emit_block(module, module.entry_block)?;
240
241        self.indent -= 1;
242        self.emit_line("}");
243
244        Ok(())
245    }
246
247    fn assign_names(&mut self, module: &IrModule) {
248        for param in &module.parameters {
249            self.value_names.insert(param.value_id, param.name.clone());
250        }
251
252        for (block_id, block) in &module.blocks {
253            self.block_labels.insert(*block_id, block.label.clone());
254        }
255    }
256
257    fn emit_block(&mut self, module: &IrModule, block_id: BlockId) -> Result<(), MslLoweringError> {
258        let block = module
259            .blocks
260            .get(&block_id)
261            .ok_or(MslLoweringError::UndefinedBlock(block_id))?;
262
263        // Block label (skip for entry)
264        if block_id != module.entry_block {
265            self.emit_line(&format!("{}: {{", block.label));
266            self.indent += 1;
267        }
268
269        // Instructions
270        for inst in &block.instructions {
271            self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
272        }
273
274        // Terminator
275        if let Some(term) = &block.terminator {
276            self.emit_terminator(module, term)?;
277        }
278
279        if block_id != module.entry_block {
280            self.indent -= 1;
281            self.emit_line("}");
282        }
283
284        Ok(())
285    }
286
287    fn emit_instruction(
288        &mut self,
289        _module: &IrModule,
290        result: &ValueId,
291        result_type: &IrType,
292        node: &IrNode,
293    ) -> Result<(), MslLoweringError> {
294        let result_name = self.get_or_create_name(*result);
295        let ty = self.lower_type(result_type);
296
297        match node {
298            // Constants
299            IrNode::Constant(c) => {
300                let val = self.lower_constant(c);
301                self.emit_line(&format!("{} {} = {};", ty, result_name, val));
302            }
303
304            // Binary operations
305            IrNode::BinaryOp(op, lhs, rhs) => {
306                let lhs_name = self.get_value_name(*lhs);
307                let rhs_name = self.get_value_name(*rhs);
308                let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
309                self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
310            }
311
312            // Unary operations
313            IrNode::UnaryOp(op, val) => {
314                let val_name = self.get_value_name(*val);
315                let expr = self.lower_unary_op(op, &val_name);
316                self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
317            }
318
319            // Comparisons
320            IrNode::Compare(op, lhs, rhs) => {
321                let lhs_name = self.get_value_name(*lhs);
322                let rhs_name = self.get_value_name(*rhs);
323                let cmp_op = self.lower_compare_op(op);
324                self.emit_line(&format!(
325                    "bool {} = {} {} {};",
326                    result_name, lhs_name, cmp_op, rhs_name
327                ));
328            }
329
330            // Memory operations
331            IrNode::Load(ptr) => {
332                let ptr_name = self.get_value_name(*ptr);
333                self.emit_line(&format!("{} {} = {};", ty, result_name, ptr_name));
334            }
335
336            IrNode::Store(ptr, val) => {
337                let ptr_name = self.get_value_name(*ptr);
338                let val_name = self.get_value_name(*val);
339                self.emit_line(&format!("{} = {};", ptr_name, val_name));
340            }
341
342            IrNode::GetElementPtr(ptr, indices) => {
343                let ptr_name = self.get_value_name(*ptr);
344                let idx_name = self.get_value_name(indices[0]);
345                self.emit_line(&format!(
346                    "{} {} = {}[{}];",
347                    ty, result_name, ptr_name, idx_name
348                ));
349            }
350
351            IrNode::SharedAlloc(elem_ty, count) => {
352                let elem = self.lower_type(elem_ty);
353                self.emit_line(&format!("threadgroup {} {}[{}];", elem, result_name, count));
354            }
355
356            // GPU indexing
357            IrNode::ThreadId(dim) => {
358                let idx = self.lower_dimension(dim, "thread_position_in_threadgroup");
359                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
360            }
361
362            IrNode::BlockId(dim) => {
363                let idx = self.lower_dimension(dim, "threadgroup_position_in_grid");
364                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
365            }
366
367            IrNode::BlockDim(dim) => {
368                let idx = self.lower_dimension(dim, "threads_per_threadgroup");
369                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
370            }
371
372            IrNode::GridDim(dim) => {
373                let idx = self.lower_dimension(dim, "threadgroups_per_grid");
374                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
375            }
376
377            IrNode::GlobalThreadId(dim) => {
378                let idx = self.lower_dimension(dim, "thread_position_in_grid");
379                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
380            }
381
382            IrNode::WarpId => {
383                self.emit_line(&format!(
384                    "{} {} = simdgroup_index_in_threadgroup;",
385                    ty, result_name
386                ));
387            }
388
389            IrNode::LaneId => {
390                self.emit_line(&format!(
391                    "{} {} = thread_index_in_simdgroup;",
392                    ty, result_name
393                ));
394            }
395
396            // Synchronization
397            IrNode::Barrier => {
398                self.emit_line("threadgroup_barrier(mem_flags::mem_threadgroup);");
399            }
400
401            IrNode::MemoryFence(scope) => {
402                let fence = match scope {
403                    MemoryScope::Thread => "threadgroup_barrier(mem_flags::mem_none)",
404                    MemoryScope::Threadgroup => "threadgroup_barrier(mem_flags::mem_threadgroup)",
405                    MemoryScope::Device => "threadgroup_barrier(mem_flags::mem_device)",
406                    MemoryScope::System => "threadgroup_barrier(mem_flags::mem_device)",
407                };
408                self.emit_line(&format!("{};", fence));
409            }
410
411            IrNode::GridSync => {
412                return Err(MslLoweringError::UnsupportedOperation(
413                    "Grid sync not supported in Metal".to_string(),
414                ));
415            }
416
417            // Atomics
418            IrNode::Atomic(op, ptr, val) => {
419                let ptr_name = self.get_value_name(*ptr);
420                let val_name = self.get_value_name(*val);
421                let atomic_fn = match op {
422                    AtomicOp::Add => "atomic_fetch_add_explicit",
423                    AtomicOp::Sub => "atomic_fetch_sub_explicit",
424                    AtomicOp::Exchange => "atomic_exchange_explicit",
425                    AtomicOp::Min => "atomic_fetch_min_explicit",
426                    AtomicOp::Max => "atomic_fetch_max_explicit",
427                    AtomicOp::And => "atomic_fetch_and_explicit",
428                    AtomicOp::Or => "atomic_fetch_or_explicit",
429                    AtomicOp::Xor => "atomic_fetch_xor_explicit",
430                    AtomicOp::Load => {
431                        self.emit_line(&format!(
432                            "{} {} = atomic_load_explicit(&{}, memory_order_relaxed);",
433                            ty, result_name, ptr_name
434                        ));
435                        return Ok(());
436                    }
437                    AtomicOp::Store => {
438                        self.emit_line(&format!(
439                            "atomic_store_explicit(&{}, {}, memory_order_relaxed);",
440                            ptr_name, val_name
441                        ));
442                        return Ok(());
443                    }
444                };
445                self.emit_line(&format!(
446                    "{} {} = {}(&{}, {}, memory_order_relaxed);",
447                    ty, result_name, atomic_fn, ptr_name, val_name
448                ));
449            }
450
451            IrNode::AtomicCas(ptr, expected, desired) => {
452                let ptr_name = self.get_value_name(*ptr);
453                let exp_name = self.get_value_name(*expected);
454                let des_name = self.get_value_name(*desired);
455                self.emit_line(&format!("{} {} = {};", ty, result_name, exp_name));
456                self.emit_line(&format!(
457                    "atomic_compare_exchange_weak_explicit(&{}, &{}, {}, memory_order_relaxed, memory_order_relaxed);",
458                    ptr_name, result_name, des_name
459                ));
460            }
461
462            // SIMD group operations
463            IrNode::WarpVote(op, val) => {
464                if !self.config.simd_groups {
465                    return Err(MslLoweringError::UnsupportedOperation(
466                        "SIMD group operations require simd_groups feature".to_string(),
467                    ));
468                }
469                let val_name = self.get_value_name(*val);
470                let vote_fn = match op {
471                    WarpVoteOp::All => "simd_all",
472                    WarpVoteOp::Any => "simd_any",
473                    WarpVoteOp::Ballot => "simd_ballot",
474                };
475                self.emit_line(&format!(
476                    "{} {} = {}({});",
477                    ty, result_name, vote_fn, val_name
478                ));
479            }
480
481            IrNode::WarpShuffle(op, val, lane) => {
482                if !self.config.simd_groups {
483                    return Err(MslLoweringError::UnsupportedOperation(
484                        "SIMD shuffle requires simd_groups feature".to_string(),
485                    ));
486                }
487                let val_name = self.get_value_name(*val);
488                let lane_name = self.get_value_name(*lane);
489                let shfl_fn = match op {
490                    WarpShuffleOp::Index => "simd_shuffle",
491                    WarpShuffleOp::Up => "simd_shuffle_up",
492                    WarpShuffleOp::Down => "simd_shuffle_down",
493                    WarpShuffleOp::Xor => "simd_shuffle_xor",
494                };
495                self.emit_line(&format!(
496                    "{} {} = {}({}, {});",
497                    ty, result_name, shfl_fn, val_name, lane_name
498                ));
499            }
500
501            // Select
502            IrNode::Select(cond, then_val, else_val) => {
503                let cond_name = self.get_value_name(*cond);
504                let then_name = self.get_value_name(*then_val);
505                let else_name = self.get_value_name(*else_val);
506                self.emit_line(&format!(
507                    "{} {} = select({}, {}, {});",
508                    ty, result_name, else_name, then_name, cond_name
509                ));
510            }
511
512            // Math functions
513            IrNode::Math(op, args) => {
514                let fn_name = self.lower_math_op(op);
515                let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
516                self.emit_line(&format!(
517                    "{} {} = {}({});",
518                    ty,
519                    result_name,
520                    fn_name,
521                    args_str.join(", ")
522                ));
523            }
524
525            // Skip nodes that don't produce MSL output
526            IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
527
528            // ========================================================================
529            // Messaging Operations
530            // ========================================================================
531
532            // K2H (Kernel-to-Host) enqueue
533            IrNode::K2HEnqueue(value) => {
534                let val_name = self.get_value_name(*value);
535                // Enqueue returns success status (bool)
536                self.emit_line(&format!(
537                    "{} {} = ringkernel_k2h_enqueue({});",
538                    ty, result_name, val_name
539                ));
540            }
541
542            // H2K (Host-to-Kernel) dequeue
543            IrNode::H2KDequeue => {
544                // Dequeue returns the message struct
545                self.emit_line(&format!(
546                    "{} {} = ringkernel_h2k_dequeue();",
547                    ty, result_name
548                ));
549            }
550
551            // H2K queue empty check
552            IrNode::H2KIsEmpty => {
553                // Returns true if queue is empty
554                self.emit_line(&format!(
555                    "{} {} = ringkernel_h2k_is_empty();",
556                    ty, result_name
557                ));
558            }
559
560            // K2K (Kernel-to-Kernel) send
561            IrNode::K2KSend(target_id, message) => {
562                let target_name = self.get_value_name(*target_id);
563                let msg_name = self.get_value_name(*message);
564                // Send returns success status (bool)
565                self.emit_line(&format!(
566                    "{} {} = ringkernel_k2k_send({}, {});",
567                    ty, result_name, target_name, msg_name
568                ));
569            }
570
571            // K2K blocking receive
572            IrNode::K2KRecv => {
573                // Blocking receive returns the message struct
574                self.emit_line(&format!("{} {} = ringkernel_k2k_recv();", ty, result_name));
575            }
576
577            // K2K non-blocking try receive
578            IrNode::K2KTryRecv => {
579                // Try receive returns optional message (use .valid field to check)
580                self.emit_line(&format!(
581                    "{} {} = ringkernel_k2k_try_recv();",
582                    ty, result_name
583                ));
584            }
585
586            // ========================================================================
587            // HLC (Hybrid Logical Clock) Operations
588            // ========================================================================
589
590            // Get current HLC time
591            IrNode::HlcNow => {
592                // Returns current HLC timestamp (uint64_t)
593                self.emit_line(&format!("{} {} = ringkernel_hlc_now();", ty, result_name));
594            }
595
596            // Tick HLC and return new time
597            IrNode::HlcTick => {
598                // Increments logical counter and returns new timestamp
599                self.emit_line(&format!("{} {} = ringkernel_hlc_tick();", ty, result_name));
600            }
601
602            // Update HLC from incoming timestamp
603            IrNode::HlcUpdate(incoming) => {
604                let incoming_name = self.get_value_name(*incoming);
605                // Updates HLC using max(local, incoming) + 1 rule
606                self.emit_line(&format!(
607                    "{} {} = ringkernel_hlc_update({});",
608                    ty, result_name, incoming_name
609                ));
610            }
611
612            _ => {
613                self.emit_line(&format!("// Unhandled: {:?}", node));
614            }
615        }
616
617        Ok(())
618    }
619
620    fn emit_terminator(
621        &mut self,
622        _module: &IrModule,
623        term: &Terminator,
624    ) -> Result<(), MslLoweringError> {
625        match term {
626            Terminator::Return(None) => {
627                self.emit_line("return;");
628            }
629            Terminator::Return(Some(val)) => {
630                let val_name = self.get_value_name(*val);
631                self.emit_line(&format!("// Return: {}", val_name));
632                self.emit_line("return;");
633            }
634            Terminator::Branch(target) => {
635                let label = self.block_labels.get(target).cloned().unwrap_or_default();
636                self.emit_line(&format!("goto {};", label));
637            }
638            Terminator::CondBranch(cond, then_block, else_block) => {
639                let cond_name = self.get_value_name(*cond);
640                let then_label = self
641                    .block_labels
642                    .get(then_block)
643                    .cloned()
644                    .unwrap_or_default();
645                let else_label = self
646                    .block_labels
647                    .get(else_block)
648                    .cloned()
649                    .unwrap_or_default();
650                self.emit_line(&format!(
651                    "if ({}) goto {}; else goto {};",
652                    cond_name, then_label, else_label
653                ));
654            }
655            Terminator::Switch(val, default, cases) => {
656                let val_name = self.get_value_name(*val);
657                self.emit_line(&format!("switch ({}) {{", val_name));
658                self.indent += 1;
659                for (case_val, target) in cases {
660                    let case_str = self.lower_constant(case_val);
661                    let label = self.block_labels.get(target).cloned().unwrap_or_default();
662                    self.emit_line(&format!("case {}: goto {};", case_str, label));
663                }
664                let default_label = self.block_labels.get(default).cloned().unwrap_or_default();
665                self.emit_line(&format!("default: goto {};", default_label));
666                self.indent -= 1;
667                self.emit_line("}");
668            }
669            Terminator::Unreachable => {
670                self.emit_line("// unreachable");
671            }
672        }
673        Ok(())
674    }
675
676    fn lower_type(&self, ty: &IrType) -> String {
677        match ty {
678            IrType::Void => "void".to_string(),
679            IrType::Scalar(s) => self.lower_scalar_type(s),
680            IrType::Vector(v) => format!("{}{}", self.lower_scalar_type(&v.element), v.count),
681            IrType::Ptr(inner) => format!("device {}*", self.lower_type(inner)),
682            IrType::Array(inner, size) => format!("array<{}, {}>", self.lower_type(inner), size),
683            IrType::Slice(inner) => format!("device {}*", self.lower_type(inner)),
684            IrType::Struct(s) => s.name.clone(),
685            IrType::Function(_) => "void*".to_string(),
686        }
687    }
688
689    fn lower_scalar_type(&self, ty: &ScalarType) -> String {
690        match ty {
691            ScalarType::Bool => "bool",
692            ScalarType::I8 => "char",
693            ScalarType::I16 => "short",
694            ScalarType::I32 => "int",
695            ScalarType::I64 => "long",
696            ScalarType::U8 => "uchar",
697            ScalarType::U16 => "ushort",
698            ScalarType::U32 => "uint",
699            ScalarType::U64 => "ulong",
700            ScalarType::F16 => "half",
701            ScalarType::F32 => "float",
702            ScalarType::F64 => "float", // Metal doesn't support f64
703        }
704        .to_string()
705    }
706
707    fn lower_constant(&self, c: &ConstantValue) -> String {
708        match c {
709            ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
710            ConstantValue::I32(v) => format!("{}", v),
711            ConstantValue::I64(v) => format!("{}L", v),
712            ConstantValue::U32(v) => format!("{}u", v),
713            ConstantValue::U64(v) => format!("{}uL", v),
714            ConstantValue::F32(v) => format!("{}f", v),
715            ConstantValue::F64(v) => format!("{}f", *v as f32), // Downcast
716            ConstantValue::Null => "nullptr".to_string(),
717            ConstantValue::Array(elems) => {
718                let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
719                format!("{{{}}}", elems_str.join(", "))
720            }
721            ConstantValue::Struct(fields) => {
722                let fields_str: Vec<String> =
723                    fields.iter().map(|f| self.lower_constant(f)).collect();
724                format!("{{{}}}", fields_str.join(", "))
725            }
726        }
727    }
728
729    fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
730        match op {
731            BinaryOp::Add => format!("{} + {}", lhs, rhs),
732            BinaryOp::Sub => format!("{} - {}", lhs, rhs),
733            BinaryOp::Mul => format!("{} * {}", lhs, rhs),
734            BinaryOp::Div => format!("{} / {}", lhs, rhs),
735            BinaryOp::Rem => format!("{} % {}", lhs, rhs),
736            BinaryOp::And => format!("{} & {}", lhs, rhs),
737            BinaryOp::Or => format!("{} | {}", lhs, rhs),
738            BinaryOp::Xor => format!("{} ^ {}", lhs, rhs),
739            BinaryOp::Shl => format!("{} << {}", lhs, rhs),
740            BinaryOp::Shr => format!("{} >> {}", lhs, rhs),
741            BinaryOp::Sar => format!("{} >> {}", lhs, rhs),
742            BinaryOp::Fma => format!("fma({}, {}, 0.0f)", lhs, rhs),
743            BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
744            BinaryOp::Min => format!("min({}, {})", lhs, rhs),
745            BinaryOp::Max => format!("max({}, {})", lhs, rhs),
746        }
747    }
748
749    fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
750        match op {
751            UnaryOp::Neg => format!("-{}", val),
752            UnaryOp::Not => format!("~{}", val),
753            UnaryOp::LogicalNot => format!("!{}", val),
754            UnaryOp::Abs => format!("abs({})", val),
755            UnaryOp::Sqrt => format!("sqrt({})", val),
756            UnaryOp::Rsqrt => format!("rsqrt({})", val),
757            UnaryOp::Floor => format!("floor({})", val),
758            UnaryOp::Ceil => format!("ceil({})", val),
759            UnaryOp::Round => format!("round({})", val),
760            UnaryOp::Trunc => format!("trunc({})", val),
761            UnaryOp::Sign => format!("sign({})", val),
762        }
763    }
764
765    fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
766        match op {
767            CompareOp::Eq => "==",
768            CompareOp::Ne => "!=",
769            CompareOp::Lt => "<",
770            CompareOp::Le => "<=",
771            CompareOp::Gt => ">",
772            CompareOp::Ge => ">=",
773        }
774    }
775
776    fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
777        match dim {
778            Dimension::X => format!("{}.x", prefix),
779            Dimension::Y => format!("{}.y", prefix),
780            Dimension::Z => format!("{}.z", prefix),
781        }
782    }
783
784    fn lower_math_op(&self, op: &MathOp) -> &'static str {
785        match op {
786            MathOp::Sin => "sin",
787            MathOp::Cos => "cos",
788            MathOp::Tan => "tan",
789            MathOp::Asin => "asin",
790            MathOp::Acos => "acos",
791            MathOp::Atan => "atan",
792            MathOp::Atan2 => "atan2",
793            MathOp::Sinh => "sinh",
794            MathOp::Cosh => "cosh",
795            MathOp::Tanh => "tanh",
796            MathOp::Exp => "exp",
797            MathOp::Exp2 => "exp2",
798            MathOp::Log => "log",
799            MathOp::Log2 => "log2",
800            MathOp::Log10 => "log10",
801            MathOp::Lerp => "mix",
802            MathOp::Clamp => "clamp",
803            MathOp::Step => "step",
804            MathOp::SmoothStep => "smoothstep",
805            MathOp::Fract => "fract",
806            MathOp::CopySign => "copysign",
807        }
808    }
809
810    fn get_value_name(&self, id: ValueId) -> String {
811        self.value_names
812            .get(&id)
813            .cloned()
814            .unwrap_or_else(|| format!("v{}", id.raw()))
815    }
816
817    fn get_or_create_name(&mut self, id: ValueId) -> String {
818        if let Some(name) = self.value_names.get(&id) {
819            return name.clone();
820        }
821        let name = format!("t{}", self.name_counter);
822        self.name_counter += 1;
823        self.value_names.insert(id, name.clone());
824        name
825    }
826
827    fn emit_line(&mut self, line: &str) {
828        let indent = "    ".repeat(self.indent);
829        writeln!(self.output, "{}{}", indent, line).unwrap();
830    }
831}
832
833/// MSL lowering errors.
834#[derive(Debug, Clone)]
835pub enum MslLoweringError {
836    /// Unsupported capability.
837    UnsupportedCapability(String),
838    /// Unsupported operation.
839    UnsupportedOperation(String),
840    /// Undefined block reference.
841    UndefinedBlock(BlockId),
842    /// Undefined value reference.
843    UndefinedValue(ValueId),
844    /// Type error.
845    TypeError(String),
846}
847
848impl std::fmt::Display for MslLoweringError {
849    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
850        match self {
851            MslLoweringError::UnsupportedCapability(cap) => {
852                write!(f, "Unsupported capability: {}", cap)
853            }
854            MslLoweringError::UnsupportedOperation(op) => {
855                write!(f, "Unsupported operation: {}", op)
856            }
857            MslLoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
858            MslLoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
859            MslLoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
860        }
861    }
862}
863
864impl std::error::Error for MslLoweringError {}
865
866/// Convenience function to lower IR to MSL.
867pub fn lower_to_msl(module: &IrModule) -> Result<String, MslLoweringError> {
868    MslLowering::new(MslLoweringConfig::default()).lower(module)
869}
870
871/// Lower IR to MSL with custom config.
872pub fn lower_to_msl_with_config(
873    module: &IrModule,
874    config: MslLoweringConfig,
875) -> Result<String, MslLoweringError> {
876    MslLowering::new(config).lower(module)
877}
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882    use crate::IrBuilder;
883
884    #[test]
885    fn test_lower_simple_kernel() {
886        let mut builder = IrBuilder::new("add_one");
887
888        let _x = builder.parameter("x", IrType::ptr(IrType::F32));
889        let _n = builder.parameter("n", IrType::I32);
890
891        let idx = builder.global_thread_id(Dimension::X);
892        let _ = idx;
893
894        builder.ret();
895
896        let module = builder.build();
897        let msl = lower_to_msl(&module).unwrap();
898
899        assert!(msl.contains("kernel void"));
900        assert!(msl.contains("add_one"));
901        assert!(msl.contains("thread_position_in_grid"));
902    }
903
904    #[test]
905    fn test_lower_with_threadgroup_memory() {
906        let mut builder = IrBuilder::new("reduce");
907
908        let shared = builder.shared_alloc(IrType::F32, 256);
909        let _ = shared;
910
911        builder.barrier();
912        builder.ret();
913
914        let module = builder.build();
915        let msl = lower_to_msl(&module).unwrap();
916
917        assert!(msl.contains("threadgroup float"));
918        assert!(msl.contains("threadgroup_barrier"));
919    }
920
921    #[test]
922    fn test_lower_with_simd_ops() {
923        let mut builder = IrBuilder::new("simd");
924
925        let val = builder.const_bool(true);
926        let _ = val;
927
928        builder.ret();
929
930        let module = builder.build();
931        let config = MslLoweringConfig::metal3();
932        let msl = lower_to_msl_with_config(&module, config).unwrap();
933
934        assert!(msl.contains("#include <metal_stdlib>"));
935    }
936
937    #[test]
938    fn test_lower_with_atomics() {
939        let mut builder = IrBuilder::new("atomic");
940
941        let counter = builder.parameter("counter", IrType::ptr(IrType::U32));
942        let one = builder.const_u32(1);
943        let _old = builder.atomic_add(counter, one);
944
945        builder.ret();
946
947        let module = builder.build();
948        let msl = lower_to_msl(&module).unwrap();
949
950        assert!(msl.contains("atomic_fetch_add_explicit"));
951    }
952
953    #[test]
954    fn test_lower_rejects_grid_sync() {
955        let mut builder = IrBuilder::new("grid");
956        builder.grid_sync();
957        builder.ret();
958
959        let module = builder.build();
960        let result = lower_to_msl(&module);
961
962        assert!(result.is_err());
963    }
964}