Skip to main content

ringkernel_ir/
printer.rs

1//! IR pretty printer.
2//!
3//! Produces human-readable text representation of IR modules.
4
5use crate::{nodes::*, BlockId, IrModule, IrNode, IrType, Terminator, ValueId};
6use std::fmt::Write;
7
8/// IR pretty printer.
9pub struct IrPrinter {
10    indent: usize,
11    output: String,
12}
13
14impl IrPrinter {
15    /// Create a new printer.
16    pub fn new() -> Self {
17        Self {
18            indent: 0,
19            output: String::new(),
20        }
21    }
22
23    /// Print a module.
24    pub fn print(mut self, module: &IrModule) -> String {
25        self.print_module(module);
26        self.output
27    }
28
29    fn print_module(&mut self, module: &IrModule) {
30        // Header
31        writeln!(self.output, "; RingKernel IR Module: {}", module.name).unwrap();
32        writeln!(
33            self.output,
34            "; Capabilities: {:?}",
35            module.required_capabilities.flags()
36        )
37        .unwrap();
38        writeln!(self.output).unwrap();
39
40        // Parameters
41        self.print_line("define kernel @");
42        write!(self.output, "{}(", module.name).unwrap();
43        for (i, param) in module.parameters.iter().enumerate() {
44            if i > 0 {
45                write!(self.output, ", ").unwrap();
46            }
47            write!(self.output, "{} %{}", param.ty, param.name).unwrap();
48        }
49        writeln!(self.output, ") {{").unwrap();
50
51        self.indent += 1;
52
53        // Print blocks in order (entry first)
54        self.print_block(module, module.entry_block);
55        for block_id in module.blocks.keys() {
56            if *block_id != module.entry_block {
57                self.print_block(module, *block_id);
58            }
59        }
60
61        self.indent -= 1;
62        self.print_line("}");
63    }
64
65    fn print_block(&mut self, module: &IrModule, block_id: BlockId) {
66        let block = match module.blocks.get(&block_id) {
67            Some(b) => b,
68            None => return,
69        };
70
71        // Block label
72        writeln!(self.output).unwrap();
73        writeln!(self.output, "{}:", block.label).unwrap();
74
75        // Instructions
76        for inst in &block.instructions {
77            self.print_instruction(module, inst.result, &inst.result_type, &inst.node);
78        }
79
80        // Terminator
81        if let Some(term) = &block.terminator {
82            self.print_terminator(term);
83        }
84    }
85
86    fn print_instruction(
87        &mut self,
88        _module: &IrModule,
89        result: ValueId,
90        ty: &IrType,
91        node: &IrNode,
92    ) {
93        let indent = "  ".repeat(self.indent);
94
95        let node_str = match node {
96            // Constants
97            IrNode::Constant(c) => format!("{} = const {}", result, format_constant(c)),
98            IrNode::Parameter(idx) => format!("{} = param {}", result, idx),
99            IrNode::Undef => format!("{} = undef", result),
100
101            // Binary ops
102            IrNode::BinaryOp(op, lhs, rhs) => {
103                format!("{} = {} {} {}, {}", result, op, ty, lhs, rhs)
104            }
105
106            // Unary ops
107            IrNode::UnaryOp(op, val) => {
108                format!("{} = {} {} {}", result, op, ty, val)
109            }
110
111            // Comparison
112            IrNode::Compare(op, lhs, rhs) => {
113                format!("{} = cmp {} {}, {}", result, op, lhs, rhs)
114            }
115
116            // Cast
117            IrNode::Cast(kind, val, target_ty) => {
118                format!("{} = cast {:?} {} to {}", result, kind, val, target_ty)
119            }
120
121            // Memory
122            IrNode::Load(ptr) => format!("{} = load {}", result, ptr),
123            IrNode::Store(ptr, val) => format!("store {}, {}", ptr, val),
124            IrNode::GetElementPtr(ptr, indices) => {
125                let indices_str: Vec<String> = indices.iter().map(|i| format!("{}", i)).collect();
126                format!("{} = gep {}, [{}]", result, ptr, indices_str.join(", "))
127            }
128            IrNode::Alloca(ty) => format!("{} = alloca {}", result, ty),
129            IrNode::SharedAlloc(ty, count) => {
130                format!("{} = shared_alloc [{} x {}]", result, count, ty)
131            }
132            IrNode::ExtractField(val, idx) => {
133                format!("{} = extractfield {}, {}", result, val, idx)
134            }
135            IrNode::InsertField(val, idx, new_val) => {
136                format!("{} = insertfield {}, {}, {}", result, val, idx, new_val)
137            }
138
139            // GPU indexing
140            IrNode::ThreadId(dim) => format!("{} = thread_id.{}", result, dim),
141            IrNode::BlockId(dim) => format!("{} = block_id.{}", result, dim),
142            IrNode::BlockDim(dim) => format!("{} = block_dim.{}", result, dim),
143            IrNode::GridDim(dim) => format!("{} = grid_dim.{}", result, dim),
144            IrNode::GlobalThreadId(dim) => format!("{} = global_thread_id.{}", result, dim),
145            IrNode::WarpId => format!("{} = warp_id", result),
146            IrNode::LaneId => format!("{} = lane_id", result),
147
148            // Synchronization
149            IrNode::Barrier => "barrier".to_string(),
150            IrNode::MemoryFence(scope) => format!("fence {:?}", scope),
151            IrNode::GridSync => "grid_sync".to_string(),
152
153            // Atomics
154            IrNode::Atomic(op, ptr, val) => {
155                format!("{} = atomic_{:?} {}, {}", result, op, ptr, val)
156            }
157            IrNode::AtomicCas(ptr, expected, desired) => {
158                format!("{} = atomic_cas {}, {}, {}", result, ptr, expected, desired)
159            }
160
161            // Warp ops
162            IrNode::WarpVote(op, val) => format!("{} = warp_{:?} {}", result, op, val),
163            IrNode::WarpShuffle(op, val, lane) => {
164                format!("{} = warp_shuffle_{:?} {}, {}", result, op, val, lane)
165            }
166            IrNode::WarpReduce(op, val) => format!("{} = warp_reduce_{:?} {}", result, op, val),
167
168            // Math
169            IrNode::Math(op, args) => {
170                let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
171                format!("{} = {:?}({})", result, op, args_str.join(", "))
172            }
173
174            // Control flow
175            IrNode::Select(cond, then_val, else_val) => {
176                format!("{} = select {}, {}, {}", result, cond, then_val, else_val)
177            }
178            IrNode::Phi(entries) => {
179                let entries_str: Vec<String> = entries
180                    .iter()
181                    .map(|(block, val)| format!("[{}, {}]", val, block))
182                    .collect();
183                format!("{} = phi {}", result, entries_str.join(", "))
184            }
185
186            // Messaging
187            IrNode::K2HEnqueue(msg) => format!("k2h_enqueue {}", msg),
188            IrNode::H2KDequeue => format!("{} = h2k_dequeue", result),
189            IrNode::H2KIsEmpty => format!("{} = h2k_is_empty", result),
190            IrNode::K2KSend(dest, msg) => format!("k2k_send {}, {}", dest, msg),
191            IrNode::K2KRecv => format!("{} = k2k_recv", result),
192            IrNode::K2KTryRecv => format!("{} = k2k_try_recv", result),
193
194            // HLC
195            IrNode::HlcNow => format!("{} = hlc_now", result),
196            IrNode::HlcTick => format!("{} = hlc_tick", result),
197            IrNode::HlcUpdate(ts) => format!("{} = hlc_update {}", result, ts),
198
199            // Call
200            IrNode::Call(name, args) => {
201                let args_str: Vec<String> = args.iter().map(|a| format!("{}", a)).collect();
202                format!("{} = call @{}({})", result, name, args_str.join(", "))
203            }
204        };
205
206        writeln!(self.output, "{}{}", indent, node_str).unwrap();
207    }
208
209    fn print_terminator(&mut self, term: &Terminator) {
210        let indent = "  ".repeat(self.indent);
211        let term_str = match term {
212            Terminator::Return(None) => "ret void".to_string(),
213            Terminator::Return(Some(val)) => format!("ret {}", val),
214            Terminator::Branch(target) => format!("br {}", target),
215            Terminator::CondBranch(cond, then_block, else_block) => {
216                format!("br {}, {}, {}", cond, then_block, else_block)
217            }
218            Terminator::Switch(val, default, cases) => {
219                let cases_str: Vec<String> = cases
220                    .iter()
221                    .map(|(c, b)| format!("{} -> {}", format_constant(c), b))
222                    .collect();
223                format!(
224                    "switch {}, default {}, [{}]",
225                    val,
226                    default,
227                    cases_str.join(", ")
228                )
229            }
230            Terminator::Unreachable => "unreachable".to_string(),
231        };
232        writeln!(self.output, "{}{}", indent, term_str).unwrap();
233    }
234
235    fn print_line(&mut self, text: &str) {
236        let indent = "  ".repeat(self.indent);
237        write!(self.output, "{}{}", indent, text).unwrap();
238    }
239}
240
241impl Default for IrPrinter {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247fn format_constant(c: &ConstantValue) -> String {
248    match c {
249        ConstantValue::Bool(b) => format!("{}", b),
250        ConstantValue::I32(v) => format!("{}i32", v),
251        ConstantValue::I64(v) => format!("{}i64", v),
252        ConstantValue::U32(v) => format!("{}u32", v),
253        ConstantValue::U64(v) => format!("{}u64", v),
254        ConstantValue::F32(v) => format!("{}f32", v),
255        ConstantValue::F64(v) => format!("{}f64", v),
256        ConstantValue::Null => "null".to_string(),
257        ConstantValue::Array(elements) => {
258            let elems: Vec<String> = elements.iter().map(format_constant).collect();
259            format!("[{}]", elems.join(", "))
260        }
261        ConstantValue::Struct(fields) => {
262            let fields_str: Vec<String> = fields.iter().map(format_constant).collect();
263            format!("{{{}}}", fields_str.join(", "))
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::{Dimension, IrBuilder};
272
273    #[test]
274    fn test_print_simple_kernel() {
275        let mut builder = IrBuilder::new("saxpy");
276
277        let _x = builder.parameter("x", IrType::ptr(IrType::F32));
278        let _y = builder.parameter("y", IrType::ptr(IrType::F32));
279        let _a = builder.parameter("a", IrType::F32);
280
281        let idx = builder.thread_id(Dimension::X);
282        let _ = idx; // Would be used for indexing
283
284        builder.ret();
285
286        let module = builder.build();
287        let output = module.pretty_print();
288
289        assert!(output.contains("saxpy"));
290        assert!(output.contains("thread_id.x"));
291        assert!(output.contains("ret void"));
292    }
293
294    #[test]
295    fn test_print_with_arithmetic() {
296        let mut builder = IrBuilder::new("test");
297
298        let a = builder.const_i32(10);
299        let b = builder.const_i32(20);
300        let c = builder.add(a, b);
301        let _ = c;
302
303        builder.ret();
304
305        let module = builder.build();
306        let output = module.pretty_print();
307
308        // Constants are stored as values, not printed in blocks
309        // The add instruction references them by ValueId
310        assert!(output.contains("add"));
311        assert!(output.contains("i32")); // Type annotation in add
312    }
313
314    #[test]
315    fn test_print_with_control_flow() {
316        let mut builder = IrBuilder::new("test");
317
318        let cond = builder.const_bool(true);
319        let then_block = builder.create_block("then");
320        let else_block = builder.create_block("else");
321
322        builder.cond_branch(cond, then_block, else_block);
323
324        builder.switch_to_block(then_block);
325        builder.ret();
326
327        builder.switch_to_block(else_block);
328        builder.ret();
329
330        let module = builder.build();
331        let output = module.pretty_print();
332
333        assert!(output.contains("then:"));
334        assert!(output.contains("else:"));
335        assert!(output.contains("br"));
336    }
337}