Skip to main content

kaio_core/ir/
kernel.rs

1//! PTX kernel — a single `.entry` function in a PTX module.
2
3use super::instruction::PtxInstruction;
4use super::param::PtxParam;
5use super::register::Register;
6use crate::instr::ArithOp;
7use crate::instr::control::ControlOp;
8use crate::instr::memory::MemoryOp;
9use crate::instr::tensor_core::TensorCoreOp;
10use crate::types::RegKind;
11
12/// Shared memory declaration in a PTX kernel preamble.
13///
14/// Emitted as `.shared .align {align} .b8 {name}[{size_bytes}];` after
15/// register declarations.
16#[derive(Debug, Clone)]
17pub struct SharedDecl {
18    /// Name of the shared memory allocation (e.g., `"sdata"`).
19    pub name: String,
20    /// Alignment in bytes (4 for f32, 8 for f64).
21    pub align: u32,
22    /// Total allocation size in bytes.
23    pub size_bytes: u32,
24}
25
26/// A PTX kernel function (`.visible .entry`).
27///
28/// Built by constructing parameters, allocating registers, and pushing
29/// instructions. Call [`set_registers`](Self::set_registers) with the
30/// allocator's output before emission so the kernel knows which `.reg`
31/// declarations to emit.
32#[derive(Debug, Clone)]
33pub struct PtxKernel {
34    /// Kernel entry point name.
35    pub name: String,
36    /// Declared parameters (in signature order).
37    pub params: Vec<PtxParam>,
38    /// Instruction body.
39    pub body: Vec<PtxInstruction>,
40    /// All registers used, for `.reg` declaration emission.
41    pub registers: Vec<Register>,
42    /// Shared memory declarations (emitted after register declarations).
43    pub shared_decls: Vec<SharedDecl>,
44}
45
46impl PtxKernel {
47    /// Create a new empty kernel with the given name.
48    pub fn new(name: &str) -> Self {
49        Self {
50            name: name.to_string(),
51            params: Vec::new(),
52            body: Vec::new(),
53            registers: Vec::new(),
54            shared_decls: Vec::new(),
55        }
56    }
57
58    /// Add a parameter to the kernel signature.
59    pub fn add_param(&mut self, param: PtxParam) {
60        self.params.push(param);
61    }
62
63    /// Append an instruction to the kernel body.
64    pub fn push(&mut self, instr: PtxInstruction) {
65        self.body.push(instr);
66    }
67
68    /// Set the register list (from [`super::register::RegisterAllocator::into_allocated`]).
69    pub fn set_registers(&mut self, regs: Vec<Register>) {
70        self.registers = regs;
71    }
72
73    /// Add a shared memory declaration to the kernel preamble.
74    pub fn add_shared_decl(&mut self, decl: SharedDecl) {
75        self.shared_decls.push(decl);
76    }
77
78    /// Compute structural statistics about this kernel's emitted PTX.
79    ///
80    /// Walks the instruction body and counts instruction types, registers
81    /// by kind, and declared shared memory. Useful for inspection and
82    /// comparison between kernel variants.
83    ///
84    /// These are **not** runtime profiling data — final hardware register
85    /// allocation and occupancy may differ after CUDA driver compilation.
86    pub fn stats(&self) -> KernelStats {
87        let mut s = KernelStats::default();
88
89        for instr in &self.body {
90            match instr {
91                PtxInstruction::Arith(op) => {
92                    s.total_instructions += 1;
93                    if matches!(op, ArithOp::Fma { .. }) {
94                        s.fma += 1;
95                    } else {
96                        s.arith_other += 1;
97                    }
98                }
99                PtxInstruction::Memory(op) => {
100                    s.total_instructions += 1;
101                    match op {
102                        MemoryOp::LdGlobal { .. } => s.ld_global += 1,
103                        MemoryOp::StGlobal { .. } => s.st_global += 1,
104                        MemoryOp::LdShared { .. } => s.ld_shared += 1,
105                        MemoryOp::StShared { .. } => s.st_shared += 1,
106                        MemoryOp::CpAsyncCaSharedGlobal { .. } => s.cp_async += 1,
107                        MemoryOp::CpAsyncCommitGroup => s.cp_async_commit += 1,
108                        MemoryOp::CpAsyncWaitGroup { .. } => s.cp_async_wait += 1,
109                        _ => {}
110                    }
111                }
112                PtxInstruction::TensorCore(op) => {
113                    s.total_instructions += 1;
114                    match op {
115                        TensorCoreOp::MmaSync { .. } | TensorCoreOp::MmaSyncInt8 { .. } => {
116                            s.mma += 1
117                        }
118                    }
119                }
120                PtxInstruction::Control(op) => {
121                    s.total_instructions += 1;
122                    match op {
123                        ControlOp::BarSync { .. } => s.bar_sync += 1,
124                        ControlOp::BraPred { .. } | ControlOp::Bra { .. } => s.branches += 1,
125                        ControlOp::SetP { .. } => s.setp += 1,
126                        _ => {}
127                    }
128                }
129                PtxInstruction::Mov { .. } => {
130                    s.total_instructions += 1;
131                    s.mov += 1;
132                }
133                PtxInstruction::Cvt { .. } => {
134                    s.total_instructions += 1;
135                    s.cvt += 1;
136                }
137                PtxInstruction::MovPack { .. } => {
138                    s.total_instructions += 1;
139                    s.mov += 1;
140                }
141                PtxInstruction::Label(_) | PtxInstruction::Comment(_) => {}
142            }
143        }
144
145        for reg in &self.registers {
146            match reg.kind {
147                RegKind::R => s.registers_r += 1,
148                RegKind::Rd => s.registers_rd += 1,
149                RegKind::F => s.registers_f += 1,
150                RegKind::Fd => s.registers_fd += 1,
151                RegKind::P => s.registers_p += 1,
152                RegKind::H => s.registers_h += 1,
153                RegKind::Hb => s.registers_hb += 1,
154            }
155        }
156
157        s.shared_bytes = self.shared_decls.iter().map(|d| d.size_bytes).sum();
158
159        s
160    }
161}
162
163/// Structural statistics about a compiled kernel's emitted PTX.
164///
165/// These describe the instruction mix and declared resource usage in
166/// KAIO's generated PTX — useful for inspection and comparison between
167/// kernel variants, but **not** a substitute for runtime profiling.
168/// Final hardware register allocation and occupancy may differ from
169/// these counts after the CUDA driver's backend compilation (PTX → SASS).
170#[derive(Debug, Default, PartialEq, Eq)]
171pub struct KernelStats {
172    /// Total instructions (excludes labels and comments).
173    pub total_instructions: usize,
174    /// `ld.global` count.
175    pub ld_global: usize,
176    /// `st.global` count.
177    pub st_global: usize,
178    /// `ld.shared` count.
179    pub ld_shared: usize,
180    /// `st.shared` count.
181    pub st_shared: usize,
182    /// `bar.sync` count.
183    pub bar_sync: usize,
184    /// `mma.sync` instruction count (all tensor-core shapes).
185    pub mma: usize,
186    /// `cp.async.ca.shared.global` instruction count.
187    pub cp_async: usize,
188    /// `cp.async.commit_group` instruction count.
189    pub cp_async_commit: usize,
190    /// `cp.async.wait_group` instruction count.
191    pub cp_async_wait: usize,
192    /// `fma` instruction count.
193    pub fma: usize,
194    /// Non-FMA arithmetic instructions (add, mul, sub, etc.).
195    pub arith_other: usize,
196    /// `mov` instruction count.
197    pub mov: usize,
198    /// `cvt` instruction count.
199    pub cvt: usize,
200    /// Branch instructions (`bra`, `@pred bra`).
201    pub branches: usize,
202    /// `setp` comparison-to-predicate instructions.
203    pub setp: usize,
204    /// `%r` registers (32-bit integer).
205    pub registers_r: u32,
206    /// `%rd` registers (64-bit integer).
207    pub registers_rd: u32,
208    /// `%f` registers (f32).
209    pub registers_f: u32,
210    /// `%fd` registers (f64).
211    pub registers_fd: u32,
212    /// `%p` registers (predicate).
213    pub registers_p: u32,
214    /// `%h` registers (f16).
215    pub registers_h: u32,
216    /// `%hb` registers (bf16).
217    pub registers_hb: u32,
218    /// Total declared shared memory in bytes.
219    pub shared_bytes: u32,
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::ir::Operand;
226    use crate::types::PtxType;
227
228    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
229        Register {
230            kind,
231            index,
232            ptx_type,
233        }
234    }
235
236    #[test]
237    fn stats_empty_kernel() {
238        let kernel = PtxKernel::new("empty");
239        let s = kernel.stats();
240        assert_eq!(s, KernelStats::default());
241    }
242
243    #[test]
244    fn stats_counts_instruction_types() {
245        let mut kernel = PtxKernel::new("test");
246
247        // 2 FMA
248        for _ in 0..2 {
249            kernel.push(PtxInstruction::Arith(ArithOp::Fma {
250                dst: reg(RegKind::F, 0, PtxType::F32),
251                a: Operand::Reg(reg(RegKind::F, 1, PtxType::F32)),
252                b: Operand::Reg(reg(RegKind::F, 2, PtxType::F32)),
253                c: Operand::Reg(reg(RegKind::F, 3, PtxType::F32)),
254                ty: PtxType::F32,
255            }));
256        }
257        // 1 Add (arith_other)
258        kernel.push(PtxInstruction::Arith(ArithOp::Add {
259            dst: reg(RegKind::R, 0, PtxType::U32),
260            lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
261            rhs: Operand::ImmU32(1),
262            ty: PtxType::U32,
263        }));
264        // 1 ld.global + 1 st.global
265        kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
266            dst: reg(RegKind::F, 0, PtxType::F32),
267            addr: reg(RegKind::Rd, 0, PtxType::U64),
268            ty: PtxType::F32,
269        }));
270        kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
271            addr: reg(RegKind::Rd, 0, PtxType::U64),
272            src: reg(RegKind::F, 0, PtxType::F32),
273            ty: PtxType::F32,
274        }));
275        // 1 ld.shared + 1 st.shared
276        kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
277            dst: reg(RegKind::F, 0, PtxType::F32),
278            addr: reg(RegKind::R, 0, PtxType::U32),
279            ty: PtxType::F32,
280        }));
281        kernel.push(PtxInstruction::Memory(MemoryOp::StShared {
282            addr: reg(RegKind::R, 0, PtxType::U32),
283            src: reg(RegKind::F, 0, PtxType::F32),
284            ty: PtxType::F32,
285        }));
286        // 1 ld.param (memory, total-only)
287        kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
288            dst: reg(RegKind::Rd, 0, PtxType::U64),
289            param_name: "p0".to_string(),
290            ty: PtxType::U64,
291        }));
292        // 1 bar.sync
293        kernel.push(PtxInstruction::Control(ControlOp::BarSync {
294            barrier_id: 0,
295        }));
296        // 1 branch + 1 setp
297        kernel.push(PtxInstruction::Control(ControlOp::BraPred {
298            pred: reg(RegKind::P, 0, PtxType::Pred),
299            target: "L0".to_string(),
300            negate: false,
301        }));
302        kernel.push(PtxInstruction::Control(ControlOp::SetP {
303            dst: reg(RegKind::P, 0, PtxType::Pred),
304            cmp_op: crate::instr::control::CmpOp::Lt,
305            lhs: Operand::Reg(reg(RegKind::R, 0, PtxType::U32)),
306            rhs: Operand::ImmU32(10),
307            ty: PtxType::U32,
308        }));
309        // 1 mov + 1 cvt
310        kernel.push(PtxInstruction::Mov {
311            dst: reg(RegKind::R, 0, PtxType::U32),
312            src: Operand::ImmU32(0),
313            ty: PtxType::U32,
314        });
315        kernel.push(PtxInstruction::Cvt {
316            dst: reg(RegKind::F, 0, PtxType::F32),
317            src: reg(RegKind::R, 0, PtxType::U32),
318            dst_ty: PtxType::F32,
319            src_ty: PtxType::U32,
320        });
321        // 1 ret
322        kernel.push(PtxInstruction::Control(ControlOp::Ret));
323        // Label + Comment — should not count
324        kernel.push(PtxInstruction::Label("L0".to_string()));
325        kernel.push(PtxInstruction::Comment("test".to_string()));
326
327        let s = kernel.stats();
328        // 2 fma + 1 add + 1 ld.global + 1 st.global + 1 ld.shared +
329        // 1 st.shared + 1 ld.param + 1 bar.sync + 1 branch + 1 setp +
330        // 1 mov + 1 cvt + 1 ret = 14
331        assert_eq!(s.total_instructions, 14);
332        assert_eq!(s.fma, 2);
333        assert_eq!(s.arith_other, 1);
334        assert_eq!(s.ld_global, 1);
335        assert_eq!(s.st_global, 1);
336        assert_eq!(s.ld_shared, 1);
337        assert_eq!(s.st_shared, 1);
338        assert_eq!(s.bar_sync, 1);
339        assert_eq!(s.branches, 1);
340        assert_eq!(s.setp, 1);
341        assert_eq!(s.mov, 1);
342        assert_eq!(s.cvt, 1);
343    }
344
345    #[test]
346    fn stats_counts_registers_by_kind() {
347        let mut kernel = PtxKernel::new("test");
348        kernel.set_registers(vec![
349            reg(RegKind::R, 0, PtxType::U32),
350            reg(RegKind::R, 1, PtxType::S32),
351            reg(RegKind::R, 2, PtxType::U32),
352            reg(RegKind::Rd, 0, PtxType::U64),
353            reg(RegKind::F, 0, PtxType::F32),
354            reg(RegKind::F, 1, PtxType::F32),
355            reg(RegKind::Fd, 0, PtxType::F64),
356            reg(RegKind::P, 0, PtxType::Pred),
357            reg(RegKind::P, 1, PtxType::Pred),
358        ]);
359
360        let s = kernel.stats();
361        assert_eq!(s.registers_r, 3);
362        assert_eq!(s.registers_rd, 1);
363        assert_eq!(s.registers_f, 2);
364        assert_eq!(s.registers_fd, 1);
365        assert_eq!(s.registers_p, 2);
366    }
367
368    #[test]
369    fn stats_counts_tensor_core_and_cp_async() {
370        use crate::fragment::{alloc_a, alloc_b, alloc_c};
371        use crate::instr::MmaShape;
372        use crate::ir::RegisterAllocator;
373
374        let mut alloc = RegisterAllocator::new();
375        let mut kernel = PtxKernel::new("tc_stats_test");
376
377        // 2 mma.sync
378        for _ in 0..2 {
379            kernel.push(PtxInstruction::TensorCore(
380                crate::instr::TensorCoreOp::MmaSync {
381                    d: alloc_c(&mut alloc),
382                    a: alloc_a(&mut alloc),
383                    b: alloc_b(&mut alloc),
384                    c: alloc_c(&mut alloc),
385                    shape: MmaShape::M16N8K16,
386                    d_ty: PtxType::F32,
387                    a_ty: PtxType::F16,
388                    b_ty: PtxType::F16,
389                    c_ty: PtxType::F32,
390                },
391            ));
392        }
393
394        // 3 cp.async loads, 1 commit, 1 wait
395        let dst_shared = reg(RegKind::R, 0, PtxType::U32);
396        let src_global = reg(RegKind::Rd, 0, PtxType::U64);
397        for _ in 0..3 {
398            kernel.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
399                dst_shared, src_global, 16,
400            )));
401        }
402        kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncCommitGroup));
403        kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncWaitGroup { n: 0 }));
404
405        let s = kernel.stats();
406        assert_eq!(s.mma, 2);
407        assert_eq!(s.cp_async, 3);
408        assert_eq!(s.cp_async_commit, 1);
409        assert_eq!(s.cp_async_wait, 1);
410        // 2 mma + 3 cp.async + 1 commit + 1 wait = 7 total
411        assert_eq!(s.total_instructions, 7);
412    }
413
414    #[test]
415    fn stats_counts_shared_bytes() {
416        let mut kernel = PtxKernel::new("test");
417        kernel.add_shared_decl(SharedDecl {
418            name: "tile_a".to_string(),
419            align: 4,
420            size_bytes: 4352, // 64 * 17 * 4
421        });
422        kernel.add_shared_decl(SharedDecl {
423            name: "tile_b".to_string(),
424            align: 4,
425            size_bytes: 4160, // 16 * 65 * 4
426        });
427
428        let s = kernel.stats();
429        assert_eq!(s.shared_bytes, 4352 + 4160);
430    }
431}