1use std::collections::HashMap;
6use std::fmt::Write;
7
8use crate::{
9 nodes::*, BackendCapabilities, BlockId, Dimension, IrModule, IrNode, IrType, ScalarType,
10 Terminator, ValueId,
11};
12
13#[derive(Debug, Clone)]
15pub struct CudaLoweringConfig {
16 pub compute_capability: u32,
18 pub cooperative_groups: bool,
20 pub enable_hlc: bool,
22 pub enable_k2k: bool,
24 pub fast_math: bool,
26 pub debug: bool,
28}
29
30impl Default for CudaLoweringConfig {
31 fn default() -> Self {
32 Self {
33 compute_capability: 70,
34 cooperative_groups: false,
35 enable_hlc: false,
36 enable_k2k: false,
37 fast_math: false,
38 debug: false,
39 }
40 }
41}
42
43impl CudaLoweringConfig {
44 pub fn sm80() -> Self {
46 Self {
47 compute_capability: 80,
48 cooperative_groups: true,
49 ..Default::default()
50 }
51 }
52
53 pub fn with_persistent(mut self) -> Self {
55 self.enable_hlc = true;
56 self.enable_k2k = true;
57 self.cooperative_groups = true;
58 self
59 }
60}
61
62pub struct CudaLowering {
64 config: CudaLoweringConfig,
65 output: String,
66 indent: usize,
67 value_names: HashMap<ValueId, String>,
68 name_counter: usize,
69 block_labels: HashMap<BlockId, String>,
70}
71
72impl CudaLowering {
73 pub fn new(config: CudaLoweringConfig) -> Self {
75 Self {
76 config,
77 output: String::new(),
78 indent: 0,
79 value_names: HashMap::new(),
80 name_counter: 0,
81 block_labels: HashMap::new(),
82 }
83 }
84
85 pub fn lower(mut self, module: &IrModule) -> Result<String, LoweringError> {
87 self.check_capabilities(module)?;
89
90 self.emit_includes();
92
93 self.emit_type_definitions(module);
95
96 self.emit_kernel(module)?;
98
99 Ok(self.output)
100 }
101
102 fn check_capabilities(&self, module: &IrModule) -> Result<(), LoweringError> {
103 let cuda_caps = BackendCapabilities::cuda_sm80();
104
105 let unsupported = cuda_caps.unsupported(&module.required_capabilities);
106 if !unsupported.is_empty() {
107 return Err(LoweringError::UnsupportedCapability(
108 unsupported
109 .iter()
110 .map(|c| format!("{}", c))
111 .collect::<Vec<_>>()
112 .join(", "),
113 ));
114 }
115
116 Ok(())
117 }
118
119 fn emit_includes(&mut self) {
120 self.emit_line("// Generated by ringkernel-ir CUDA lowering");
121 self.emit_line("#include <cuda_runtime.h>");
122 self.emit_line("#include <stdint.h>");
123
124 if self.config.cooperative_groups {
125 self.emit_line("#include <cooperative_groups.h>");
126 self.emit_line("namespace cg = cooperative_groups;");
127 }
128
129 self.emit_line("");
130 }
131
132 fn emit_type_definitions(&mut self, _module: &IrModule) {
133 if self.config.enable_hlc {
135 self.emit_line("// HLC Timestamp");
136 self.emit_line("struct HlcTimestamp {");
137 self.indent += 1;
138 self.emit_line("uint64_t physical;");
139 self.emit_line("uint64_t logical;");
140 self.emit_line("uint64_t node_id;");
141 self.indent -= 1;
142 self.emit_line("};");
143 self.emit_line("");
144 }
145
146 if self.config.enable_k2k {
148 self.emit_line("// Control Block");
149 self.emit_line("struct ControlBlock {");
150 self.indent += 1;
151 self.emit_line("uint32_t is_active;");
152 self.emit_line("uint32_t should_terminate;");
153 self.emit_line("uint32_t has_terminated;");
154 self.emit_line("uint32_t _pad1;");
155 self.emit_line("uint64_t messages_processed;");
156 self.emit_line("uint64_t messages_in_flight;");
157 self.emit_line("uint64_t input_head;");
158 self.emit_line("uint64_t input_tail;");
159 self.emit_line("uint64_t output_head;");
160 self.emit_line("uint64_t output_tail;");
161 self.emit_line("uint32_t input_capacity;");
162 self.emit_line("uint32_t output_capacity;");
163 self.emit_line("uint32_t input_mask;");
164 self.emit_line("uint32_t output_mask;");
165 self.indent -= 1;
166 self.emit_line("};");
167 self.emit_line("");
168
169 self.emit_line("// Queue Intrinsics (provided by runtime)");
171 self.emit_line("__device__ bool __ringkernel_k2h_enqueue(const void* msg);");
172 self.emit_line("__device__ void* __ringkernel_h2k_dequeue();");
173 self.emit_line("__device__ bool __ringkernel_h2k_is_empty();");
174 self.emit_line("");
175
176 self.emit_line("// K2K Messaging Intrinsics (provided by runtime)");
178 self.emit_line(
179 "__device__ bool __ringkernel_k2k_send(uint64_t target_id, const void* msg);",
180 );
181 self.emit_line("__device__ void* __ringkernel_k2k_recv();");
182 self.emit_line("struct K2KOptionalMsg { bool valid; void* data; };");
183 self.emit_line("__device__ K2KOptionalMsg __ringkernel_k2k_try_recv();");
184 self.emit_line("");
185 }
186
187 if self.config.enable_hlc {
189 self.emit_line("// HLC Intrinsics (provided by runtime)");
190 self.emit_line("__device__ uint64_t __ringkernel_hlc_now();");
191 self.emit_line("__device__ uint64_t __ringkernel_hlc_tick();");
192 self.emit_line("__device__ uint64_t __ringkernel_hlc_update(uint64_t incoming);");
193 self.emit_line("");
194 }
195 }
196
197 fn emit_kernel(&mut self, module: &IrModule) -> Result<(), LoweringError> {
198 self.assign_names(module);
200
201 let kernel_attr = if self.config.cooperative_groups {
203 "__global__ void __launch_bounds__(256)"
204 } else {
205 "__global__ void"
206 };
207
208 write!(self.output, "{} {}(", kernel_attr, module.name).unwrap();
209
210 for (i, param) in module.parameters.iter().enumerate() {
212 if i > 0 {
213 write!(self.output, ", ").unwrap();
214 }
215 let ty = self.lower_type(¶m.ty);
216 write!(self.output, "{} {}", ty, param.name).unwrap();
217 }
218
219 self.emit_line(") {");
220 self.indent += 1;
221
222 if self.config.cooperative_groups {
224 self.emit_line("cg::grid_group grid = cg::this_grid();");
225 self.emit_line("cg::thread_block block = cg::this_thread_block();");
226 self.emit_line("");
227 }
228
229 self.emit_block(module, module.entry_block)?;
231
232 for block_id in module.blocks.keys() {
234 if *block_id != module.entry_block {
235 self.emit_block(module, *block_id)?;
236 }
237 }
238
239 self.indent -= 1;
240 self.emit_line("}");
241
242 Ok(())
243 }
244
245 fn assign_names(&mut self, module: &IrModule) {
246 for param in &module.parameters {
248 self.value_names.insert(param.value_id, param.name.clone());
249 }
250
251 for (block_id, block) in &module.blocks {
253 self.block_labels.insert(*block_id, block.label.clone());
254 }
255 }
256
257 fn emit_block(&mut self, module: &IrModule, block_id: BlockId) -> Result<(), LoweringError> {
258 let block = module
259 .blocks
260 .get(&block_id)
261 .ok_or(LoweringError::UndefinedBlock(block_id))?;
262
263 if block_id != module.entry_block {
265 self.emit_line(&format!("{}: {{", block.label));
266 self.indent += 1;
267 }
268
269 for inst in &block.instructions {
271 self.emit_instruction(module, &inst.result, &inst.result_type, &inst.node)?;
272 }
273
274 if let Some(term) = &block.terminator {
276 self.emit_terminator(term)?;
277 }
278
279 if block_id != module.entry_block {
280 self.indent -= 1;
281 self.emit_line("}");
282 }
283
284 Ok(())
285 }
286
287 fn emit_instruction(
288 &mut self,
289 _module: &IrModule,
290 result: &ValueId,
291 result_type: &IrType,
292 node: &IrNode,
293 ) -> Result<(), LoweringError> {
294 let result_name = self.get_or_create_name(*result);
295 let ty = self.lower_type(result_type);
296
297 match node {
298 IrNode::Constant(c) => {
300 let val = self.lower_constant(c);
301 self.emit_line(&format!("{} {} = {};", ty, result_name, val));
302 }
303
304 IrNode::BinaryOp(op, lhs, rhs) => {
306 let lhs_name = self.get_value_name(*lhs);
307 let rhs_name = self.get_value_name(*rhs);
308 let expr = self.lower_binary_op(op, &lhs_name, &rhs_name);
309 self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
310 }
311
312 IrNode::UnaryOp(op, val) => {
314 let val_name = self.get_value_name(*val);
315 let expr = self.lower_unary_op(op, &val_name);
316 self.emit_line(&format!("{} {} = {};", ty, result_name, expr));
317 }
318
319 IrNode::Compare(op, lhs, rhs) => {
321 let lhs_name = self.get_value_name(*lhs);
322 let rhs_name = self.get_value_name(*rhs);
323 let cmp_op = self.lower_compare_op(op);
324 self.emit_line(&format!(
325 "bool {} = {} {} {};",
326 result_name, lhs_name, cmp_op, rhs_name
327 ));
328 }
329
330 IrNode::Load(ptr) => {
332 let ptr_name = self.get_value_name(*ptr);
333 self.emit_line(&format!("{} {} = *{};", ty, result_name, ptr_name));
334 }
335
336 IrNode::Store(ptr, val) => {
337 let ptr_name = self.get_value_name(*ptr);
338 let val_name = self.get_value_name(*val);
339 self.emit_line(&format!("*{} = {};", ptr_name, val_name));
340 }
341
342 IrNode::GetElementPtr(ptr, indices) => {
343 let ptr_name = self.get_value_name(*ptr);
344 let idx_name = self.get_value_name(indices[0]);
345 self.emit_line(&format!(
346 "{} {} = &{}[{}];",
347 ty, result_name, ptr_name, idx_name
348 ));
349 }
350
351 IrNode::SharedAlloc(elem_ty, count) => {
352 let elem = self.lower_type(elem_ty);
353 self.emit_line(&format!("__shared__ {} {}[{}];", elem, result_name, count));
354 }
355
356 IrNode::ThreadId(dim) => {
358 let idx = self.lower_dimension(dim, "threadIdx");
359 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
360 }
361
362 IrNode::BlockId(dim) => {
363 let idx = self.lower_dimension(dim, "blockIdx");
364 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
365 }
366
367 IrNode::BlockDim(dim) => {
368 let idx = self.lower_dimension(dim, "blockDim");
369 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
370 }
371
372 IrNode::GridDim(dim) => {
373 let idx = self.lower_dimension(dim, "gridDim");
374 self.emit_line(&format!("{} {} = {};", ty, result_name, idx));
375 }
376
377 IrNode::GlobalThreadId(dim) => {
378 let block_idx = self.lower_dimension(dim, "blockIdx");
379 let block_dim = self.lower_dimension(dim, "blockDim");
380 let thread_idx = self.lower_dimension(dim, "threadIdx");
381 self.emit_line(&format!(
382 "{} {} = {} * {} + {};",
383 ty, result_name, block_idx, block_dim, thread_idx
384 ));
385 }
386
387 IrNode::WarpId => {
388 self.emit_line(&format!("{} {} = threadIdx.x / 32;", ty, result_name));
389 }
390
391 IrNode::LaneId => {
392 self.emit_line(&format!("{} {} = threadIdx.x % 32;", ty, result_name));
393 }
394
395 IrNode::Barrier => {
397 self.emit_line("__syncthreads();");
398 }
399
400 IrNode::MemoryFence(scope) => {
401 let fence = match scope {
402 MemoryScope::Thread => "__threadfence_block()",
403 MemoryScope::Threadgroup => "__threadfence_block()",
404 MemoryScope::Device => "__threadfence()",
405 MemoryScope::System => "__threadfence_system()",
406 };
407 self.emit_line(&format!("{};", fence));
408 }
409
410 IrNode::GridSync => {
411 if self.config.cooperative_groups {
412 self.emit_line("grid.sync();");
413 } else {
414 return Err(LoweringError::RequiresCooperativeGroups);
415 }
416 }
417
418 IrNode::Atomic(op, ptr, val) => {
420 let ptr_name = self.get_value_name(*ptr);
421 let val_name = self.get_value_name(*val);
422 let atomic_fn = match op {
423 AtomicOp::Add => "atomicAdd",
424 AtomicOp::Sub => "atomicSub",
425 AtomicOp::Exchange => "atomicExch",
426 AtomicOp::Min => "atomicMin",
427 AtomicOp::Max => "atomicMax",
428 AtomicOp::And => "atomicAnd",
429 AtomicOp::Or => "atomicOr",
430 AtomicOp::Xor => "atomicXor",
431 AtomicOp::Load => {
432 self.emit_line(&format!(
433 "{} {} = atomicAdd({}, 0);",
434 ty, result_name, ptr_name
435 ));
436 return Ok(());
437 }
438 AtomicOp::Store => {
439 self.emit_line(&format!("atomicExch({}, {});", ptr_name, val_name));
440 return Ok(());
441 }
442 };
443 self.emit_line(&format!(
444 "{} {} = {}({}, {});",
445 ty, result_name, atomic_fn, ptr_name, val_name
446 ));
447 }
448
449 IrNode::AtomicCas(ptr, expected, desired) => {
450 let ptr_name = self.get_value_name(*ptr);
451 let exp_name = self.get_value_name(*expected);
452 let des_name = self.get_value_name(*desired);
453 self.emit_line(&format!(
454 "{} {} = atomicCAS({}, {}, {});",
455 ty, result_name, ptr_name, exp_name, des_name
456 ));
457 }
458
459 IrNode::WarpVote(op, val) => {
461 let val_name = self.get_value_name(*val);
462 let vote_fn = match op {
463 WarpVoteOp::All => "__all_sync(0xFFFFFFFF, ",
464 WarpVoteOp::Any => "__any_sync(0xFFFFFFFF, ",
465 WarpVoteOp::Ballot => "__ballot_sync(0xFFFFFFFF, ",
466 };
467 self.emit_line(&format!(
468 "{} {} = {}{})",
469 ty, result_name, vote_fn, val_name
470 ));
471 }
472
473 IrNode::WarpShuffle(op, val, lane) => {
474 let val_name = self.get_value_name(*val);
475 let lane_name = self.get_value_name(*lane);
476 let shfl_fn = match op {
477 WarpShuffleOp::Index => "__shfl_sync(0xFFFFFFFF, ",
478 WarpShuffleOp::Up => "__shfl_up_sync(0xFFFFFFFF, ",
479 WarpShuffleOp::Down => "__shfl_down_sync(0xFFFFFFFF, ",
480 WarpShuffleOp::Xor => "__shfl_xor_sync(0xFFFFFFFF, ",
481 };
482 self.emit_line(&format!(
483 "{} {} = {}{}, {})",
484 ty, result_name, shfl_fn, val_name, lane_name
485 ));
486 }
487
488 IrNode::Select(cond, then_val, else_val) => {
490 let cond_name = self.get_value_name(*cond);
491 let then_name = self.get_value_name(*then_val);
492 let else_name = self.get_value_name(*else_val);
493 self.emit_line(&format!(
494 "{} {} = {} ? {} : {};",
495 ty, result_name, cond_name, then_name, else_name
496 ));
497 }
498
499 IrNode::Math(op, args) => {
501 let fn_name = self.lower_math_op(op);
502 let args_str: Vec<String> = args.iter().map(|a| self.get_value_name(*a)).collect();
503 self.emit_line(&format!(
504 "{} {} = {}({});",
505 ty,
506 result_name,
507 fn_name,
508 args_str.join(", ")
509 ));
510 }
511
512 IrNode::Parameter(_) | IrNode::Undef | IrNode::Phi(_) => {}
514
515 IrNode::K2HEnqueue(value) => {
521 let val_name = self.get_value_name(*value);
522 self.emit_line(&format!(
524 "{} {} = __ringkernel_k2h_enqueue({});",
525 ty, result_name, val_name
526 ));
527 }
528
529 IrNode::H2KDequeue => {
531 self.emit_line(&format!(
533 "{} {} = __ringkernel_h2k_dequeue();",
534 ty, result_name
535 ));
536 }
537
538 IrNode::H2KIsEmpty => {
540 self.emit_line(&format!(
542 "{} {} = __ringkernel_h2k_is_empty();",
543 ty, result_name
544 ));
545 }
546
547 IrNode::K2KSend(target_id, message) => {
549 let target_name = self.get_value_name(*target_id);
550 let msg_name = self.get_value_name(*message);
551 self.emit_line(&format!(
553 "{} {} = __ringkernel_k2k_send({}, {});",
554 ty, result_name, target_name, msg_name
555 ));
556 }
557
558 IrNode::K2KRecv => {
560 self.emit_line(&format!(
562 "{} {} = __ringkernel_k2k_recv();",
563 ty, result_name
564 ));
565 }
566
567 IrNode::K2KTryRecv => {
569 self.emit_line(&format!(
571 "{} {} = __ringkernel_k2k_try_recv();",
572 ty, result_name
573 ));
574 }
575
576 IrNode::HlcNow => {
582 self.emit_line(&format!("{} {} = __ringkernel_hlc_now();", ty, result_name));
584 }
585
586 IrNode::HlcTick => {
588 self.emit_line(&format!(
590 "{} {} = __ringkernel_hlc_tick();",
591 ty, result_name
592 ));
593 }
594
595 IrNode::HlcUpdate(incoming) => {
597 let incoming_name = self.get_value_name(*incoming);
598 self.emit_line(&format!(
600 "{} {} = __ringkernel_hlc_update({});",
601 ty, result_name, incoming_name
602 ));
603 }
604
605 _ => {
606 self.emit_line(&format!("// Unhandled: {:?}", node));
607 }
608 }
609
610 Ok(())
611 }
612
613 fn emit_terminator(&mut self, term: &Terminator) -> Result<(), LoweringError> {
614 match term {
615 Terminator::Return(None) => {
616 self.emit_line("return;");
617 }
618 Terminator::Return(Some(val)) => {
619 let val_name = self.get_value_name(*val);
620 self.emit_line(&format!("return {};", val_name));
621 }
622 Terminator::Branch(target) => {
623 let label = self.block_labels.get(target).cloned().unwrap_or_default();
624 self.emit_line(&format!("goto {};", label));
625 }
626 Terminator::CondBranch(cond, then_block, else_block) => {
627 let cond_name = self.get_value_name(*cond);
628 let then_label = self
629 .block_labels
630 .get(then_block)
631 .cloned()
632 .unwrap_or_default();
633 let else_label = self
634 .block_labels
635 .get(else_block)
636 .cloned()
637 .unwrap_or_default();
638 self.emit_line(&format!(
639 "if ({}) goto {}; else goto {};",
640 cond_name, then_label, else_label
641 ));
642 }
643 Terminator::Switch(val, default, cases) => {
644 let val_name = self.get_value_name(*val);
645 self.emit_line(&format!("switch ({}) {{", val_name));
646 self.indent += 1;
647 for (case_val, target) in cases {
648 let case_str = self.lower_constant(case_val);
649 let label = self.block_labels.get(target).cloned().unwrap_or_default();
650 self.emit_line(&format!("case {}: goto {};", case_str, label));
651 }
652 let default_label = self.block_labels.get(default).cloned().unwrap_or_default();
653 self.emit_line(&format!("default: goto {};", default_label));
654 self.indent -= 1;
655 self.emit_line("}");
656 }
657 Terminator::Unreachable => {
658 self.emit_line("__builtin_unreachable();");
659 }
660 }
661 Ok(())
662 }
663
664 fn lower_type(&self, ty: &IrType) -> String {
665 match ty {
666 IrType::Void => "void".to_string(),
667 IrType::Scalar(s) => self.lower_scalar_type(s),
668 IrType::Vector(v) => format!("{}{}", self.lower_scalar_type(&v.element), v.count),
669 IrType::Ptr(inner) => format!("{}*", self.lower_type(inner)),
670 IrType::Array(inner, size) => format!("{}[{}]", self.lower_type(inner), size),
671 IrType::Slice(inner) => format!("{}*", self.lower_type(inner)),
672 IrType::Struct(s) => s.name.clone(),
673 IrType::Function(_) => "void*".to_string(), }
675 }
676
677 fn lower_scalar_type(&self, ty: &ScalarType) -> String {
678 match ty {
679 ScalarType::Bool => "bool",
680 ScalarType::I8 => "int8_t",
681 ScalarType::I16 => "int16_t",
682 ScalarType::I32 => "int32_t",
683 ScalarType::I64 => "int64_t",
684 ScalarType::U8 => "uint8_t",
685 ScalarType::U16 => "uint16_t",
686 ScalarType::U32 => "uint32_t",
687 ScalarType::U64 => "uint64_t",
688 ScalarType::F16 => "__half",
689 ScalarType::F32 => "float",
690 ScalarType::F64 => "double",
691 }
692 .to_string()
693 }
694
695 fn lower_constant(&self, c: &ConstantValue) -> String {
696 match c {
697 ConstantValue::Bool(b) => if *b { "true" } else { "false" }.to_string(),
698 ConstantValue::I32(v) => format!("{}", v),
699 ConstantValue::I64(v) => format!("{}LL", v),
700 ConstantValue::U32(v) => format!("{}u", v),
701 ConstantValue::U64(v) => format!("{}ull", v),
702 ConstantValue::F32(v) => format!("{}f", v),
703 ConstantValue::F64(v) => format!("{}", v),
704 ConstantValue::Null => "nullptr".to_string(),
705 ConstantValue::Array(elems) => {
706 let elems_str: Vec<String> = elems.iter().map(|e| self.lower_constant(e)).collect();
707 format!("{{{}}}", elems_str.join(", "))
708 }
709 ConstantValue::Struct(fields) => {
710 let fields_str: Vec<String> =
711 fields.iter().map(|f| self.lower_constant(f)).collect();
712 format!("{{{}}}", fields_str.join(", "))
713 }
714 }
715 }
716
717 fn lower_binary_op(&self, op: &BinaryOp, lhs: &str, rhs: &str) -> String {
718 match op {
719 BinaryOp::Add => format!("{} + {}", lhs, rhs),
720 BinaryOp::Sub => format!("{} - {}", lhs, rhs),
721 BinaryOp::Mul => format!("{} * {}", lhs, rhs),
722 BinaryOp::Div => format!("{} / {}", lhs, rhs),
723 BinaryOp::Rem => format!("{} % {}", lhs, rhs),
724 BinaryOp::And => format!("{} & {}", lhs, rhs),
725 BinaryOp::Or => format!("{} | {}", lhs, rhs),
726 BinaryOp::Xor => format!("{} ^ {}", lhs, rhs),
727 BinaryOp::Shl => format!("{} << {}", lhs, rhs),
728 BinaryOp::Shr => format!("{} >> {}", lhs, rhs),
729 BinaryOp::Sar => format!("{} >> {}", lhs, rhs), BinaryOp::Fma => format!("fma({}, {}, 0.0f)", lhs, rhs), BinaryOp::Pow => format!("pow({}, {})", lhs, rhs),
732 BinaryOp::Min => format!("min({}, {})", lhs, rhs),
733 BinaryOp::Max => format!("max({}, {})", lhs, rhs),
734 }
735 }
736
737 fn lower_unary_op(&self, op: &UnaryOp, val: &str) -> String {
738 match op {
739 UnaryOp::Neg => format!("-{}", val),
740 UnaryOp::Not => format!("~{}", val),
741 UnaryOp::LogicalNot => format!("!{}", val),
742 UnaryOp::Abs => format!("abs({})", val),
743 UnaryOp::Sqrt => format!("sqrt({})", val),
744 UnaryOp::Rsqrt => format!("rsqrt({})", val),
745 UnaryOp::Floor => format!("floor({})", val),
746 UnaryOp::Ceil => format!("ceil({})", val),
747 UnaryOp::Round => format!("round({})", val),
748 UnaryOp::Trunc => format!("trunc({})", val),
749 UnaryOp::Sign => format!("copysign(1.0f, {})", val),
750 }
751 }
752
753 fn lower_compare_op(&self, op: &CompareOp) -> &'static str {
754 match op {
755 CompareOp::Eq => "==",
756 CompareOp::Ne => "!=",
757 CompareOp::Lt => "<",
758 CompareOp::Le => "<=",
759 CompareOp::Gt => ">",
760 CompareOp::Ge => ">=",
761 }
762 }
763
764 fn lower_dimension(&self, dim: &Dimension, prefix: &str) -> String {
765 match dim {
766 Dimension::X => format!("{}.x", prefix),
767 Dimension::Y => format!("{}.y", prefix),
768 Dimension::Z => format!("{}.z", prefix),
769 }
770 }
771
772 fn lower_math_op(&self, op: &MathOp) -> &'static str {
773 match op {
774 MathOp::Sin => "sin",
775 MathOp::Cos => "cos",
776 MathOp::Tan => "tan",
777 MathOp::Asin => "asin",
778 MathOp::Acos => "acos",
779 MathOp::Atan => "atan",
780 MathOp::Atan2 => "atan2",
781 MathOp::Sinh => "sinh",
782 MathOp::Cosh => "cosh",
783 MathOp::Tanh => "tanh",
784 MathOp::Exp => "exp",
785 MathOp::Exp2 => "exp2",
786 MathOp::Log => "log",
787 MathOp::Log2 => "log2",
788 MathOp::Log10 => "log10",
789 MathOp::Lerp => "lerp",
790 MathOp::Clamp => "clamp",
791 MathOp::Step => "step",
792 MathOp::SmoothStep => "smoothstep",
793 MathOp::Fract => "fract",
794 MathOp::CopySign => "copysign",
795 }
796 }
797
798 fn get_value_name(&self, id: ValueId) -> String {
799 self.value_names
800 .get(&id)
801 .cloned()
802 .unwrap_or_else(|| format!("v{}", id.raw()))
803 }
804
805 fn get_or_create_name(&mut self, id: ValueId) -> String {
806 if let Some(name) = self.value_names.get(&id) {
807 return name.clone();
808 }
809 let name = format!("t{}", self.name_counter);
810 self.name_counter += 1;
811 self.value_names.insert(id, name.clone());
812 name
813 }
814
815 fn emit_line(&mut self, line: &str) {
816 let indent = " ".repeat(self.indent);
817 writeln!(self.output, "{}{}", indent, line).unwrap();
818 }
819}
820
821#[derive(Debug, Clone)]
823pub enum LoweringError {
824 UnsupportedCapability(String),
826 UndefinedBlock(BlockId),
828 UndefinedValue(ValueId),
830 RequiresCooperativeGroups,
832 TypeError(String),
834}
835
836impl std::fmt::Display for LoweringError {
837 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
838 match self {
839 LoweringError::UnsupportedCapability(cap) => {
840 write!(f, "Unsupported capability: {}", cap)
841 }
842 LoweringError::UndefinedBlock(id) => write!(f, "Undefined block: {}", id),
843 LoweringError::UndefinedValue(id) => write!(f, "Undefined value: {}", id),
844 LoweringError::RequiresCooperativeGroups => {
845 write!(f, "Operation requires cooperative groups")
846 }
847 LoweringError::TypeError(msg) => write!(f, "Type error: {}", msg),
848 }
849 }
850}
851
852impl std::error::Error for LoweringError {}
853
854pub fn lower_to_cuda(module: &IrModule) -> Result<String, LoweringError> {
856 CudaLowering::new(CudaLoweringConfig::default()).lower(module)
857}
858
859pub fn lower_to_cuda_with_config(
861 module: &IrModule,
862 config: CudaLoweringConfig,
863) -> Result<String, LoweringError> {
864 CudaLowering::new(config).lower(module)
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use crate::IrBuilder;
871
872 #[test]
873 fn test_lower_simple_kernel() {
874 let mut builder = IrBuilder::new("add_one");
875
876 let x = builder.parameter("x", IrType::ptr(IrType::F32));
877 let n = builder.parameter("n", IrType::I32);
878
879 let idx = builder.global_thread_id(Dimension::X);
880 let in_bounds = builder.lt(idx, n);
881
882 let then_block = builder.create_block("body");
883 let end_block = builder.create_block("end");
884
885 builder.cond_branch(in_bounds, then_block, end_block);
886
887 builder.switch_to_block(then_block);
888 let one = builder.const_f32(1.0);
889 let ptr = builder.gep(x, vec![idx]);
890 let val = builder.load(ptr);
891 let result = builder.add(val, one);
892 builder.store(ptr, result);
893 builder.branch(end_block);
894
895 builder.switch_to_block(end_block);
896 builder.ret();
897
898 let module = builder.build();
899 let cuda = lower_to_cuda(&module).unwrap();
900
901 assert!(cuda.contains("__global__ void add_one"));
902 assert!(cuda.contains("float* x"));
903 assert!(cuda.contains("int32_t n"));
904 assert!(cuda.contains("blockIdx.x * blockDim.x + threadIdx.x"));
905 }
906
907 #[test]
908 fn test_lower_with_shared_memory() {
909 let mut builder = IrBuilder::new("reduce");
910
911 let _x = builder.parameter("x", IrType::ptr(IrType::F32));
912
913 let shared = builder.shared_alloc(IrType::F32, 256);
914 let _ = shared;
915
916 builder.barrier();
917 builder.ret();
918
919 let module = builder.build();
920 let cuda = lower_to_cuda(&module).unwrap();
921
922 assert!(cuda.contains("__shared__ float"));
923 assert!(cuda.contains("__syncthreads()"));
924 }
925
926 #[test]
927 fn test_lower_with_atomics() {
928 let mut builder = IrBuilder::new("atomic_add");
929
930 let counter = builder.parameter("counter", IrType::ptr(IrType::U32));
931
932 let one = builder.const_u32(1);
933 let _old = builder.atomic_add(counter, one);
934
935 builder.ret();
936
937 let module = builder.build();
938 let cuda = lower_to_cuda(&module).unwrap();
939
940 assert!(cuda.contains("atomicAdd"));
941 }
942
943 #[test]
944 fn test_lower_with_cooperative_groups() {
945 let mut builder = IrBuilder::new("grid_reduce");
946 builder.grid_sync();
947 builder.ret();
948
949 let module = builder.build();
950
951 let result = lower_to_cuda(&module);
953 assert!(result.is_err());
954
955 let config = CudaLoweringConfig::sm80();
957 let cuda = lower_to_cuda_with_config(&module, config).unwrap();
958
959 assert!(cuda.contains("cooperative_groups"));
960 assert!(cuda.contains("grid.sync()"));
961 }
962
963 #[test]
964 fn test_lower_binary_ops() {
965 let mut builder = IrBuilder::new("math");
966
967 let a = builder.const_f32(1.0);
968 let b = builder.const_f32(2.0);
969
970 let _sum = builder.add(a, b);
971 let _diff = builder.sub(a, b);
972 let _prod = builder.mul(a, b);
973 let _quot = builder.div(a, b);
974 let _min = builder.min(a, b);
975 let _max = builder.max(a, b);
976
977 builder.ret();
978
979 let module = builder.build();
980 let cuda = lower_to_cuda(&module).unwrap();
981
982 assert!(cuda.contains("+"));
983 assert!(cuda.contains("-"));
984 assert!(cuda.contains("*"));
985 assert!(cuda.contains("/"));
986 assert!(cuda.contains("min("));
987 assert!(cuda.contains("max("));
988 }
989}