Skip to main content

kaio_core/emit/
emit_trait.rs

1//! The `Emit` trait and implementations for all IR nodes.
2//!
3//! Individual instruction-category Emit impls are co-located with their types:
4//! - `ArithOp` → `instr/arith.rs`
5//! - `MemoryOp` → `instr/memory.rs`
6//! - `ControlOp` → `instr/control.rs`
7//!
8//! This file contains the orchestration-level impls for `PtxModule`,
9//! `PtxKernel`, and `PtxInstruction` (including Mov, Cvt, Label, Comment).
10
11use std::fmt;
12
13use super::writer::PtxWriter;
14use crate::ir::{PtxInstruction, PtxKernel, PtxModule, Register};
15use crate::types::PtxType;
16
17/// Trait for emitting PTX text from an IR node.
18///
19/// Every IR type implements this. The writer handles indentation and
20/// formatting; each `Emit` impl is responsible for producing the
21/// content of its node type.
22pub trait Emit {
23    /// Write this node's PTX representation to the writer.
24    fn emit(&self, w: &mut PtxWriter) -> fmt::Result;
25}
26
27// --- Module-level emission ---
28
29impl Emit for PtxModule {
30    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
31        w.raw_line(&format!(".version {}", self.version))?;
32        w.raw_line(&format!(".target {}", self.target))?;
33        w.raw_line(&format!(".address_size {}", self.address_size))?;
34        for kernel in &self.kernels {
35            w.blank()?;
36            kernel.emit(w)?;
37        }
38        Ok(())
39    }
40}
41
42// --- Kernel-level emission ---
43
44impl Emit for PtxKernel {
45    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
46        // 1. Kernel signature with parameters
47        if self.params.is_empty() {
48            w.raw_line(&format!(".visible .entry {}()", self.name))?;
49        } else {
50            w.raw_line(&format!(".visible .entry {}(", self.name))?;
51            w.indent();
52            for (i, param) in self.params.iter().enumerate() {
53                let comma = if i < self.params.len() - 1 { "," } else { "" };
54                w.line(&format!("{}{}", param.ptx_decl(), comma))?;
55            }
56            w.dedent();
57            w.raw_line(")")?;
58        }
59
60        // 2. Opening brace
61        w.raw_line("{")?;
62        w.indent();
63
64        // 3. Register declarations
65        emit_reg_declarations(&self.registers, w)?;
66
67        // 4. Shared memory declarations
68        for decl in &self.shared_decls {
69            w.line(&format!(
70                ".shared .align {} .b8 {}[{}];",
71                decl.align, decl.name, decl.size_bytes
72            ))?;
73        }
74
75        // 5. Blank line between declarations and body
76        w.blank()?;
77
78        // 6. Instruction body
79        for instr in &self.body {
80            instr.emit(w)?;
81        }
82
83        // 7. Closing brace
84        w.dedent();
85        w.raw_line("}")?;
86        Ok(())
87    }
88}
89
90/// Emit `.reg` declarations grouped by register kind.
91///
92/// Uses the `<N>` syntax: `.reg .b32 %r<5>;` declares `%r0` through `%r4`.
93/// Groups by [`RegKind`](crate::types::RegKind) using fixed-size arrays
94/// indexed by `counter_index()` — no heap allocation, deterministic order.
95fn emit_reg_declarations(registers: &[Register], w: &mut PtxWriter) -> fmt::Result {
96    // Find max index per RegKind
97    let mut max_idx: [Option<u32>; 7] = [None; 7];
98    let mut decl_types: [&str; 7] = [""; 7];
99
100    for reg in registers {
101        let ci = reg.kind.counter_index();
102        match max_idx[ci] {
103            None => {
104                max_idx[ci] = Some(reg.index);
105                decl_types[ci] = reg.ptx_type.reg_decl_type();
106            }
107            Some(prev) if reg.index > prev => {
108                max_idx[ci] = Some(reg.index);
109            }
110            _ => {}
111        }
112    }
113
114    // Emit in counter_index order: R(0), Rd(1), F(2), Fd(3), P(4), H(5), Hb(6)
115    let prefixes = ["%r", "%rd", "%f", "%fd", "%p", "%h", "%hb"];
116    for i in 0..7 {
117        if let Some(max) = max_idx[i] {
118            let count = max + 1;
119            w.line(&format!(
120                ".reg {} {}<{}>;",
121                decl_types[i], prefixes[i], count
122            ))?;
123        }
124    }
125    Ok(())
126}
127
128// --- Instruction-level emission ---
129
130impl Emit for PtxInstruction {
131    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
132        match self {
133            Self::Arith(op) => op.emit(w),
134            Self::Memory(op) => op.emit(w),
135            Self::Control(op) => op.emit(w),
136            Self::TensorCore(op) => op.emit(w),
137            Self::Mov { dst, src, ty } => {
138                let mnemonic = format!("mov{}", ty.ptx_suffix());
139                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
140            }
141            Self::Cvt {
142                dst,
143                src,
144                dst_ty,
145                src_ty,
146            } => {
147                // PTX requires rounding modifiers for conversions involving floats.
148                // KAIO emits .rn for all float-to-float cvt operations for
149                // consistency and PTX validity, even where the conversion is
150                // exact (e.g., f16→f32).
151                let rounding = match (dst_ty, src_ty) {
152                    // int → float (including half): round to nearest even
153                    (
154                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
155                        PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
156                    ) => ".rn",
157                    // float (including half) → int: round toward zero (matches Rust `as`)
158                    (
159                        PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
160                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
161                    ) => ".rzi",
162                    // float → float (any width, including half): round to nearest
163                    (
164                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
165                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
166                    ) => ".rn",
167                    // int → int or same type: no rounding modifier
168                    _ => "",
169                };
170                let mnemonic = format!(
171                    "cvt{rounding}{}{}",
172                    dst_ty.ptx_suffix(),
173                    src_ty.ptx_suffix()
174                );
175                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
176            }
177            Self::Label(name) => {
178                // Labels are at column 0 — dedent, emit, re-indent.
179                // dedent saturates at 0 (safe for edge cases).
180                w.dedent();
181                w.raw_line(&format!("{name}:"))?;
182                w.indent();
183                Ok(())
184            }
185            Self::Comment(text) => w.line(&format!("// {text}")),
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::ir::{Operand, PtxParam, RegisterAllocator, SpecialReg};
194    use crate::types::{PtxType, RegKind};
195
196    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
197        Register {
198            kind,
199            index,
200            ptx_type,
201        }
202    }
203
204    #[test]
205    fn emit_mov_special_reg() {
206        let mut w = PtxWriter::new();
207        w.indent();
208        let instr = PtxInstruction::Mov {
209            dst: reg(RegKind::R, 0, PtxType::U32),
210            src: Operand::SpecialReg(SpecialReg::TidX),
211            ty: PtxType::U32,
212        };
213        instr.emit(&mut w).unwrap();
214        assert_eq!(w.finish(), "    mov.u32 %r0, %tid.x;\n");
215    }
216
217    #[test]
218    fn emit_mov_reg_to_reg() {
219        let mut w = PtxWriter::new();
220        w.indent();
221        let instr = PtxInstruction::Mov {
222            dst: reg(RegKind::F, 1, PtxType::F32),
223            src: Operand::Reg(reg(RegKind::F, 0, PtxType::F32)),
224            ty: PtxType::F32,
225        };
226        instr.emit(&mut w).unwrap();
227        assert_eq!(w.finish(), "    mov.f32 %f1, %f0;\n");
228    }
229
230    #[test]
231    fn emit_mov_shared_addr() {
232        let mut w = PtxWriter::new();
233        w.indent();
234        let instr = PtxInstruction::Mov {
235            dst: reg(RegKind::R, 0, PtxType::U32),
236            src: Operand::SharedAddr("sdata".to_string()),
237            ty: PtxType::U32,
238        };
239        instr.emit(&mut w).unwrap();
240        assert_eq!(w.finish(), "    mov.u32 %r0, sdata;\n");
241    }
242
243    #[test]
244    fn emit_cvt() {
245        let mut w = PtxWriter::new();
246        w.indent();
247        let instr = PtxInstruction::Cvt {
248            dst: reg(RegKind::F, 0, PtxType::F32),
249            src: reg(RegKind::R, 0, PtxType::S32),
250            dst_ty: PtxType::F32,
251            src_ty: PtxType::S32,
252        };
253        instr.emit(&mut w).unwrap();
254        assert_eq!(w.finish(), "    cvt.rn.f32.s32 %f0, %r0;\n");
255    }
256
257    #[test]
258    fn emit_cvt_float_to_int() {
259        let mut w = PtxWriter::new();
260        w.indent();
261        let instr = PtxInstruction::Cvt {
262            dst: reg(RegKind::R, 0, PtxType::U32),
263            src: reg(RegKind::F, 0, PtxType::F32),
264            dst_ty: PtxType::U32,
265            src_ty: PtxType::F32,
266        };
267        instr.emit(&mut w).unwrap();
268        assert_eq!(w.finish(), "    cvt.rzi.u32.f32 %r0, %f0;\n");
269    }
270
271    #[test]
272    fn emit_cvt_int_to_int() {
273        let mut w = PtxWriter::new();
274        w.indent();
275        let instr = PtxInstruction::Cvt {
276            dst: reg(RegKind::R, 0, PtxType::S32),
277            src: reg(RegKind::R, 1, PtxType::U32),
278            dst_ty: PtxType::S32,
279            src_ty: PtxType::U32,
280        };
281        instr.emit(&mut w).unwrap();
282        // No rounding modifier for int → int
283        assert_eq!(w.finish(), "    cvt.s32.u32 %r0, %r1;\n");
284    }
285
286    #[test]
287    fn emit_cvt_f32_to_f16() {
288        let mut w = PtxWriter::new();
289        w.indent();
290        let instr = PtxInstruction::Cvt {
291            dst: reg(RegKind::H, 0, PtxType::F16),
292            src: reg(RegKind::F, 0, PtxType::F32),
293            dst_ty: PtxType::F16,
294            src_ty: PtxType::F32,
295        };
296        instr.emit(&mut w).unwrap();
297        assert_eq!(w.finish(), "    cvt.rn.f16.f32 %h0, %f0;\n");
298    }
299
300    #[test]
301    fn emit_cvt_f16_to_f32() {
302        let mut w = PtxWriter::new();
303        w.indent();
304        let instr = PtxInstruction::Cvt {
305            dst: reg(RegKind::F, 0, PtxType::F32),
306            src: reg(RegKind::H, 0, PtxType::F16),
307            dst_ty: PtxType::F32,
308            src_ty: PtxType::F16,
309        };
310        instr.emit(&mut w).unwrap();
311        assert_eq!(w.finish(), "    cvt.rn.f32.f16 %f0, %h0;\n");
312    }
313
314    #[test]
315    fn emit_cvt_int_to_f16() {
316        let mut w = PtxWriter::new();
317        w.indent();
318        let instr = PtxInstruction::Cvt {
319            dst: reg(RegKind::H, 0, PtxType::F16),
320            src: reg(RegKind::R, 0, PtxType::S32),
321            dst_ty: PtxType::F16,
322            src_ty: PtxType::S32,
323        };
324        instr.emit(&mut w).unwrap();
325        assert_eq!(w.finish(), "    cvt.rn.f16.s32 %h0, %r0;\n");
326    }
327
328    #[test]
329    fn emit_cvt_f16_to_int() {
330        let mut w = PtxWriter::new();
331        w.indent();
332        let instr = PtxInstruction::Cvt {
333            dst: reg(RegKind::R, 0, PtxType::U32),
334            src: reg(RegKind::H, 0, PtxType::F16),
335            dst_ty: PtxType::U32,
336            src_ty: PtxType::F16,
337        };
338        instr.emit(&mut w).unwrap();
339        assert_eq!(w.finish(), "    cvt.rzi.u32.f16 %r0, %h0;\n");
340    }
341
342    #[test]
343    fn emit_cvt_bf16_to_f32() {
344        let mut w = PtxWriter::new();
345        w.indent();
346        let instr = PtxInstruction::Cvt {
347            dst: reg(RegKind::F, 0, PtxType::F32),
348            src: reg(RegKind::Hb, 0, PtxType::BF16),
349            dst_ty: PtxType::F32,
350            src_ty: PtxType::BF16,
351        };
352        instr.emit(&mut w).unwrap();
353        assert_eq!(w.finish(), "    cvt.rn.f32.bf16 %f0, %hb0;\n");
354    }
355
356    #[test]
357    fn emit_reg_declarations_with_f16() {
358        let regs = vec![
359            reg(RegKind::F, 0, PtxType::F32),
360            reg(RegKind::H, 0, PtxType::F16),
361            reg(RegKind::H, 1, PtxType::F16),
362            reg(RegKind::Hb, 0, PtxType::BF16),
363        ];
364        let mut w = PtxWriter::new();
365        w.indent();
366        emit_reg_declarations(&regs, &mut w).unwrap();
367        let output = w.finish();
368        assert!(output.contains(".reg .f32 %f<1>;"));
369        assert!(output.contains(".reg .f16 %h<2>;"));
370        assert!(output.contains(".reg .bf16 %hb<1>;"));
371    }
372
373    #[test]
374    fn emit_label_at_column_zero() {
375        let mut w = PtxWriter::new();
376        w.indent(); // simulate being inside a kernel body
377        let instr = PtxInstruction::Label("EXIT".to_string());
378        instr.emit(&mut w).unwrap();
379        // Label should be at column 0, no indentation
380        assert_eq!(w.finish(), "EXIT:\n");
381    }
382
383    #[test]
384    fn emit_comment() {
385        let mut w = PtxWriter::new();
386        w.indent();
387        let instr = PtxInstruction::Comment("bounds check".to_string());
388        instr.emit(&mut w).unwrap();
389        assert_eq!(w.finish(), "    // bounds check\n");
390    }
391
392    /// End-to-end emitter test: a mini f16 kernel proving half types flow
393    /// through params, register declarations, loads, cvt, arithmetic, and stores.
394    ///
395    /// Kernel: load f16 → cvt to f32 → add 1.0 → cvt to f16 → store f16
396    #[test]
397    fn emit_kernel_f16_flow() {
398        use crate::instr::{ArithOp, MemoryOp};
399
400        let mut alloc = RegisterAllocator::new();
401        // Registers: rd for pointers, h for f16, f for f32, r for tid
402        let rd_in = alloc.alloc(PtxType::U64); // %rd0: input ptr
403        let rd_out = alloc.alloc(PtxType::U64); // %rd1: output ptr
404        let r_tid = alloc.alloc(PtxType::U32); // %r0: thread id
405        let rd_off = alloc.alloc(PtxType::U64); // %rd2: byte offset
406        let rd_addr_in = alloc.alloc(PtxType::U64); // %rd3: input addr
407        let rd_addr_out = alloc.alloc(PtxType::U64); // %rd4: output addr
408        let h_val = alloc.alloc(PtxType::F16); // %h0: loaded f16
409        let f_val = alloc.alloc(PtxType::F32); // %f0: f32 value
410        let f_one = alloc.alloc(PtxType::F32); // %f1: constant 1.0
411        let f_sum = alloc.alloc(PtxType::F32); // %f2: result
412        let h_out = alloc.alloc(PtxType::F16); // %h1: output f16
413
414        let mut kernel = PtxKernel::new("f16_add_one");
415        kernel.add_param(PtxParam::pointer("in_ptr", PtxType::F16));
416        kernel.add_param(PtxParam::pointer("out_ptr", PtxType::F16));
417
418        // Load params
419        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
420            dst: rd_in,
421            param_name: "in_ptr".to_string(),
422            ty: PtxType::U64,
423        }));
424        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
425            dst: rd_out,
426            param_name: "out_ptr".to_string(),
427            ty: PtxType::U64,
428        }));
429        // Get tid
430        kernel.push(PtxInstruction::Mov {
431            dst: r_tid,
432            src: Operand::SpecialReg(SpecialReg::TidX),
433            ty: PtxType::U32,
434        });
435        // Compute byte offset (tid * 2 for f16)
436        kernel.push(PtxInstruction::Cvt {
437            dst: rd_off,
438            src: r_tid,
439            dst_ty: PtxType::U64,
440            src_ty: PtxType::U32,
441        });
442        // addr_in = in_ptr + offset
443        kernel.push(PtxInstruction::Arith(ArithOp::Add {
444            dst: rd_addr_in,
445            lhs: Operand::Reg(rd_in),
446            rhs: Operand::Reg(rd_off),
447            ty: PtxType::U64,
448        }));
449        // Load f16 value
450        kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
451            dst: h_val,
452            addr: rd_addr_in,
453            ty: PtxType::F16,
454        }));
455        // Convert to f32
456        kernel.push(PtxInstruction::Cvt {
457            dst: f_val,
458            src: h_val,
459            dst_ty: PtxType::F32,
460            src_ty: PtxType::F16,
461        });
462        // Add 1.0
463        kernel.push(PtxInstruction::Mov {
464            dst: f_one,
465            src: Operand::ImmF32(1.0),
466            ty: PtxType::F32,
467        });
468        kernel.push(PtxInstruction::Arith(ArithOp::Add {
469            dst: f_sum,
470            lhs: Operand::Reg(f_val),
471            rhs: Operand::Reg(f_one),
472            ty: PtxType::F32,
473        }));
474        // Convert back to f16
475        kernel.push(PtxInstruction::Cvt {
476            dst: h_out,
477            src: f_sum,
478            dst_ty: PtxType::F16,
479            src_ty: PtxType::F32,
480        });
481        // addr_out = out_ptr + offset
482        kernel.push(PtxInstruction::Arith(ArithOp::Add {
483            dst: rd_addr_out,
484            lhs: Operand::Reg(rd_out),
485            rhs: Operand::Reg(rd_off),
486            ty: PtxType::U64,
487        }));
488        // Store f16 result
489        kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
490            addr: rd_addr_out,
491            src: h_out,
492            ty: PtxType::F16,
493        }));
494        kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
495        kernel.set_registers(alloc.into_allocated());
496
497        let mut w = PtxWriter::new();
498        kernel.emit(&mut w).unwrap();
499        let output = w.finish();
500
501        // Verify structure: params, register declarations, instructions
502        assert!(output.contains(".param .u64 in_ptr"));
503        assert!(output.contains(".param .u64 out_ptr"));
504        assert!(output.contains(".reg .f16 %h<2>;"), "f16 reg declarations");
505        assert!(output.contains(".reg .f32 %f<3>;"), "f32 reg declarations");
506        // f16/bf16 loads and stores emit `.b16` — the valid ld/st type
507        // modifier for 16-bit memory ops (PTX ISA §8.7.9). Register class
508        // stays `.f16`. The `cvt` instruction still uses `.f16` / `.f32`
509        // because it's a register-to-register conversion, not memory.
510        assert!(output.contains("ld.global.b16 %h0"));
511        assert!(output.contains("cvt.rn.f32.f16 %f0, %h0"));
512        assert!(output.contains("cvt.rn.f16.f32 %h1, %f2"));
513        assert!(output.contains("st.global.b16 [%rd4], %h1"));
514    }
515
516    #[test]
517    fn emit_module_header() {
518        let module = PtxModule::new("sm_70");
519        let mut w = PtxWriter::new();
520        // Emit just the header (module with no kernels)
521        module.emit(&mut w).unwrap();
522        assert_eq!(
523            w.finish(),
524            ".version 8.7\n.target sm_70\n.address_size 64\n"
525        );
526    }
527
528    #[test]
529    fn emit_kernel_minimal() {
530        let mut alloc = RegisterAllocator::new();
531        let r0 = alloc.alloc(PtxType::U32);
532
533        let mut kernel = PtxKernel::new("test_kernel");
534        kernel.add_param(PtxParam::scalar("n", PtxType::U32));
535        kernel.push(PtxInstruction::Mov {
536            dst: r0,
537            src: Operand::ImmU32(42),
538            ty: PtxType::U32,
539        });
540        kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
541        kernel.set_registers(alloc.into_allocated());
542
543        let mut w = PtxWriter::new();
544        kernel.emit(&mut w).unwrap();
545        let output = w.finish();
546
547        // Validate structure
548        assert!(output.contains(".visible .entry test_kernel("));
549        assert!(output.contains(".param .u32 n"));
550        assert!(output.contains(".reg .b32 %r<1>;"));
551        assert!(output.contains("mov.u32 %r0, 42;"));
552        assert!(output.contains("ret;"));
553        assert!(output.starts_with(".visible .entry"));
554        assert!(output.trim_end().ends_with('}'));
555    }
556}