Skip to main content

kaio_core/ir/
kernel.rs

1//! PTX kernel — a single `.entry` function in a PTX module.
2
3use super::instruction::PtxInstruction;
4use super::param::PtxParam;
5use super::register::Register;
6use crate::instr::ArithOp;
7use crate::instr::control::ControlOp;
8use crate::instr::memory::MemoryOp;
9use crate::instr::tensor_core::TensorCoreOp;
10use crate::types::RegKind;
11
12/// Shared memory declaration in a PTX kernel preamble.
13///
14/// Emitted as `.shared .align {align} .b8 {name}[{size_bytes}];` after
15/// register declarations.
16#[derive(Debug, Clone)]
17pub struct SharedDecl {
18    /// Name of the shared memory allocation (e.g., `"sdata"`).
19    pub name: String,
20    /// Alignment in bytes (4 for f32, 8 for f64).
21    pub align: u32,
22    /// Total allocation size in bytes.
23    pub size_bytes: u32,
24}
25
26/// A PTX kernel function (`.visible .entry`).
27///
28/// Built by constructing parameters, allocating registers, and pushing
29/// instructions. Call [`set_registers`](Self::set_registers) with the
30/// allocator's output before emission so the kernel knows which `.reg`
31/// declarations to emit.
32#[derive(Debug, Clone)]
33pub struct PtxKernel {
34    /// Kernel entry point name.
35    pub name: String,
36    /// Declared parameters (in signature order).
37    pub params: Vec<PtxParam>,
38    /// Instruction body.
39    pub body: Vec<PtxInstruction>,
40    /// All registers used, for `.reg` declaration emission.
41    pub registers: Vec<Register>,
42    /// Shared memory declarations (emitted after register declarations).
43    pub shared_decls: Vec<SharedDecl>,
44}
45
46impl PtxKernel {
47    /// Create a new empty kernel with the given name.
48    pub fn new(name: &str) -> Self {
49        Self {
50            name: name.to_string(),
51            params: Vec::new(),
52            body: Vec::new(),
53            registers: Vec::new(),
54            shared_decls: Vec::new(),
55        }
56    }
57
58    /// Add a parameter to the kernel signature.
59    pub fn add_param(&mut self, param: PtxParam) {
60        self.params.push(param);
61    }
62
63    /// Append an instruction to the kernel body.
64    pub fn push(&mut self, instr: PtxInstruction) {
65        self.body.push(instr);
66    }
67
68    /// Set the register list (from [`super::register::RegisterAllocator::into_allocated`]).
69    pub fn set_registers(&mut self, regs: Vec<Register>) {
70        self.registers = regs;
71    }
72
73    /// Add a shared memory declaration to the kernel preamble.
74    pub fn add_shared_decl(&mut self, decl: SharedDecl) {
75        self.shared_decls.push(decl);
76    }
77
78    /// Compute structural statistics about this kernel's emitted PTX.
79    ///
80    /// Walks the instruction body and counts instruction types, registers
81    /// by kind, and declared shared memory. Useful for inspection and
82    /// comparison between kernel variants.
83    ///
84    /// These are **not** runtime profiling data — final hardware register
85    /// allocation and occupancy may differ after CUDA driver compilation.
86    pub fn stats(&self) -> KernelStats {
87        let mut s = KernelStats::default();
88
89        for instr in &self.body {
90            match instr {
91                PtxInstruction::Arith(op) => {
92                    s.total_instructions += 1;
93                    if matches!(op, ArithOp::Fma { .. }) {
94                        s.fma += 1;
95                    } else {
96                        s.arith_other += 1;
97                    }
98                }
99                PtxInstruction::Memory(op) => {
100                    s.total_instructions += 1;
101                    match op {
102                        MemoryOp::LdGlobal { .. } => s.ld_global += 1,
103                        MemoryOp::StGlobal { .. } => s.st_global += 1,
104                        MemoryOp::LdShared { .. } => s.ld_shared += 1,
105                        MemoryOp::StShared { .. } => s.st_shared += 1,
106                        MemoryOp::CpAsyncCaSharedGlobal { .. } => s.cp_async += 1,
107                        MemoryOp::CpAsyncCommitGroup => s.cp_async_commit += 1,
108                        MemoryOp::CpAsyncWaitGroup { .. } => s.cp_async_wait += 1,
109                        _ => {}
110                    }
111                }
112                PtxInstruction::TensorCore(op) => {
113                    s.total_instructions += 1;
114                    match op {
115                        TensorCoreOp::MmaSync { .. }
116                        | TensorCoreOp::MmaSyncInt8 { .. }
117                        | TensorCoreOp::MmaSyncBf16 { .. } => s.mma += 1,
118                        // Counted separately from both `mma` and
119                        // `ld_shared`: it is a warp-collective load, and
120                        // folding it into either would hide exactly the
121                        // instruction-mix shift the ldmatrix loader
122                        // rewire is supposed to show (Sprint 9.3).
123                        TensorCoreOp::LdMatrix { .. } => s.ldmatrix += 1,
124                    }
125                }
126                PtxInstruction::Control(op) => {
127                    s.total_instructions += 1;
128                    match op {
129                        ControlOp::BarSync { .. } => s.bar_sync += 1,
130                        ControlOp::BraPred { .. } | ControlOp::Bra { .. } => s.branches += 1,
131                        ControlOp::SetP { .. } => s.setp += 1,
132                        _ => {}
133                    }
134                }
135                PtxInstruction::Mov { .. } => {
136                    s.total_instructions += 1;
137                    s.mov += 1;
138                }
139                PtxInstruction::Cvt { .. } => {
140                    s.total_instructions += 1;
141                    s.cvt += 1;
142                }
143                PtxInstruction::MovPack { .. } => {
144                    s.total_instructions += 1;
145                    s.mov += 1;
146                }
147                PtxInstruction::Label(_) | PtxInstruction::Comment(_) => {}
148            }
149        }
150
151        for reg in &self.registers {
152            match reg.kind {
153                RegKind::R => s.registers_r += 1,
154                RegKind::Rd => s.registers_rd += 1,
155                RegKind::F => s.registers_f += 1,
156                RegKind::Fd => s.registers_fd += 1,
157                RegKind::P => s.registers_p += 1,
158                RegKind::H => s.registers_h += 1,
159                RegKind::Hb => s.registers_hb += 1,
160            }
161        }
162
163        s.shared_bytes = self.shared_decls.iter().map(|d| d.size_bytes).sum();
164
165        s
166    }
167}
168
169/// Structural statistics about a compiled kernel's emitted PTX.
170///
171/// These describe the instruction mix and declared resource usage in
172/// KAIO's generated PTX — useful for inspection and comparison between
173/// kernel variants, but **not** a substitute for runtime profiling.
174/// Final hardware register allocation and occupancy may differ from
175/// these counts after the CUDA driver's backend compilation (PTX → SASS).
176#[derive(Debug, Default, PartialEq, Eq)]
177pub struct KernelStats {
178    /// Total instructions (excludes labels and comments).
179    pub total_instructions: usize,
180    /// `ld.global` count.
181    pub ld_global: usize,
182    /// `st.global` count.
183    pub st_global: usize,
184    /// `ld.shared` count.
185    pub ld_shared: usize,
186    /// `st.shared` count.
187    pub st_shared: usize,
188    /// `bar.sync` count.
189    pub bar_sync: usize,
190    /// `mma.sync` instruction count (all tensor-core shapes).
191    pub mma: usize,
192    /// `ldmatrix` instruction count (warp-collective fragment loads —
193    /// tracked apart from `ld_shared` so loader-rewire instruction-mix
194    /// shifts stay visible).
195    pub ldmatrix: usize,
196    /// `cp.async.ca.shared.global` instruction count.
197    pub cp_async: usize,
198    /// `cp.async.commit_group` instruction count.
199    pub cp_async_commit: usize,
200    /// `cp.async.wait_group` instruction count.
201    pub cp_async_wait: usize,
202    /// `fma` instruction count.
203    pub fma: usize,
204    /// Non-FMA arithmetic instructions (add, mul, sub, etc.).
205    pub arith_other: usize,
206    /// `mov` instruction count.
207    pub mov: usize,
208    /// `cvt` instruction count.
209    pub cvt: usize,
210    /// Branch instructions (`bra`, `@pred bra`).
211    pub branches: usize,
212    /// `setp` comparison-to-predicate instructions.
213    pub setp: usize,
214    /// `%r` registers (32-bit integer).
215    pub registers_r: u32,
216    /// `%rd` registers (64-bit integer).
217    pub registers_rd: u32,
218    /// `%f` registers (f32).
219    pub registers_f: u32,
220    /// `%fd` registers (f64).
221    pub registers_fd: u32,
222    /// `%p` registers (predicate).
223    pub registers_p: u32,
224    /// `%h` registers (f16).
225    pub registers_h: u32,
226    /// `%hb` registers (bf16).
227    pub registers_hb: u32,
228    /// Total declared shared memory in bytes.
229    pub shared_bytes: u32,
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::ir::Operand;
236    use crate::types::PtxType;
237
238    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
239        Register {
240            kind,
241            index,
242            ptx_type,
243        }
244    }
245
246    #[test]
247    fn stats_empty_kernel() {
248        let kernel = PtxKernel::new("empty");
249        let s = kernel.stats();
250        assert_eq!(s, KernelStats::default());
251    }
252
253    #[test]
254    fn stats_counts_instruction_types() {
255        let mut kernel = PtxKernel::new("test");
256
257        // 2 FMA
258        for _ in 0..2 {
259            kernel.push(PtxInstruction::Arith(ArithOp::Fma {
260                dst: reg(RegKind::F, 0, PtxType::F32),
261                a: Operand::Reg(reg(RegKind::F, 1, PtxType::F32)),
262                b: Operand::Reg(reg(RegKind::F, 2, PtxType::F32)),
263                c: Operand::Reg(reg(RegKind::F, 3, PtxType::F32)),
264                ty: PtxType::F32,
265            }));
266        }
267        // 1 Add (arith_other)
268        kernel.push(PtxInstruction::Arith(ArithOp::Add {
269            dst: reg(RegKind::R, 0, PtxType::U32),
270            lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
271            rhs: Operand::ImmU32(1),
272            ty: PtxType::U32,
273        }));
274        // 1 ld.global + 1 st.global
275        kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
276            dst: reg(RegKind::F, 0, PtxType::F32),
277            addr: reg(RegKind::Rd, 0, PtxType::U64),
278            ty: PtxType::F32,
279        }));
280        kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
281            addr: reg(RegKind::Rd, 0, PtxType::U64),
282            src: reg(RegKind::F, 0, PtxType::F32),
283            ty: PtxType::F32,
284        }));
285        // 1 ld.shared + 1 st.shared
286        kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
287            dst: reg(RegKind::F, 0, PtxType::F32),
288            addr: reg(RegKind::R, 0, PtxType::U32),
289            ty: PtxType::F32,
290        }));
291        kernel.push(PtxInstruction::Memory(MemoryOp::StShared {
292            addr: reg(RegKind::R, 0, PtxType::U32),
293            src: reg(RegKind::F, 0, PtxType::F32),
294            ty: PtxType::F32,
295        }));
296        // 1 ld.param (memory, total-only)
297        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
298            dst: reg(RegKind::Rd, 0, PtxType::U64),
299            param_name: "p0".to_string(),
300            ty: PtxType::U64,
301        }));
302        // 1 bar.sync
303        kernel.push(PtxInstruction::Control(ControlOp::BarSync {
304            barrier_id: 0,
305        }));
306        // 1 branch + 1 setp
307        kernel.push(PtxInstruction::Control(ControlOp::BraPred {
308            pred: reg(RegKind::P, 0, PtxType::Pred),
309            target: "L0".to_string(),
310            negate: false,
311        }));
312        kernel.push(PtxInstruction::Control(ControlOp::SetP {
313            dst: reg(RegKind::P, 0, PtxType::Pred),
314            cmp_op: crate::instr::control::CmpOp::Lt,
315            lhs: Operand::Reg(reg(RegKind::R, 0, PtxType::U32)),
316            rhs: Operand::ImmU32(10),
317            ty: PtxType::U32,
318        }));
319        // 1 mov + 1 cvt
320        kernel.push(PtxInstruction::Mov {
321            dst: reg(RegKind::R, 0, PtxType::U32),
322            src: Operand::ImmU32(0),
323            ty: PtxType::U32,
324        });
325        kernel.push(PtxInstruction::Cvt {
326            dst: reg(RegKind::F, 0, PtxType::F32),
327            src: reg(RegKind::R, 0, PtxType::U32),
328            dst_ty: PtxType::F32,
329            src_ty: PtxType::U32,
330        });
331        // 1 ret
332        kernel.push(PtxInstruction::Control(ControlOp::Ret));
333        // Label + Comment — should not count
334        kernel.push(PtxInstruction::Label("L0".to_string()));
335        kernel.push(PtxInstruction::Comment("test".to_string()));
336
337        let s = kernel.stats();
338        // 2 fma + 1 add + 1 ld.global + 1 st.global + 1 ld.shared +
339        // 1 st.shared + 1 ld.param + 1 bar.sync + 1 branch + 1 setp +
340        // 1 mov + 1 cvt + 1 ret = 14
341        assert_eq!(s.total_instructions, 14);
342        assert_eq!(s.fma, 2);
343        assert_eq!(s.arith_other, 1);
344        assert_eq!(s.ld_global, 1);
345        assert_eq!(s.st_global, 1);
346        assert_eq!(s.ld_shared, 1);
347        assert_eq!(s.st_shared, 1);
348        assert_eq!(s.bar_sync, 1);
349        assert_eq!(s.branches, 1);
350        assert_eq!(s.setp, 1);
351        assert_eq!(s.mov, 1);
352        assert_eq!(s.cvt, 1);
353    }
354
355    #[test]
356    fn stats_counts_registers_by_kind() {
357        let mut kernel = PtxKernel::new("test");
358        kernel.set_registers(vec![
359            reg(RegKind::R, 0, PtxType::U32),
360            reg(RegKind::R, 1, PtxType::S32),
361            reg(RegKind::R, 2, PtxType::U32),
362            reg(RegKind::Rd, 0, PtxType::U64),
363            reg(RegKind::F, 0, PtxType::F32),
364            reg(RegKind::F, 1, PtxType::F32),
365            reg(RegKind::Fd, 0, PtxType::F64),
366            reg(RegKind::P, 0, PtxType::Pred),
367            reg(RegKind::P, 1, PtxType::Pred),
368        ]);
369
370        let s = kernel.stats();
371        assert_eq!(s.registers_r, 3);
372        assert_eq!(s.registers_rd, 1);
373        assert_eq!(s.registers_f, 2);
374        assert_eq!(s.registers_fd, 1);
375        assert_eq!(s.registers_p, 2);
376    }
377
378    #[test]
379    fn stats_counts_tensor_core_and_cp_async() {
380        use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
381        use crate::instr::MmaShape;
382        use crate::ir::RegisterAllocator;
383
384        let mut alloc = RegisterAllocator::new();
385        let mut kernel = PtxKernel::new("tc_stats_test");
386
387        // 2 mma.sync
388        for _ in 0..2 {
389            kernel.push(PtxInstruction::TensorCore(
390                crate::instr::TensorCoreOp::MmaSync {
391                    d: alloc_c(&mut alloc),
392                    a: alloc_a_f16(&mut alloc),
393                    b: alloc_b_f16(&mut alloc),
394                    c: alloc_c(&mut alloc),
395                    shape: MmaShape::M16N8K16,
396                    d_ty: PtxType::F32,
397                    a_ty: PtxType::F16,
398                    b_ty: PtxType::F16,
399                    c_ty: PtxType::F32,
400                },
401            ));
402        }
403
404        // 3 cp.async loads, 1 commit, 1 wait
405        let dst_shared = reg(RegKind::R, 0, PtxType::U32);
406        let src_global = reg(RegKind::Rd, 0, PtxType::U64);
407        for _ in 0..3 {
408            kernel.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
409                dst_shared, src_global, 16,
410            )));
411        }
412        kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncCommitGroup));
413        kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncWaitGroup { n: 0 }));
414
415        let s = kernel.stats();
416        assert_eq!(s.mma, 2);
417        assert_eq!(s.cp_async, 3);
418        assert_eq!(s.cp_async_commit, 1);
419        assert_eq!(s.cp_async_wait, 1);
420        // 2 mma + 3 cp.async + 1 commit + 1 wait = 7 total
421        assert_eq!(s.total_instructions, 7);
422    }
423
424    #[test]
425    fn stats_counts_shared_bytes() {
426        let mut kernel = PtxKernel::new("test");
427        kernel.add_shared_decl(SharedDecl {
428            name: "tile_a".to_string(),
429            align: 4,
430            size_bytes: 4352, // 64 * 17 * 4
431        });
432        kernel.add_shared_decl(SharedDecl {
433            name: "tile_b".to_string(),
434            align: 4,
435            size_bytes: 4160, // 16 * 65 * 4
436        });
437
438        let s = kernel.stats();
439        assert_eq!(s.shared_bytes, 4352 + 4160);
440    }
441}