Skip to main content

ringkernel_ir/
lower_cuda.rs

1//! IR to CUDA lowering pass.
2//!
3//! Lowers IR to CUDA C code for compilation with nvcc.
4
5use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9    nodes::*, BackendCapabilities, BlockId, Dimension, IrModule, IrNode, IrType, ScalarType,
10    Terminator, ValueId,
11};
12
13/// CUDA lowering configuration.
14#[derive(Debug, Clone)]
15pub struct CudaLoweringConfig {
16    /// Target compute capability (e.g., 80 for SM 8.0).
17    pub compute_capability: u32,
18    /// Enable cooperative groups.
19    pub cooperative_groups: bool,
20    /// Enable HLC (Hybrid Logical Clocks).
21    pub enable_hlc: bool,
22    /// Enable K2K messaging.
23    pub enable_k2k: bool,
24    /// Use fast math.
25    pub fast_math: bool,
26    /// Generate debug info.
27    pub debug: bool,
28}
29
30impl Default for CudaLoweringConfig {
31    fn default() -> Self {
32        Self {
33            compute_capability: 70,
34            cooperative_groups: false,
35            enable_hlc: false,
36            enable_k2k: false,
37            fast_math: false,
38            debug: false,
39        }
40    }
41}
42
43impl CudaLoweringConfig {
44    /// Create config for SM 8.0+.
45    pub fn sm80() -> Self {
46        Self {
47            compute_capability: 80,
48            cooperative_groups: true,
49            ..Default::default()
50        }
51    }
52
53    /// Enable persistent kernel features.
54    pub fn with_persistent(mut self) -> Self {
55        self.enable_hlc = true;
56        self.enable_k2k = true;
57        self.cooperative_groups = true;
58        self
59    }
60}
61
62/// CUDA code generator.
63pub struct CudaLowering {
64    config: CudaLoweringConfig,
65    output: String,
66    indent: usize,
67    value_names: HashMap<ValueId, String>,
68    name_counter: usize,
69    block_labels: HashMap<BlockId, String>,
70}
71
72impl CudaLowering {
73    /// Create a new CUDA lowering pass.
74    pub fn new(config: CudaLoweringConfig) -> Self {
75        Self {
76            config,
77            output: String::new(),
78            indent: 0,
79            value_names: HashMap::new(),
80            name_counter: 0,
81            block_labels: HashMap::new(),
82        }
83    }
84
85    /// Lower an IR module to CUDA code.
86    pub fn lower(mut self, module: &IrModule) -> Result<String, LoweringError> {
87        // Check capabilities
88        self.check_capabilities(module)?;
89
90        // Generate includes
91        self.emit_includes();
92
93        // Generate type definitions
94        self.emit_type_definitions(module);
95
96        // Generate kernel
97        self.emit_kernel(module)?;
98
99        Ok(self.output)
100    }
101
102    fn check_capabilities(&self, module: &IrModule) -> Result<(), LoweringError> {
103        let cuda_caps = BackendCapabilities::cuda_sm80();
104
105        let unsupported = cuda_caps.unsupported(&module.required_capabilities);
106        if !unsupported.is_empty() {
107            return Err(LoweringError::UnsupportedCapability(
108                unsupported
109                    .iter()
110                    .map(|c| format!("{}", c))
111                    .collect::<Vec<_>>()
112                    .join(", "),
113            ));
114        }
115
116        Ok(())
117    }
118
119    fn emit_includes(&mut self) {
120        self.emit_line("// Generated by ringkernel-ir CUDA lowering");
121        self.emit_line("#include <cuda_runtime.h>");
122        self.emit_line("#include <stdint.h>");
123
124        if self.config.cooperative_groups {
125            self.emit_line("#include <cooperative_groups.h>");
126            self.emit_line("namespace cg = cooperative_groups;");
127        }
128
129        self.emit_line("");
130    }
131
132    fn emit_type_definitions(&mut self, _module: &IrModule) {
133        // HLC timestamp type
134        if self.config.enable_hlc {
135            self.emit_line("// HLC Timestamp");
136            self.emit_line("struct HlcTimestamp {");
137            self.indent += 1;
138            self.emit_line("uint64_t physical;");
139            self.emit_line("uint64_t logical;");
140            self.emit_line("uint64_t node_id;");
141            self.indent -= 1;
142            self.emit_line("};");
143            self.emit_line("");
144        }
145
146        // Control block for persistent kernels
147        if self.config.enable_k2k {
148            self.emit_line("// Control Block");
149            self.emit_line("struct ControlBlock {");
150            self.indent += 1;
151            self.emit_line("uint32_t is_active;");
152            self.emit_line("uint32_t should_terminate;");
153            self.emit_line("uint32_t has_terminated;");
154            self.emit_line("uint32_t _pad1;");
155            self.emit_line("uint64_t messages_processed;");
156            self.emit_line("uint64_t messages_in_flight;");
157            self.emit_line("uint64_t input_head;");
158            self.emit_line("uint64_t input_tail;");
159            self.emit_line("uint64_t output_head;");
160            self.emit_line("uint64_t output_tail;");
161            self.emit_line("uint32_t input_capacity;");
162            self.emit_line("uint32_t output_capacity;");
163            self.emit_line("uint32_t input_mask;");
164            self.emit_line("uint32_t output_mask;");
165            self.indent -= 1;
166            self.emit_line("};");
167            self.emit_line("");
168
169            // K2H/H2K queue intrinsic declarations
170            self.emit_line("// Queue Intrinsics (provided by runtime)");
171            self.emit_line("__device__ bool __ringkernel_k2h_enqueue(const void* msg);");
172            self.emit_line("__device__ void* __ringkernel_h2k_dequeue();");
173            self.emit_line("__device__ bool __ringkernel_h2k_is_empty();");
174            self.emit_line("");
175
176            // K2K messaging intrinsic declarations
177            self.emit_line("// K2K Messaging Intrinsics (provided by runtime)");
178            self.emit_line(
179                "__device__ bool __ringkernel_k2k_send(uint64_t target_id, const void* msg);",
180            );
181            self.emit_line("__device__ void* __ringkernel_k2k_recv();");
182            self.emit_line("struct K2KOptionalMsg { bool valid; void* data; };");
183            self.emit_line("__device__ K2KOptionalMsg __ringkernel_k2k_try_recv();");
184            self.emit_line("");
185        }
186
187        // HLC intrinsic declarations
188        if self.config.enable_hlc {
189            self.emit_line("// HLC Intrinsics (provided by runtime)");
190            self.emit_line("__device__ uint64_t __ringkernel_hlc_now();");
191            self.emit_line("__device__ uint64_t __ringkernel_hlc_tick();");
192            self.emit_line("__device__ uint64_t __ringkernel_hlc_update(uint64_t incoming);");
193            self.emit_line("");
194        }
195    }
196
197    fn emit_kernel(&mut self, module: &IrModule) -> Result<(), LoweringError> {
198        // Assign names to values and blocks
199        self.assign_names(module);
200
201        // Kernel signature
202        let kernel_attr = if self.config.cooperative_groups {
203            "__global__ void __launch_bounds__(256)"
204        } else {
205            "__global__ void"
206        };
207
208        write!(self.output, "{} {}(", kernel_attr, module.name).unwrap();
209
210        // Parameters
211        for (i, param) in module.parameters.iter().enumerate() {
212            if i > 0 {
213                write!(self.output, ", ").unwrap();
214            }
215            let ty = self.lower_type(&param.ty);
216            write!(self.output, "{} {}", ty, param.name).unwrap();
217        }
218
219        self.emit_line(") {");
220        self.indent += 1;
221
222        // Cooperative groups setup
223        if self.config.cooperative_groups {
224            self.emit_line("cg::grid_group grid = cg::this_grid();");
225            self.emit_line("cg::thread_block block = cg::this_thread_block();");
226            self.emit_line("");
227        }
228
229        // Emit blocks
230        self.emit_block(module, module.entry_block)?;
231
232        // Emit other blocks
233        for block_id in module.blocks.keys() {
234            if *block_id != module.entry_block {
235                self.emit_block(module, *block_id)?;
236            }
237        }
238
239        self.indent -= 1;
240        self.emit_line("}");
241
242        Ok(())
243    }
244
245    fn assign_names(&mut self, module: &IrModule) {
246        // Assign names to parameters
247        for param in &module.parameters {
248            self.value_names.insert(param.value_id, param.name.clone());
249        }
250
251        // Assign names to blocks
252        for (block_id, block) in &module.blocks {
253            self.block_labels.insert(*block_id, block.label.clone());
254        }
255    }
256
257    fn emit_block(&mut self, module: &IrModule, block_id: BlockId) -> Result<(), LoweringError> {
258        let block = module
259            .blocks
260            .get(&block_id)
261            .ok_or(LoweringError::UndefinedBlock(block_id))?;
262
263        // Block label (skip for entry)
264        if block_id != module.entry_block {
265            self.emit_line(&format!("{}: {{", block.label));
266            self.indent += 1;
267        }
268
269        // Instructions
270        for inst in &block.instructions {
271            self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
272        }
273
274        // Terminator
275        if let Some(term) = &block.terminator {
276            self.emit_terminator(term)?;
277        }
278
279        if block_id != module.entry_block {
280            self.indent -= 1;
281            self.emit_line("}");
282        }
283
284        Ok(())
285    }
286
287    fn emit_instruction(
288        &mut self,
289        _module: &IrModule,
290        result: &ValueId,
291        result_type: &IrType,
292        node: &IrNode,
293    ) -> Result<(), LoweringError> {
294        let result_name = self.get_or_create_name(*result);
295        let ty = self.lower_type(result_type);
296
297        match node {
298            // Constants
299            IrNode::Constant(c) => {
300                let val = self.lower_constant(c);
301                self.emit_line(&format!("{} {} = {};", ty, result_name, val));
302            }
303
304            // Binary operations
305            IrNode::BinaryOp(op, lhs, rhs) => {
306                let lhs_name = self.get_value_name(*lhs);
307                let rhs_name = self.get_value_name(*rhs);
308                let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
309                self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
310            }
311
312            // Unary operations
313            IrNode::UnaryOp(op, val) => {
314                let val_name = self.get_value_name(*val);
315                let expr = self.lower_unary_op(op, &val_name);
316                self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
317            }
318
319            // Comparisons
320            IrNode::Compare(op, lhs, rhs) => {
321                let lhs_name = self.get_value_name(*lhs);
322                let rhs_name = self.get_value_name(*rhs);
323                let cmp_op = self.lower_compare_op(op);
324                self.emit_line(&format!(
325                    "bool {} = {} {} {};",
326                    result_name, lhs_name, cmp_op, rhs_name
327                ));
328            }
329
330            // Memory operations
331            IrNode::Load(ptr) => {
332                let ptr_name = self.get_value_name(*ptr);
333                self.emit_line(&format!("{} {} = *{};", ty, result_name, ptr_name));
334            }
335
336            IrNode::Store(ptr, val) => {
337                let ptr_name = self.get_value_name(*ptr);
338                let val_name = self.get_value_name(*val);
339                self.emit_line(&format!("*{} = {};", ptr_name, val_name));
340            }
341
342            IrNode::GetElementPtr(ptr, indices) => {
343                let ptr_name = self.get_value_name(*ptr);
344                let idx_name = self.get_value_name(indices[0]);
345                self.emit_line(&format!(
346                    "{} {} = &{}[{}];",
347                    ty, result_name, ptr_name, idx_name
348                ));
349            }
350
351            IrNode::SharedAlloc(elem_ty, count) => {
352                let elem = self.lower_type(elem_ty);
353                self.emit_line(&format!("__shared__ {} {}[{}];", elem, result_name, count));
354            }
355
356            // GPU indexing
357            IrNode::ThreadId(dim) => {
358                let idx = self.lower_dimension(dim, "threadIdx");
359                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
360            }
361
362            IrNode::BlockId(dim) => {
363                let idx = self.lower_dimension(dim, "blockIdx");
364                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
365            }
366
367            IrNode::BlockDim(dim) => {
368                let idx = self.lower_dimension(dim, "blockDim");
369                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
370            }
371
372            IrNode::GridDim(dim) => {
373                let idx = self.lower_dimension(dim, "gridDim");
374                self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
375            }
376
377            IrNode::GlobalThreadId(dim) => {
378                let block_idx = self.lower_dimension(dim, "blockIdx");
379                let block_dim = self.lower_dimension(dim, "blockDim");
380                let thread_idx = self.lower_dimension(dim, "threadIdx");
381                self.emit_line(&format!(
382                    "{} {} = {} * {} + {};",
383                    ty, result_name, block_idx, block_dim, thread_idx
384                ));
385            }
386
387            IrNode::WarpId => {
388                self.emit_line(&format!("{} {} = threadIdx.x / 32;", ty, result_name));
389            }
390
391            IrNode::LaneId => {
392                self.emit_line(&format!("{} {} = threadIdx.x % 32;", ty, result_name));
393            }
394
395            // Synchronization
396            IrNode::Barrier => {
397                self.emit_line("__syncthreads();");
398            }
399
400            IrNode::MemoryFence(scope) => {
401                let fence = match scope {
402                    MemoryScope::Thread => "__threadfence_block()",
403                    MemoryScope::Threadgroup => "__threadfence_block()",
404                    MemoryScope::Device => "__threadfence()",
405                    MemoryScope::System => "__threadfence_system()",
406                };
407                self.emit_line(&format!("{};", fence));
408            }
409
410            IrNode::GridSync => {
411                if self.config.cooperative_groups {
412                    self.emit_line("grid.sync();");
413                } else {
414                    return Err(LoweringError::RequiresCooperativeGroups);
415                }
416            }
417
418            // Atomics
419            IrNode::Atomic(op, ptr, val) => {
420                let ptr_name = self.get_value_name(*ptr);
421                let val_name = self.get_value_name(*val);
422                let atomic_fn = match op {
423                    AtomicOp::Add => "atomicAdd",
424                    AtomicOp::Sub => "atomicSub",
425                    AtomicOp::Exchange => "atomicExch",
426                    AtomicOp::Min => "atomicMin",
427                    AtomicOp::Max => "atomicMax",
428                    AtomicOp::And => "atomicAnd",
429                    AtomicOp::Or => "atomicOr",
430                    AtomicOp::Xor => "atomicXor",
431                    AtomicOp::Load => {
432                        self.emit_line(&format!(
433                            "{} {} = atomicAdd({}, 0);",
434                            ty, result_name, ptr_name
435                        ));
436                        return Ok(());
437                    }
438                    AtomicOp::Store => {
439                        self.emit_line(&format!("atomicExch({}, {});", ptr_name, val_name));
440                        return Ok(());
441                    }
442                };
443                self.emit_line(&format!(
444                    "{} {} = {}({}, {});",
445                    ty, result_name, atomic_fn, ptr_name, val_name
446                ));
447            }
448
449            IrNode::AtomicCas(ptr, expected, desired) => {
450                let ptr_name = self.get_value_name(*ptr);
451                let exp_name = self.get_value_name(*expected);
452                let des_name = self.get_value_name(*desired);
453                self.emit_line(&format!(
454                    "{} {} = atomicCAS({}, {}, {});",
455                    ty, result_name, ptr_name, exp_name, des_name
456                ));
457            }
458
459            // Warp operations
460            IrNode::WarpVote(op, val) => {
461                let val_name = self.get_value_name(*val);
462                let vote_fn = match op {
463                    WarpVoteOp::All => "__all_sync(0xFFFFFFFF, ",
464                    WarpVoteOp::Any => "__any_sync(0xFFFFFFFF, ",
465                    WarpVoteOp::Ballot => "__ballot_sync(0xFFFFFFFF, ",
466                };
467                self.emit_line(&format!(
468                    "{} {} = {}{})",
469                    ty, result_name, vote_fn, val_name
470                ));
471            }
472
473            IrNode::WarpShuffle(op, val, lane) => {
474                let val_name = self.get_value_name(*val);
475                let lane_name = self.get_value_name(*lane);
476                let shfl_fn = match op {
477                    WarpShuffleOp::Index => "__shfl_sync(0xFFFFFFFF, ",
478                    WarpShuffleOp::Up => "__shfl_up_sync(0xFFFFFFFF, ",
479                    WarpShuffleOp::Down => "__shfl_down_sync(0xFFFFFFFF, ",
480                    WarpShuffleOp::Xor => "__shfl_xor_sync(0xFFFFFFFF, ",
481                };
482                self.emit_line(&format!(
483                    "{} {} = {}{}, {})",
484                    ty, result_name, shfl_fn, val_name, lane_name
485                ));
486            }
487
488            // Select
489            IrNode::Select(cond, then_val, else_val) => {
490                let cond_name = self.get_value_name(*cond);
491                let then_name = self.get_value_name(*then_val);
492                let else_name = self.get_value_name(*else_val);
493                self.emit_line(&format!(
494                    "{} {} = {} ? {} : {};",
495                    ty, result_name, cond_name, then_name, else_name
496                ));
497            }
498
499            // Math functions
500            IrNode::Math(op, args) => {
501                let fn_name = self.lower_math_op(op);
502                let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
503                self.emit_line(&format!(
504                    "{} {} = {}({});",
505                    ty,
506                    result_name,
507                    fn_name,
508                    args_str.join(", ")
509                ));
510            }
511
512            // Skip nodes that don't produce CUDA output
513            IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
514
515            // ========================================================================
516            // Messaging Operations
517            // ========================================================================
518
519            // K2H (Kernel-to-Host) enqueue
520            IrNode::K2HEnqueue(value) => {
521                let val_name = self.get_value_name(*value);
522                // Enqueue returns success status (bool)
523                self.emit_line(&format!(
524                    "{} {} = __ringkernel_k2h_enqueue({});",
525                    ty, result_name, val_name
526                ));
527            }
528
529            // H2K (Host-to-Kernel) dequeue
530            IrNode::H2KDequeue => {
531                // Dequeue returns the message struct
532                self.emit_line(&format!(
533                    "{} {} = __ringkernel_h2k_dequeue();",
534                    ty, result_name
535                ));
536            }
537
538            // H2K queue empty check
539            IrNode::H2KIsEmpty => {
540                // Returns true if queue is empty
541                self.emit_line(&format!(
542                    "{} {} = __ringkernel_h2k_is_empty();",
543                    ty, result_name
544                ));
545            }
546
547            // K2K (Kernel-to-Kernel) send
548            IrNode::K2KSend(target_id, message) => {
549                let target_name = self.get_value_name(*target_id);
550                let msg_name = self.get_value_name(*message);
551                // Send returns success status (bool)
552                self.emit_line(&format!(
553                    "{} {} = __ringkernel_k2k_send({}, {});",
554                    ty, result_name, target_name, msg_name
555                ));
556            }
557
558            // K2K blocking receive
559            IrNode::K2KRecv => {
560                // Blocking receive returns the message struct
561                self.emit_line(&format!(
562                    "{} {} = __ringkernel_k2k_recv();",
563                    ty, result_name
564                ));
565            }
566
567            // K2K non-blocking try receive
568            IrNode::K2KTryRecv => {
569                // Try receive returns optional message (use .valid field to check)
570                self.emit_line(&format!(
571                    "{} {} = __ringkernel_k2k_try_recv();",
572                    ty, result_name
573                ));
574            }
575
576            // ========================================================================
577            // HLC (Hybrid Logical Clock) Operations
578            // ========================================================================
579
580            // Get current HLC time
581            IrNode::HlcNow => {
582                // Returns current HLC timestamp (u64)
583                self.emit_line(&format!("{} {} = __ringkernel_hlc_now();", ty, result_name));
584            }
585
586            // Tick HLC and return new time
587            IrNode::HlcTick => {
588                // Increments logical counter and returns new timestamp
589                self.emit_line(&format!(
590                    "{} {} = __ringkernel_hlc_tick();",
591                    ty, result_name
592                ));
593            }
594
595            // Update HLC from incoming timestamp
596            IrNode::HlcUpdate(incoming) => {
597                let incoming_name = self.get_value_name(*incoming);
598                // Updates HLC using max(local, incoming) + 1 rule
599                self.emit_line(&format!(
600                    "{} {} = __ringkernel_hlc_update({});",
601                    ty, result_name, incoming_name
602                ));
603            }
604
605            _ => {
606                self.emit_line(&format!("// Unhandled: {:?}", node));
607            }
608        }
609
610        Ok(())
611    }
612
613    fn emit_terminator(&mut self, term: &Terminator) -> Result<(), LoweringError> {
614        match term {
615            Terminator::Return(None) => {
616                self.emit_line("return;");
617            }
618            Terminator::Return(Some(val)) => {
619                let val_name = self.get_value_name(*val);
620                self.emit_line(&format!("return {};", val_name));
621            }
622            Terminator::Branch(target) => {
623                let label = self.block_labels.get(target).cloned().unwrap_or_default();
624                self.emit_line(&format!("goto {};", label));
625            }
626            Terminator::CondBranch(cond, then_block, else_block) => {
627                let cond_name = self.get_value_name(*cond);
628                let then_label = self
629                    .block_labels
630                    .get(then_block)
631                    .cloned()
632                    .unwrap_or_default();
633                let else_label = self
634                    .block_labels
635                    .get(else_block)
636                    .cloned()
637                    .unwrap_or_default();
638                self.emit_line(&format!(
639                    "if ({}) goto {}; else goto {};",
640                    cond_name, then_label, else_label
641                ));
642            }
643            Terminator::Switch(val, default, cases) => {
644                let val_name = self.get_value_name(*val);
645                self.emit_line(&format!("switch ({}) {{", val_name));
646                self.indent += 1;
647                for (case_val, target) in cases {
648                    let case_str = self.lower_constant(case_val);
649                    let label = self.block_labels.get(target).cloned().unwrap_or_default();
650                    self.emit_line(&format!("case {}: goto {};", case_str, label));
651                }
652                let default_label = self.block_labels.get(default).cloned().unwrap_or_default();
653                self.emit_line(&format!("default: goto {};", default_label));
654                self.indent -= 1;
655                self.emit_line("}");
656            }
657            Terminator::Unreachable => {
658                self.emit_line("__builtin_unreachable();");
659            }
660        }
661        Ok(())
662    }
663
664    fn lower_type(&self, ty: &IrType) -> String {
665        match ty {
666            IrType::Void => "void".to_string(),
667            IrType::Scalar(s) => self.lower_scalar_type(s),
668            IrType::Vector(v) => format!("{}{}", self.lower_scalar_type(&v.element), v.count),
669            IrType::Ptr(inner) => format!("{}*", self.lower_type(inner)),
670            IrType::Array(inner, size) => format!("{}[{}]", self.lower_type(inner), size),
671            IrType::Slice(inner) => format!("{}*", self.lower_type(inner)),
672            IrType::Struct(s) => s.name.clone(),
673            IrType::Function(_) => "void*".to_string(), // Function pointers
674        }
675    }
676
677    fn lower_scalar_type(&self, ty: &ScalarType) -> String {
678        match ty {
679            ScalarType::Bool => "bool",
680            ScalarType::I8 => "int8_t",
681            ScalarType::I16 => "int16_t",
682            ScalarType::I32 => "int32_t",
683            ScalarType::I64 => "int64_t",
684            ScalarType::U8 => "uint8_t",
685            ScalarType::U16 => "uint16_t",
686            ScalarType::U32 => "uint32_t",
687            ScalarType::U64 => "uint64_t",
688            ScalarType::F16 => "__half",
689            ScalarType::F32 => "float",
690            ScalarType::F64 => "double",
691        }
692        .to_string()
693    }
694
695    fn lower_constant(&self, c: &ConstantValue) -> String {
696        match c {
697            ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
698            ConstantValue::I32(v) => format!("{}", v),
699            ConstantValue::I64(v) => format!("{}LL", v),
700            ConstantValue::U32(v) => format!("{}u", v),
701            ConstantValue::U64(v) => format!("{}ull", v),
702            ConstantValue::F32(v) => format!("{}f", v),
703            ConstantValue::F64(v) => format!("{}", v),
704            ConstantValue::Null => "nullptr".to_string(),
705            ConstantValue::Array(elems) => {
706                let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
707                format!("{{{}}}", elems_str.join(", "))
708            }
709            ConstantValue::Struct(fields) => {
710                let fields_str: Vec<String> =
711                    fields.iter().map(|f| self.lower_constant(f)).collect();
712                format!("{{{}}}", fields_str.join(", "))
713            }
714        }
715    }
716
717    fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
718        match op {
719            BinaryOp::Add => format!("{} + {}", lhs, rhs),
720            BinaryOp::Sub => format!("{} - {}", lhs, rhs),
721            BinaryOp::Mul => format!("{} * {}", lhs, rhs),
722            BinaryOp::Div => format!("{} / {}", lhs, rhs),
723            BinaryOp::Rem => format!("{} % {}", lhs, rhs),
724            BinaryOp::And => format!("{} & {}", lhs, rhs),
725            BinaryOp::Or => format!("{} | {}", lhs, rhs),
726            BinaryOp::Xor => format!("{} ^ {}", lhs, rhs),
727            BinaryOp::Shl => format!("{} << {}", lhs, rhs),
728            BinaryOp::Shr => format!("{} >> {}", lhs, rhs),
729            BinaryOp::Sar => format!("{} >> {}", lhs, rhs), // C handles sign extension
730            BinaryOp::Fma => format!("fma({}, {}, 0.0f)", lhs, rhs), // Would need third arg
731            BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
732            BinaryOp::Min => format!("min({}, {})", lhs, rhs),
733            BinaryOp::Max => format!("max({}, {})", lhs, rhs),
734        }
735    }
736
737    fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
738        match op {
739            UnaryOp::Neg => format!("-{}", val),
740            UnaryOp::Not => format!("~{}", val),
741            UnaryOp::LogicalNot => format!("!{}", val),
742            UnaryOp::Abs => format!("abs({})", val),
743            UnaryOp::Sqrt => format!("sqrt({})", val),
744            UnaryOp::Rsqrt => format!("rsqrt({})", val),
745            UnaryOp::Floor => format!("floor({})", val),
746            UnaryOp::Ceil => format!("ceil({})", val),
747            UnaryOp::Round => format!("round({})", val),
748            UnaryOp::Trunc => format!("trunc({})", val),
749            UnaryOp::Sign => format!("copysign(1.0f, {})", val),
750        }
751    }
752
753    fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
754        match op {
755            CompareOp::Eq => "==",
756            CompareOp::Ne => "!=",
757            CompareOp::Lt => "<",
758            CompareOp::Le => "<=",
759            CompareOp::Gt => ">",
760            CompareOp::Ge => ">=",
761        }
762    }
763
764    fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
765        match dim {
766            Dimension::X => format!("{}.x", prefix),
767            Dimension::Y => format!("{}.y", prefix),
768            Dimension::Z => format!("{}.z", prefix),
769        }
770    }
771
772    fn lower_math_op(&self, op: &MathOp) -> &'static str {
773        match op {
774            MathOp::Sin => "sin",
775            MathOp::Cos => "cos",
776            MathOp::Tan => "tan",
777            MathOp::Asin => "asin",
778            MathOp::Acos => "acos",
779            MathOp::Atan => "atan",
780            MathOp::Atan2 => "atan2",
781            MathOp::Sinh => "sinh",
782            MathOp::Cosh => "cosh",
783            MathOp::Tanh => "tanh",
784            MathOp::Exp => "exp",
785            MathOp::Exp2 => "exp2",
786            MathOp::Log => "log",
787            MathOp::Log2 => "log2",
788            MathOp::Log10 => "log10",
789            MathOp::Lerp => "lerp",
790            MathOp::Clamp => "clamp",
791            MathOp::Step => "step",
792            MathOp::SmoothStep => "smoothstep",
793            MathOp::Fract => "fract",
794            MathOp::CopySign => "copysign",
795        }
796    }
797
798    fn get_value_name(&self, id: ValueId) -> String {
799        self.value_names
800            .get(&id)
801            .cloned()
802            .unwrap_or_else(|| format!("v{}", id.raw()))
803    }
804
805    fn get_or_create_name(&mut self, id: ValueId) -> String {
806        if let Some(name) = self.value_names.get(&id) {
807            return name.clone();
808        }
809        let name = format!("t{}", self.name_counter);
810        self.name_counter += 1;
811        self.value_names.insert(id, name.clone());
812        name
813    }
814
815    fn emit_line(&mut self, line: &str) {
816        let indent = "    ".repeat(self.indent);
817        writeln!(self.output, "{}{}", indent, line).unwrap();
818    }
819}
820
821/// Lowering errors.
822#[derive(Debug, Clone)]
823pub enum LoweringError {
824    /// Unsupported capability.
825    UnsupportedCapability(String),
826    /// Undefined block reference.
827    UndefinedBlock(BlockId),
828    /// Undefined value reference.
829    UndefinedValue(ValueId),
830    /// Requires cooperative groups.
831    RequiresCooperativeGroups,
832    /// Type error.
833    TypeError(String),
834}
835
836impl std::fmt::Display for LoweringError {
837    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
838        match self {
839            LoweringError::UnsupportedCapability(cap) => {
840                write!(f, "Unsupported capability: {}", cap)
841            }
842            LoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
843            LoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
844            LoweringError::RequiresCooperativeGroups => {
845                write!(f, "Operation requires cooperative groups")
846            }
847            LoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
848        }
849    }
850}
851
852impl std::error::Error for LoweringError {}
853
854/// Convenience function to lower IR to CUDA.
855pub fn lower_to_cuda(module: &IrModule) -> Result<String, LoweringError> {
856    CudaLowering::new(CudaLoweringConfig::default()).lower(module)
857}
858
859/// Lower IR to CUDA with custom config.
860pub fn lower_to_cuda_with_config(
861    module: &IrModule,
862    config: CudaLoweringConfig,
863) -> Result<String, LoweringError> {
864    CudaLowering::new(config).lower(module)
865}
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870    use crate::IrBuilder;
871
872    #[test]
873    fn test_lower_simple_kernel() {
874        let mut builder = IrBuilder::new("add_one");
875
876        let x = builder.parameter("x", IrType::ptr(IrType::F32));
877        let n = builder.parameter("n", IrType::I32);
878
879        let idx = builder.global_thread_id(Dimension::X);
880        let in_bounds = builder.lt(idx, n);
881
882        let then_block = builder.create_block("body");
883        let end_block = builder.create_block("end");
884
885        builder.cond_branch(in_bounds, then_block, end_block);
886
887        builder.switch_to_block(then_block);
888        let one = builder.const_f32(1.0);
889        let ptr = builder.gep(x, vec![idx]);
890        let val = builder.load(ptr);
891        let result = builder.add(val, one);
892        builder.store(ptr, result);
893        builder.branch(end_block);
894
895        builder.switch_to_block(end_block);
896        builder.ret();
897
898        let module = builder.build();
899        let cuda = lower_to_cuda(&module).unwrap();
900
901        assert!(cuda.contains("__global__ void add_one"));
902        assert!(cuda.contains("float* x"));
903        assert!(cuda.contains("int32_t n"));
904        assert!(cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"));
905    }
906
907    #[test]
908    fn test_lower_with_shared_memory() {
909        let mut builder = IrBuilder::new("reduce");
910
911        let _x = builder.parameter("x", IrType::ptr(IrType::F32));
912
913        let shared = builder.shared_alloc(IrType::F32, 256);
914        let _ = shared;
915
916        builder.barrier();
917        builder.ret();
918
919        let module = builder.build();
920        let cuda = lower_to_cuda(&module).unwrap();
921
922        assert!(cuda.contains("__shared__ float"));
923        assert!(cuda.contains("__syncthreads()"));
924    }
925
926    #[test]
927    fn test_lower_with_atomics() {
928        let mut builder = IrBuilder::new("atomic_add");
929
930        let counter = builder.parameter("counter", IrType::ptr(IrType::U32));
931
932        let one = builder.const_u32(1);
933        let _old = builder.atomic_add(counter, one);
934
935        builder.ret();
936
937        let module = builder.build();
938        let cuda = lower_to_cuda(&module).unwrap();
939
940        assert!(cuda.contains("atomicAdd"));
941    }
942
943    #[test]
944    fn test_lower_with_cooperative_groups() {
945        let mut builder = IrBuilder::new("grid_reduce");
946        builder.grid_sync();
947        builder.ret();
948
949        let module = builder.build();
950
951        // Without cooperative groups, should fail
952        let result = lower_to_cuda(&module);
953        assert!(result.is_err());
954
955        // With cooperative groups, should succeed
956        let config = CudaLoweringConfig::sm80();
957        let cuda = lower_to_cuda_with_config(&module, config).unwrap();
958
959        assert!(cuda.contains("cooperative_groups"));
960        assert!(cuda.contains("grid.sync()"));
961    }
962
963    #[test]
964    fn test_lower_binary_ops() {
965        let mut builder = IrBuilder::new("math");
966
967        let a = builder.const_f32(1.0);
968        let b = builder.const_f32(2.0);
969
970        let _sum = builder.add(a, b);
971        let _diff = builder.sub(a, b);
972        let _prod = builder.mul(a, b);
973        let _quot = builder.div(a, b);
974        let _min = builder.min(a, b);
975        let _max = builder.max(a, b);
976
977        builder.ret();
978
979        let module = builder.build();
980        let cuda = lower_to_cuda(&module).unwrap();
981
982        assert!(cuda.contains("+"));
983        assert!(cuda.contains("-"));
984        assert!(cuda.contains("*"));
985        assert!(cuda.contains("/"));
986        assert!(cuda.contains("min("));
987        assert!(cuda.contains("max("));
988    }
989}