Skip to main content

kaio_core/emit/
emit_trait.rs

1//! The `Emit` trait and implementations for all IR nodes.
2//!
3//! Individual instruction-category Emit impls are co-located with their types:
4//! - `ArithOp` → `instr/arith.rs`
5//! - `MemoryOp` → `instr/memory.rs`
6//! - `ControlOp` → `instr/control.rs`
7//!
8//! This file contains the orchestration-level impls for `PtxModule`,
9//! `PtxKernel`, and `PtxInstruction` (including Mov, Cvt, Label, Comment).
10
11use std::fmt;
12
13use super::writer::PtxWriter;
14use crate::ir::{PtxInstruction, PtxKernel, PtxModule, Register};
15use crate::types::PtxType;
16
17/// Trait for emitting PTX text from an IR node.
18///
19/// Every IR type implements this. The writer handles indentation and
20/// formatting; each `Emit` impl is responsible for producing the
21/// content of its node type.
22pub trait Emit {
23    /// Write this node's PTX representation to the writer.
24    fn emit(&self, w: &mut PtxWriter) -> fmt::Result;
25}
26
27// --- Module-level emission ---
28
29impl Emit for PtxModule {
30    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
31        w.raw_line(&format!(".version {}", self.version))?;
32        w.raw_line(&format!(".target {}", self.target))?;
33        w.raw_line(&format!(".address_size {}", self.address_size))?;
34        for kernel in &self.kernels {
35            w.blank()?;
36            kernel.emit(w)?;
37        }
38        Ok(())
39    }
40}
41
42// --- Kernel-level emission ---
43
44impl Emit for PtxKernel {
45    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
46        // 1. Kernel signature with parameters
47        if self.params.is_empty() {
48            w.raw_line(&format!(".visible .entry {}()", self.name))?;
49        } else {
50            w.raw_line(&format!(".visible .entry {}(", self.name))?;
51            w.indent();
52            for (i, param) in self.params.iter().enumerate() {
53                let comma = if i < self.params.len() - 1 { "," } else { "" };
54                w.line(&format!("{}{}", param.ptx_decl(), comma))?;
55            }
56            w.dedent();
57            w.raw_line(")")?;
58        }
59
60        // 2. Opening brace
61        w.raw_line("{")?;
62        w.indent();
63
64        // 3. Register declarations
65        emit_reg_declarations(&self.registers, w)?;
66
67        // 4. Shared memory declarations
68        for decl in &self.shared_decls {
69            w.line(&format!(
70                ".shared .align {} .b8 {}[{}];",
71                decl.align, decl.name, decl.size_bytes
72            ))?;
73        }
74
75        // 5. Blank line between declarations and body
76        w.blank()?;
77
78        // 6. Instruction body
79        for instr in &self.body {
80            instr.emit(w)?;
81        }
82
83        // 7. Closing brace
84        w.dedent();
85        w.raw_line("}")?;
86        Ok(())
87    }
88}
89
90/// Emit `.reg` declarations grouped by register kind.
91///
92/// Uses the `<N>` syntax: `.reg .b32 %r<5>;` declares `%r0` through `%r4`.
93/// Groups by [`RegKind`](crate::types::RegKind) using fixed-size arrays
94/// indexed by `counter_index()` — no heap allocation, deterministic order.
95fn emit_reg_declarations(registers: &[Register], w: &mut PtxWriter) -> fmt::Result {
96    // Find max index per RegKind
97    let mut max_idx: [Option<u32>; 7] = [None; 7];
98    let mut decl_types: [&str; 7] = [""; 7];
99
100    for reg in registers {
101        let ci = reg.kind.counter_index();
102        match max_idx[ci] {
103            None => {
104                max_idx[ci] = Some(reg.index);
105                decl_types[ci] = reg.ptx_type.reg_decl_type();
106            }
107            Some(prev) if reg.index > prev => {
108                max_idx[ci] = Some(reg.index);
109            }
110            _ => {}
111        }
112    }
113
114    // Emit in counter_index order: R(0), Rd(1), F(2), Fd(3), P(4), H(5), Hb(6)
115    let prefixes = ["%r", "%rd", "%f", "%fd", "%p", "%h", "%hb"];
116    for i in 0..7 {
117        if let Some(max) = max_idx[i] {
118            let count = max + 1;
119            w.line(&format!(
120                ".reg {} {}<{}>;",
121                decl_types[i], prefixes[i], count
122            ))?;
123        }
124    }
125    Ok(())
126}
127
128// --- Instruction-level emission ---
129
130impl Emit for PtxInstruction {
131    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
132        match self {
133            Self::Arith(op) => op.emit(w),
134            Self::Memory(op) => op.emit(w),
135            Self::Control(op) => op.emit(w),
136            Self::TensorCore(op) => op.emit(w),
137            Self::Mov { dst, src, ty } => {
138                let mnemonic = format!("mov{}", ty.ptx_suffix());
139                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
140            }
141            Self::Cvt {
142                dst,
143                src,
144                dst_ty,
145                src_ty,
146            } => {
147                // PTX requires rounding modifiers for conversions involving floats.
148                // KAIO emits .rn for all float-to-float cvt operations for
149                // consistency and PTX validity, even where the conversion is
150                // exact (e.g., f16→f32).
151                let rounding = match (dst_ty, src_ty) {
152                    // int → float (including half): round to nearest even.
153                    // PTX requires `.rn` even for exact conversions like s8→f16
154                    // (ptxas rejects bare `cvt.f16.s8` with "Rounding modifier
155                    // required for instruction 'cvt'"); spec wording about
156                    // optional modifiers does not match real ptxas validation.
157                    (
158                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
159                        PtxType::S8 | PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
160                    ) => ".rn",
161                    // float (including half) → int: round toward zero (matches Rust `as`)
162                    (
163                        PtxType::S8 | PtxType::S32 | PtxType::U32 | PtxType::S64 | PtxType::U64,
164                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
165                    ) => ".rzi",
166                    // float → float (any width, including half): round to nearest
167                    (
168                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
169                        PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64,
170                    ) => ".rn",
171                    // int → int or same type: no rounding modifier
172                    _ => "",
173                };
174                let mnemonic = format!(
175                    "cvt{rounding}{}{}",
176                    dst_ty.ptx_suffix(),
177                    src_ty.ptx_suffix()
178                );
179                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, src])
180            }
181            Self::MovPack { dst, srcs, ty } => {
182                // mov.b{N} %dst, {%s0,%s1,...};
183                //
184                // The vector-pack form of `mov` requires the typeless `.b{N}`
185                // suffix (PTX ISA 9.7.9.10) — `mov.u32 %r, {%h0, %h1};` is
186                // rejected by ptxas. We derive `.b{N}` from the destination
187                // type's byte width.
188                let joined = srcs
189                    .iter()
190                    .map(|r| format!("{r}"))
191                    .collect::<Vec<_>>()
192                    .join(",");
193                let src_list = format!("{{{joined}}}");
194                let bits = ty.size_bytes() * 8;
195                let mnemonic = format!("mov.b{bits}");
196                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &src_list])
197            }
198            Self::Label(name) => {
199                // Labels are at column 0 — dedent, emit, re-indent.
200                // dedent saturates at 0 (safe for edge cases).
201                w.dedent();
202                w.raw_line(&format!("{name}:"))?;
203                w.indent();
204                Ok(())
205            }
206            Self::Comment(text) => w.line(&format!("// {text}")),
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::ir::{Operand, PtxParam, RegisterAllocator, SpecialReg};
215    use crate::types::{PtxType, RegKind};
216
217    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
218        Register {
219            kind,
220            index,
221            ptx_type,
222        }
223    }
224
225    #[test]
226    fn emit_mov_special_reg() {
227        let mut w = PtxWriter::new();
228        w.indent();
229        let instr = PtxInstruction::Mov {
230            dst: reg(RegKind::R, 0, PtxType::U32),
231            src: Operand::SpecialReg(SpecialReg::TidX),
232            ty: PtxType::U32,
233        };
234        instr.emit(&mut w).unwrap();
235        assert_eq!(w.finish(), "    mov.u32 %r0, %tid.x;\n");
236    }
237
238    #[test]
239    fn emit_mov_reg_to_reg() {
240        let mut w = PtxWriter::new();
241        w.indent();
242        let instr = PtxInstruction::Mov {
243            dst: reg(RegKind::F, 1, PtxType::F32),
244            src: Operand::Reg(reg(RegKind::F, 0, PtxType::F32)),
245            ty: PtxType::F32,
246        };
247        instr.emit(&mut w).unwrap();
248        assert_eq!(w.finish(), "    mov.f32 %f1, %f0;\n");
249    }
250
251    #[test]
252    fn emit_mov_shared_addr() {
253        let mut w = PtxWriter::new();
254        w.indent();
255        let instr = PtxInstruction::Mov {
256            dst: reg(RegKind::R, 0, PtxType::U32),
257            src: Operand::SharedAddr("sdata".to_string()),
258            ty: PtxType::U32,
259        };
260        instr.emit(&mut w).unwrap();
261        assert_eq!(w.finish(), "    mov.u32 %r0, sdata;\n");
262    }
263
264    #[test]
265    fn emit_cvt() {
266        let mut w = PtxWriter::new();
267        w.indent();
268        let instr = PtxInstruction::Cvt {
269            dst: reg(RegKind::F, 0, PtxType::F32),
270            src: reg(RegKind::R, 0, PtxType::S32),
271            dst_ty: PtxType::F32,
272            src_ty: PtxType::S32,
273        };
274        instr.emit(&mut w).unwrap();
275        assert_eq!(w.finish(), "    cvt.rn.f32.s32 %f0, %r0;\n");
276    }
277
278    #[test]
279    fn emit_cvt_float_to_int() {
280        let mut w = PtxWriter::new();
281        w.indent();
282        let instr = PtxInstruction::Cvt {
283            dst: reg(RegKind::R, 0, PtxType::U32),
284            src: reg(RegKind::F, 0, PtxType::F32),
285            dst_ty: PtxType::U32,
286            src_ty: PtxType::F32,
287        };
288        instr.emit(&mut w).unwrap();
289        assert_eq!(w.finish(), "    cvt.rzi.u32.f32 %r0, %f0;\n");
290    }
291
292    #[test]
293    fn emit_cvt_int_to_int() {
294        let mut w = PtxWriter::new();
295        w.indent();
296        let instr = PtxInstruction::Cvt {
297            dst: reg(RegKind::R, 0, PtxType::S32),
298            src: reg(RegKind::R, 1, PtxType::U32),
299            dst_ty: PtxType::S32,
300            src_ty: PtxType::U32,
301        };
302        instr.emit(&mut w).unwrap();
303        // No rounding modifier for int → int
304        assert_eq!(w.finish(), "    cvt.s32.u32 %r0, %r1;\n");
305    }
306
307    #[test]
308    fn emit_cvt_f32_to_f16() {
309        let mut w = PtxWriter::new();
310        w.indent();
311        let instr = PtxInstruction::Cvt {
312            dst: reg(RegKind::H, 0, PtxType::F16),
313            src: reg(RegKind::F, 0, PtxType::F32),
314            dst_ty: PtxType::F16,
315            src_ty: PtxType::F32,
316        };
317        instr.emit(&mut w).unwrap();
318        assert_eq!(w.finish(), "    cvt.rn.f16.f32 %h0, %f0;\n");
319    }
320
321    #[test]
322    fn emit_cvt_f16_to_f32() {
323        let mut w = PtxWriter::new();
324        w.indent();
325        let instr = PtxInstruction::Cvt {
326            dst: reg(RegKind::F, 0, PtxType::F32),
327            src: reg(RegKind::H, 0, PtxType::F16),
328            dst_ty: PtxType::F32,
329            src_ty: PtxType::F16,
330        };
331        instr.emit(&mut w).unwrap();
332        assert_eq!(w.finish(), "    cvt.rn.f32.f16 %f0, %h0;\n");
333    }
334
335    #[test]
336    fn emit_cvt_int_to_f16() {
337        let mut w = PtxWriter::new();
338        w.indent();
339        let instr = PtxInstruction::Cvt {
340            dst: reg(RegKind::H, 0, PtxType::F16),
341            src: reg(RegKind::R, 0, PtxType::S32),
342            dst_ty: PtxType::F16,
343            src_ty: PtxType::S32,
344        };
345        instr.emit(&mut w).unwrap();
346        assert_eq!(w.finish(), "    cvt.rn.f16.s32 %h0, %r0;\n");
347    }
348
349    #[test]
350    fn emit_cvt_f16_to_int() {
351        let mut w = PtxWriter::new();
352        w.indent();
353        let instr = PtxInstruction::Cvt {
354            dst: reg(RegKind::R, 0, PtxType::U32),
355            src: reg(RegKind::H, 0, PtxType::F16),
356            dst_ty: PtxType::U32,
357            src_ty: PtxType::F16,
358        };
359        instr.emit(&mut w).unwrap();
360        assert_eq!(w.finish(), "    cvt.rzi.u32.f16 %r0, %h0;\n");
361    }
362
363    #[test]
364    fn emit_cvt_bf16_to_f32() {
365        let mut w = PtxWriter::new();
366        w.indent();
367        let instr = PtxInstruction::Cvt {
368            dst: reg(RegKind::F, 0, PtxType::F32),
369            src: reg(RegKind::Hb, 0, PtxType::BF16),
370            dst_ty: PtxType::F32,
371            src_ty: PtxType::BF16,
372        };
373        instr.emit(&mut w).unwrap();
374        assert_eq!(w.finish(), "    cvt.rn.f32.bf16 %f0, %hb0;\n");
375    }
376
377    #[test]
378    fn emit_reg_declarations_with_f16() {
379        let regs = vec![
380            reg(RegKind::F, 0, PtxType::F32),
381            reg(RegKind::H, 0, PtxType::F16),
382            reg(RegKind::H, 1, PtxType::F16),
383            reg(RegKind::Hb, 0, PtxType::BF16),
384        ];
385        let mut w = PtxWriter::new();
386        w.indent();
387        emit_reg_declarations(&regs, &mut w).unwrap();
388        let output = w.finish();
389        assert!(output.contains(".reg .f32 %f<1>;"));
390        assert!(output.contains(".reg .f16 %h<2>;"));
391        assert!(output.contains(".reg .bf16 %hb<1>;"));
392    }
393
394    #[test]
395    fn emit_label_at_column_zero() {
396        let mut w = PtxWriter::new();
397        w.indent(); // simulate being inside a kernel body
398        let instr = PtxInstruction::Label("EXIT".to_string());
399        instr.emit(&mut w).unwrap();
400        // Label should be at column 0, no indentation
401        assert_eq!(w.finish(), "EXIT:\n");
402    }
403
404    #[test]
405    fn emit_comment() {
406        let mut w = PtxWriter::new();
407        w.indent();
408        let instr = PtxInstruction::Comment("bounds check".to_string());
409        instr.emit(&mut w).unwrap();
410        assert_eq!(w.finish(), "    // bounds check\n");
411    }
412
413    #[test]
414    fn emit_mov_pack_two_f16_into_b32() {
415        let mut w = PtxWriter::new();
416        w.indent();
417        let instr = PtxInstruction::MovPack {
418            dst: reg(RegKind::R, 7, PtxType::U32),
419            srcs: vec![
420                reg(RegKind::H, 3, PtxType::F16),
421                reg(RegKind::H, 4, PtxType::F16),
422            ],
423            ty: PtxType::U32,
424        };
425        instr.emit(&mut w).unwrap();
426        assert_eq!(w.finish(), "    mov.b32 %r7, {%h3,%h4};\n");
427    }
428
429    /// End-to-end emitter test: a mini f16 kernel proving half types flow
430    /// through params, register declarations, loads, cvt, arithmetic, and stores.
431    ///
432    /// Kernel: load f16 → cvt to f32 → add 1.0 → cvt to f16 → store f16
433    #[test]
434    fn emit_kernel_f16_flow() {
435        use crate::instr::{ArithOp, MemoryOp};
436
437        let mut alloc = RegisterAllocator::new();
438        // Registers: rd for pointers, h for f16, f for f32, r for tid
439        let rd_in = alloc.alloc(PtxType::U64); // %rd0: input ptr
440        let rd_out = alloc.alloc(PtxType::U64); // %rd1: output ptr
441        let r_tid = alloc.alloc(PtxType::U32); // %r0: thread id
442        let rd_off = alloc.alloc(PtxType::U64); // %rd2: byte offset
443        let rd_addr_in = alloc.alloc(PtxType::U64); // %rd3: input addr
444        let rd_addr_out = alloc.alloc(PtxType::U64); // %rd4: output addr
445        let h_val = alloc.alloc(PtxType::F16); // %h0: loaded f16
446        let f_val = alloc.alloc(PtxType::F32); // %f0: f32 value
447        let f_one = alloc.alloc(PtxType::F32); // %f1: constant 1.0
448        let f_sum = alloc.alloc(PtxType::F32); // %f2: result
449        let h_out = alloc.alloc(PtxType::F16); // %h1: output f16
450
451        let mut kernel = PtxKernel::new("f16_add_one");
452        kernel.add_param(PtxParam::pointer("in_ptr", PtxType::F16));
453        kernel.add_param(PtxParam::pointer("out_ptr", PtxType::F16));
454
455        // Load params
456        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
457            dst: rd_in,
458            param_name: "in_ptr".to_string(),
459            ty: PtxType::U64,
460        }));
461        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
462            dst: rd_out,
463            param_name: "out_ptr".to_string(),
464            ty: PtxType::U64,
465        }));
466        // Get tid
467        kernel.push(PtxInstruction::Mov {
468            dst: r_tid,
469            src: Operand::SpecialReg(SpecialReg::TidX),
470            ty: PtxType::U32,
471        });
472        // Compute byte offset (tid * 2 for f16)
473        kernel.push(PtxInstruction::Cvt {
474            dst: rd_off,
475            src: r_tid,
476            dst_ty: PtxType::U64,
477            src_ty: PtxType::U32,
478        });
479        // addr_in = in_ptr + offset
480        kernel.push(PtxInstruction::Arith(ArithOp::Add {
481            dst: rd_addr_in,
482            lhs: Operand::Reg(rd_in),
483            rhs: Operand::Reg(rd_off),
484            ty: PtxType::U64,
485        }));
486        // Load f16 value
487        kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
488            dst: h_val,
489            addr: rd_addr_in,
490            ty: PtxType::F16,
491        }));
492        // Convert to f32
493        kernel.push(PtxInstruction::Cvt {
494            dst: f_val,
495            src: h_val,
496            dst_ty: PtxType::F32,
497            src_ty: PtxType::F16,
498        });
499        // Add 1.0
500        kernel.push(PtxInstruction::Mov {
501            dst: f_one,
502            src: Operand::ImmF32(1.0),
503            ty: PtxType::F32,
504        });
505        kernel.push(PtxInstruction::Arith(ArithOp::Add {
506            dst: f_sum,
507            lhs: Operand::Reg(f_val),
508            rhs: Operand::Reg(f_one),
509            ty: PtxType::F32,
510        }));
511        // Convert back to f16
512        kernel.push(PtxInstruction::Cvt {
513            dst: h_out,
514            src: f_sum,
515            dst_ty: PtxType::F16,
516            src_ty: PtxType::F32,
517        });
518        // addr_out = out_ptr + offset
519        kernel.push(PtxInstruction::Arith(ArithOp::Add {
520            dst: rd_addr_out,
521            lhs: Operand::Reg(rd_out),
522            rhs: Operand::Reg(rd_off),
523            ty: PtxType::U64,
524        }));
525        // Store f16 result
526        kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
527            addr: rd_addr_out,
528            src: h_out,
529            ty: PtxType::F16,
530        }));
531        kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
532        kernel.set_registers(alloc.into_allocated());
533
534        let mut w = PtxWriter::new();
535        kernel.emit(&mut w).unwrap();
536        let output = w.finish();
537
538        // Verify structure: params, register declarations, instructions
539        assert!(output.contains(".param .u64 in_ptr"));
540        assert!(output.contains(".param .u64 out_ptr"));
541        assert!(output.contains(".reg .f16 %h<2>;"), "f16 reg declarations");
542        assert!(output.contains(".reg .f32 %f<3>;"), "f32 reg declarations");
543        // f16/bf16 loads and stores emit `.b16` — the valid ld/st type
544        // modifier for 16-bit memory ops (PTX ISA §8.7.9). Register class
545        // stays `.f16`. The `cvt` instruction still uses `.f16` / `.f32`
546        // because it's a register-to-register conversion, not memory.
547        assert!(output.contains("ld.global.b16 %h0"));
548        assert!(output.contains("cvt.rn.f32.f16 %f0, %h0"));
549        assert!(output.contains("cvt.rn.f16.f32 %h1, %f2"));
550        assert!(output.contains("st.global.b16 [%rd4], %h1"));
551    }
552
553    #[test]
554    fn emit_module_header() {
555        let module = PtxModule::new("sm_70");
556        let mut w = PtxWriter::new();
557        // Emit just the header (module with no kernels)
558        module.emit(&mut w).unwrap();
559        assert_eq!(
560            w.finish(),
561            ".version 8.7\n.target sm_70\n.address_size 64\n"
562        );
563    }
564
565    #[test]
566    fn emit_kernel_minimal() {
567        let mut alloc = RegisterAllocator::new();
568        let r0 = alloc.alloc(PtxType::U32);
569
570        let mut kernel = PtxKernel::new("test_kernel");
571        kernel.add_param(PtxParam::scalar("n", PtxType::U32));
572        kernel.push(PtxInstruction::Mov {
573            dst: r0,
574            src: Operand::ImmU32(42),
575            ty: PtxType::U32,
576        });
577        kernel.push(PtxInstruction::Control(crate::instr::ControlOp::Ret));
578        kernel.set_registers(alloc.into_allocated());
579
580        let mut w = PtxWriter::new();
581        kernel.emit(&mut w).unwrap();
582        let output = w.finish();
583
584        // Validate structure
585        assert!(output.contains(".visible .entry test_kernel("));
586        assert!(output.contains(".param .u32 n"));
587        assert!(output.contains(".reg .b32 %r<1>;"));
588        assert!(output.contains("mov.u32 %r0, 42;"));
589        assert!(output.contains("ret;"));
590        assert!(output.starts_with(".visible .entry"));
591        assert!(output.trim_end().ends_with('}'));
592    }
593}