1use crate::{nodes::*, BlockId, IrModule, IrNode, IrType, Terminator, ValueId};
6use std::fmt::Write;
7
8pub struct IrPrinter {
10 indent: usize,
11 output: String,
12}
13
14impl IrPrinter {
15 pub fn new() -> Self {
17 Self {
18 indent: 0,
19 output: String::new(),
20 }
21 }
22
23 pub fn print(mut self, module: &IrModule) -> String {
25 self.print_module(module);
26 self.output
27 }
28
29 fn print_module(&mut self, module: &IrModule) {
30 writeln!(self.output, "; RingKernel IR Module: {}", module.name).unwrap();
32 writeln!(
33 self.output,
34 "; Capabilities: {:?}",
35 module.required_capabilities.flags()
36 )
37 .unwrap();
38 writeln!(self.output).unwrap();
39
40 self.print_line("define kernel @");
42 write!(self.output, "{}(", module.name).unwrap();
43 for (i, param) in module.parameters.iter().enumerate() {
44 if i > 0 {
45 write!(self.output, ", ").unwrap();
46 }
47 write!(self.output, "{} %{}", param.ty, param.name).unwrap();
48 }
49 writeln!(self.output, ") {{").unwrap();
50
51 self.indent += 1;
52
53 self.print_block(module, module.entry_block);
55 for block_id in module.blocks.keys() {
56 if *block_id != module.entry_block {
57 self.print_block(module, *block_id);
58 }
59 }
60
61 self.indent -= 1;
62 self.print_line("}");
63 }
64
65 fn print_block(&mut self, module: &IrModule, block_id: BlockId) {
66 let block = match module.blocks.get(&block_id) {
67 Some(b) => b,
68 None => return,
69 };
70
71 writeln!(self.output).unwrap();
73 writeln!(self.output, "{}:", block.label).unwrap();
74
75 for inst in &block.instructions {
77 self.print_instruction(module, inst.result, &inst.result_type, &inst.node);
78 }
79
80 if let Some(term) = &block.terminator {
82 self.print_terminator(term);
83 }
84 }
85
86 fn print_instruction(
87 &mut self,
88 _module: &IrModule,
89 result: ValueId,
90 ty: &IrType,
91 node: &IrNode,
92 ) {
93 let indent = " ".repeat(self.indent);
94
95 let node_str = match node {
96 IrNode::Constant(c) => format!("{} = const {}", result, format_constant(c)),
98 IrNode::Parameter(idx) => format!("{} = param {}", result, idx),
99 IrNode::Undef => format!("{} = undef", result),
100
101 IrNode::BinaryOp(op, lhs, rhs) => {
103 format!("{} = {} {} {}, {}", result, op, ty, lhs, rhs)
104 }
105
106 IrNode::UnaryOp(op, val) => {
108 format!("{} = {} {} {}", result, op, ty, val)
109 }
110
111 IrNode::Compare(op, lhs, rhs) => {
113 format!("{} = cmp {} {}, {}", result, op, lhs, rhs)
114 }
115
116 IrNode::Cast(kind, val, target_ty) => {
118 format!("{} = cast {:?} {} to {}", result, kind, val, target_ty)
119 }
120
121 IrNode::Load(ptr) => format!("{} = load {}", result, ptr),
123 IrNode::Store(ptr, val) => format!("store {}, {}", ptr, val),
124 IrNode::GetElementPtr(ptr, indices) => {
125 let indices_str: Vec<String> = indices.iter().map(|i| format!("{}", i)).collect();
126 format!("{} = gep {}, [{}]", result, ptr, indices_str.join(", "))
127 }
128 IrNode::Alloca(ty) => format!("{} = alloca {}", result, ty),
129 IrNode::SharedAlloc(ty, count) => {
130 format!("{} = shared_alloc [{} x {}]", result, count, ty)
131 }
132 IrNode::ExtractField(val, idx) => {
133 format!("{} = extractfield {}, {}", result, val, idx)
134 }
135 IrNode::InsertField(val, idx, new_val) => {
136 format!("{} = insertfield {}, {}, {}", result, val, idx, new_val)
137 }
138
139 IrNode::ThreadId(dim) => format!("{} = thread_id.{}", result, dim),
141 IrNode::BlockId(dim) => format!("{} = block_id.{}", result, dim),
142 IrNode::BlockDim(dim) => format!("{} = block_dim.{}", result, dim),
143 IrNode::GridDim(dim) => format!("{} = grid_dim.{}", result, dim),
144 IrNode::GlobalThreadId(dim) => format!("{} = global_thread_id.{}", result, dim),
145 IrNode::WarpId => format!("{} = warp_id", result),
146 IrNode::LaneId => format!("{} = lane_id", result),
147
148 IrNode::Barrier => "barrier".to_string(),
150 IrNode::MemoryFence(scope) => format!("fence {:?}", scope),
151 IrNode::GridSync => "grid_sync".to_string(),
152
153 IrNode::Atomic(op, ptr, val) => {
155 format!("{} = atomic_{:?} {}, {}", result, op, ptr, val)
156 }
157 IrNode::AtomicCas(ptr, expected, desired) => {
158 format!("{} = atomic_cas {}, {}, {}", result, ptr, expected, desired)
159 }
160
161 IrNode::WarpVote(op, val) => format!("{} = warp_{:?} {}", result, op, val),
163 IrNode::WarpShuffle(op, val, lane) => {
164 format!("{} = warp_shuffle_{:?} {}, {}", result, op, val, lane)
165 }
166 IrNode::WarpReduce(op, val) => format!("{} = warp_reduce_{:?} {}", result, op, val),
167
168 IrNode::Math(op, args) => {
170 let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
171 format!("{} = {:?}({})", result, op, args_str.join(", "))
172 }
173
174 IrNode::Select(cond, then_val, else_val) => {
176 format!("{} = select {}, {}, {}", result, cond, then_val, else_val)
177 }
178 IrNode::Phi(entries) => {
179 let entries_str: Vec<String> = entries
180 .iter()
181 .map(|(block, val)| format!("[{}, {}]", val, block))
182 .collect();
183 format!("{} = phi {}", result, entries_str.join(", "))
184 }
185
186 IrNode::K2HEnqueue(msg) => format!("k2h_enqueue {}", msg),
188 IrNode::H2KDequeue => format!("{} = h2k_dequeue", result),
189 IrNode::H2KIsEmpty => format!("{} = h2k_is_empty", result),
190 IrNode::K2KSend(dest, msg) => format!("k2k_send {}, {}", dest, msg),
191 IrNode::K2KRecv => format!("{} = k2k_recv", result),
192 IrNode::K2KTryRecv => format!("{} = k2k_try_recv", result),
193
194 IrNode::HlcNow => format!("{} = hlc_now", result),
196 IrNode::HlcTick => format!("{} = hlc_tick", result),
197 IrNode::HlcUpdate(ts) => format!("{} = hlc_update {}", result, ts),
198
199 IrNode::Call(name, args) => {
201 let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
202 format!("{} = call @{}({})", result, name, args_str.join(", "))
203 }
204 };
205
206 writeln!(self.output, "{}{}", indent, node_str).unwrap();
207 }
208
209 fn print_terminator(&mut self, term: &Terminator) {
210 let indent = " ".repeat(self.indent);
211 let term_str = match term {
212 Terminator::Return(None) => "ret void".to_string(),
213 Terminator::Return(Some(val)) => format!("ret {}", val),
214 Terminator::Branch(target) => format!("br {}", target),
215 Terminator::CondBranch(cond, then_block, else_block) => {
216 format!("br {}, {}, {}", cond, then_block, else_block)
217 }
218 Terminator::Switch(val, default, cases) => {
219 let cases_str: Vec<String> = cases
220 .iter()
221 .map(|(c, b)| format!("{} -> {}", format_constant(c), b))
222 .collect();
223 format!(
224 "switch {}, default {}, [{}]",
225 val,
226 default,
227 cases_str.join(", ")
228 )
229 }
230 Terminator::Unreachable => "unreachable".to_string(),
231 };
232 writeln!(self.output, "{}{}", indent, term_str).unwrap();
233 }
234
235 fn print_line(&mut self, text: &str) {
236 let indent = " ".repeat(self.indent);
237 write!(self.output, "{}{}", indent, text).unwrap();
238 }
239}
240
241impl Default for IrPrinter {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247fn format_constant(c: &ConstantValue) -> String {
248 match c {
249 ConstantValue::Bool(b) => format!("{}", b),
250 ConstantValue::I32(v) => format!("{}i32", v),
251 ConstantValue::I64(v) => format!("{}i64", v),
252 ConstantValue::U32(v) => format!("{}u32", v),
253 ConstantValue::U64(v) => format!("{}u64", v),
254 ConstantValue::F32(v) => format!("{}f32", v),
255 ConstantValue::F64(v) => format!("{}f64", v),
256 ConstantValue::Null => "null".to_string(),
257 ConstantValue::Array(elements) => {
258 let elems: Vec<String> = elements.iter().map(format_constant).collect();
259 format!("[{}]", elems.join(", "))
260 }
261 ConstantValue::Struct(fields) => {
262 let fields_str: Vec<String> = fields.iter().map(format_constant).collect();
263 format!("{{{}}}", fields_str.join(", "))
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::{Dimension, IrBuilder};
272
273 #[test]
274 fn test_print_simple_kernel() {
275 let mut builder = IrBuilder::new("saxpy");
276
277 let _x = builder.parameter("x", IrType::ptr(IrType::F32));
278 let _y = builder.parameter("y", IrType::ptr(IrType::F32));
279 let _a = builder.parameter("a", IrType::F32);
280
281 let idx = builder.thread_id(Dimension::X);
282 let _ = idx; builder.ret();
285
286 let module = builder.build();
287 let output = module.pretty_print();
288
289 assert!(output.contains("saxpy"));
290 assert!(output.contains("thread_id.x"));
291 assert!(output.contains("ret void"));
292 }
293
294 #[test]
295 fn test_print_with_arithmetic() {
296 let mut builder = IrBuilder::new("test");
297
298 let a = builder.const_i32(10);
299 let b = builder.const_i32(20);
300 let c = builder.add(a, b);
301 let _ = c;
302
303 builder.ret();
304
305 let module = builder.build();
306 let output = module.pretty_print();
307
308 assert!(output.contains("add"));
311 assert!(output.contains("i32")); }
313
314 #[test]
315 fn test_print_with_control_flow() {
316 let mut builder = IrBuilder::new("test");
317
318 let cond = builder.const_bool(true);
319 let then_block = builder.create_block("then");
320 let else_block = builder.create_block("else");
321
322 builder.cond_branch(cond, then_block, else_block);
323
324 builder.switch_to_block(then_block);
325 builder.ret();
326
327 builder.switch_to_block(else_block);
328 builder.ret();
329
330 let module = builder.build();
331 let output = module.pretty_print();
332
333 assert!(output.contains("then:"));
334 assert!(output.contains("else:"));
335 assert!(output.contains("br"));
336 }
337}