1use super::instruction::PtxInstruction;
4use super::param::PtxParam;
5use super::register::Register;
6use crate::instr::ArithOp;
7use crate::instr::control::ControlOp;
8use crate::instr::memory::MemoryOp;
9use crate::instr::tensor_core::TensorCoreOp;
10use crate::types::RegKind;
11
12#[derive(Debug, Clone)]
17pub struct SharedDecl {
18 pub name: String,
20 pub align: u32,
22 pub size_bytes: u32,
24}
25
26#[derive(Debug, Clone)]
33pub struct PtxKernel {
34 pub name: String,
36 pub params: Vec<PtxParam>,
38 pub body: Vec<PtxInstruction>,
40 pub registers: Vec<Register>,
42 pub shared_decls: Vec<SharedDecl>,
44}
45
46impl PtxKernel {
47 pub fn new(name: &str) -> Self {
49 Self {
50 name: name.to_string(),
51 params: Vec::new(),
52 body: Vec::new(),
53 registers: Vec::new(),
54 shared_decls: Vec::new(),
55 }
56 }
57
58 pub fn add_param(&mut self, param: PtxParam) {
60 self.params.push(param);
61 }
62
63 pub fn push(&mut self, instr: PtxInstruction) {
65 self.body.push(instr);
66 }
67
68 pub fn set_registers(&mut self, regs: Vec<Register>) {
70 self.registers = regs;
71 }
72
73 pub fn add_shared_decl(&mut self, decl: SharedDecl) {
75 self.shared_decls.push(decl);
76 }
77
78 pub fn stats(&self) -> KernelStats {
87 let mut s = KernelStats::default();
88
89 for instr in &self.body {
90 match instr {
91 PtxInstruction::Arith(op) => {
92 s.total_instructions += 1;
93 if matches!(op, ArithOp::Fma { .. }) {
94 s.fma += 1;
95 } else {
96 s.arith_other += 1;
97 }
98 }
99 PtxInstruction::Memory(op) => {
100 s.total_instructions += 1;
101 match op {
102 MemoryOp::LdGlobal { .. } => s.ld_global += 1,
103 MemoryOp::StGlobal { .. } => s.st_global += 1,
104 MemoryOp::LdShared { .. } => s.ld_shared += 1,
105 MemoryOp::StShared { .. } => s.st_shared += 1,
106 MemoryOp::CpAsyncCaSharedGlobal { .. } => s.cp_async += 1,
107 MemoryOp::CpAsyncCommitGroup => s.cp_async_commit += 1,
108 MemoryOp::CpAsyncWaitGroup { .. } => s.cp_async_wait += 1,
109 _ => {}
110 }
111 }
112 PtxInstruction::TensorCore(op) => {
113 s.total_instructions += 1;
114 match op {
115 TensorCoreOp::MmaSync { .. }
116 | TensorCoreOp::MmaSyncInt8 { .. }
117 | TensorCoreOp::MmaSyncBf16 { .. } => s.mma += 1,
118 TensorCoreOp::LdMatrix { .. } => s.ldmatrix += 1,
124 }
125 }
126 PtxInstruction::Control(op) => {
127 s.total_instructions += 1;
128 match op {
129 ControlOp::BarSync { .. } => s.bar_sync += 1,
130 ControlOp::BraPred { .. } | ControlOp::Bra { .. } => s.branches += 1,
131 ControlOp::SetP { .. } => s.setp += 1,
132 _ => {}
133 }
134 }
135 PtxInstruction::Mov { .. } => {
136 s.total_instructions += 1;
137 s.mov += 1;
138 }
139 PtxInstruction::Cvt { .. } => {
140 s.total_instructions += 1;
141 s.cvt += 1;
142 }
143 PtxInstruction::MovPack { .. } => {
144 s.total_instructions += 1;
145 s.mov += 1;
146 }
147 PtxInstruction::Label(_) | PtxInstruction::Comment(_) => {}
148 }
149 }
150
151 for reg in &self.registers {
152 match reg.kind {
153 RegKind::R => s.registers_r += 1,
154 RegKind::Rd => s.registers_rd += 1,
155 RegKind::F => s.registers_f += 1,
156 RegKind::Fd => s.registers_fd += 1,
157 RegKind::P => s.registers_p += 1,
158 RegKind::H => s.registers_h += 1,
159 RegKind::Hb => s.registers_hb += 1,
160 }
161 }
162
163 s.shared_bytes = self.shared_decls.iter().map(|d| d.size_bytes).sum();
164
165 s
166 }
167}
168
169#[derive(Debug, Default, PartialEq, Eq)]
177pub struct KernelStats {
178 pub total_instructions: usize,
180 pub ld_global: usize,
182 pub st_global: usize,
184 pub ld_shared: usize,
186 pub st_shared: usize,
188 pub bar_sync: usize,
190 pub mma: usize,
192 pub ldmatrix: usize,
196 pub cp_async: usize,
198 pub cp_async_commit: usize,
200 pub cp_async_wait: usize,
202 pub fma: usize,
204 pub arith_other: usize,
206 pub mov: usize,
208 pub cvt: usize,
210 pub branches: usize,
212 pub setp: usize,
214 pub registers_r: u32,
216 pub registers_rd: u32,
218 pub registers_f: u32,
220 pub registers_fd: u32,
222 pub registers_p: u32,
224 pub registers_h: u32,
226 pub registers_hb: u32,
228 pub shared_bytes: u32,
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::ir::Operand;
236 use crate::types::PtxType;
237
238 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
239 Register {
240 kind,
241 index,
242 ptx_type,
243 }
244 }
245
246 #[test]
247 fn stats_empty_kernel() {
248 let kernel = PtxKernel::new("empty");
249 let s = kernel.stats();
250 assert_eq!(s, KernelStats::default());
251 }
252
253 #[test]
254 fn stats_counts_instruction_types() {
255 let mut kernel = PtxKernel::new("test");
256
257 for _ in 0..2 {
259 kernel.push(PtxInstruction::Arith(ArithOp::Fma {
260 dst: reg(RegKind::F, 0, PtxType::F32),
261 a: Operand::Reg(reg(RegKind::F, 1, PtxType::F32)),
262 b: Operand::Reg(reg(RegKind::F, 2, PtxType::F32)),
263 c: Operand::Reg(reg(RegKind::F, 3, PtxType::F32)),
264 ty: PtxType::F32,
265 }));
266 }
267 kernel.push(PtxInstruction::Arith(ArithOp::Add {
269 dst: reg(RegKind::R, 0, PtxType::U32),
270 lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
271 rhs: Operand::ImmU32(1),
272 ty: PtxType::U32,
273 }));
274 kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
276 dst: reg(RegKind::F, 0, PtxType::F32),
277 addr: reg(RegKind::Rd, 0, PtxType::U64),
278 ty: PtxType::F32,
279 }));
280 kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
281 addr: reg(RegKind::Rd, 0, PtxType::U64),
282 src: reg(RegKind::F, 0, PtxType::F32),
283 ty: PtxType::F32,
284 }));
285 kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
287 dst: reg(RegKind::F, 0, PtxType::F32),
288 addr: reg(RegKind::R, 0, PtxType::U32),
289 ty: PtxType::F32,
290 }));
291 kernel.push(PtxInstruction::Memory(MemoryOp::StShared {
292 addr: reg(RegKind::R, 0, PtxType::U32),
293 src: reg(RegKind::F, 0, PtxType::F32),
294 ty: PtxType::F32,
295 }));
296 kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
298 dst: reg(RegKind::Rd, 0, PtxType::U64),
299 param_name: "p0".to_string(),
300 ty: PtxType::U64,
301 }));
302 kernel.push(PtxInstruction::Control(ControlOp::BarSync {
304 barrier_id: 0,
305 }));
306 kernel.push(PtxInstruction::Control(ControlOp::BraPred {
308 pred: reg(RegKind::P, 0, PtxType::Pred),
309 target: "L0".to_string(),
310 negate: false,
311 }));
312 kernel.push(PtxInstruction::Control(ControlOp::SetP {
313 dst: reg(RegKind::P, 0, PtxType::Pred),
314 cmp_op: crate::instr::control::CmpOp::Lt,
315 lhs: Operand::Reg(reg(RegKind::R, 0, PtxType::U32)),
316 rhs: Operand::ImmU32(10),
317 ty: PtxType::U32,
318 }));
319 kernel.push(PtxInstruction::Mov {
321 dst: reg(RegKind::R, 0, PtxType::U32),
322 src: Operand::ImmU32(0),
323 ty: PtxType::U32,
324 });
325 kernel.push(PtxInstruction::Cvt {
326 dst: reg(RegKind::F, 0, PtxType::F32),
327 src: reg(RegKind::R, 0, PtxType::U32),
328 dst_ty: PtxType::F32,
329 src_ty: PtxType::U32,
330 });
331 kernel.push(PtxInstruction::Control(ControlOp::Ret));
333 kernel.push(PtxInstruction::Label("L0".to_string()));
335 kernel.push(PtxInstruction::Comment("test".to_string()));
336
337 let s = kernel.stats();
338 assert_eq!(s.total_instructions, 14);
342 assert_eq!(s.fma, 2);
343 assert_eq!(s.arith_other, 1);
344 assert_eq!(s.ld_global, 1);
345 assert_eq!(s.st_global, 1);
346 assert_eq!(s.ld_shared, 1);
347 assert_eq!(s.st_shared, 1);
348 assert_eq!(s.bar_sync, 1);
349 assert_eq!(s.branches, 1);
350 assert_eq!(s.setp, 1);
351 assert_eq!(s.mov, 1);
352 assert_eq!(s.cvt, 1);
353 }
354
355 #[test]
356 fn stats_counts_registers_by_kind() {
357 let mut kernel = PtxKernel::new("test");
358 kernel.set_registers(vec![
359 reg(RegKind::R, 0, PtxType::U32),
360 reg(RegKind::R, 1, PtxType::S32),
361 reg(RegKind::R, 2, PtxType::U32),
362 reg(RegKind::Rd, 0, PtxType::U64),
363 reg(RegKind::F, 0, PtxType::F32),
364 reg(RegKind::F, 1, PtxType::F32),
365 reg(RegKind::Fd, 0, PtxType::F64),
366 reg(RegKind::P, 0, PtxType::Pred),
367 reg(RegKind::P, 1, PtxType::Pred),
368 ]);
369
370 let s = kernel.stats();
371 assert_eq!(s.registers_r, 3);
372 assert_eq!(s.registers_rd, 1);
373 assert_eq!(s.registers_f, 2);
374 assert_eq!(s.registers_fd, 1);
375 assert_eq!(s.registers_p, 2);
376 }
377
378 #[test]
379 fn stats_counts_tensor_core_and_cp_async() {
380 use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
381 use crate::instr::MmaShape;
382 use crate::ir::RegisterAllocator;
383
384 let mut alloc = RegisterAllocator::new();
385 let mut kernel = PtxKernel::new("tc_stats_test");
386
387 for _ in 0..2 {
389 kernel.push(PtxInstruction::TensorCore(
390 crate::instr::TensorCoreOp::MmaSync {
391 d: alloc_c(&mut alloc),
392 a: alloc_a_f16(&mut alloc),
393 b: alloc_b_f16(&mut alloc),
394 c: alloc_c(&mut alloc),
395 shape: MmaShape::M16N8K16,
396 d_ty: PtxType::F32,
397 a_ty: PtxType::F16,
398 b_ty: PtxType::F16,
399 c_ty: PtxType::F32,
400 },
401 ));
402 }
403
404 let dst_shared = reg(RegKind::R, 0, PtxType::U32);
406 let src_global = reg(RegKind::Rd, 0, PtxType::U64);
407 for _ in 0..3 {
408 kernel.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
409 dst_shared, src_global, 16,
410 )));
411 }
412 kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncCommitGroup));
413 kernel.push(PtxInstruction::Memory(MemoryOp::CpAsyncWaitGroup { n: 0 }));
414
415 let s = kernel.stats();
416 assert_eq!(s.mma, 2);
417 assert_eq!(s.cp_async, 3);
418 assert_eq!(s.cp_async_commit, 1);
419 assert_eq!(s.cp_async_wait, 1);
420 assert_eq!(s.total_instructions, 7);
422 }
423
424 #[test]
425 fn stats_counts_shared_bytes() {
426 let mut kernel = PtxKernel::new("test");
427 kernel.add_shared_decl(SharedDecl {
428 name: "tile_a".to_string(),
429 align: 4,
430 size_bytes: 4352, });
432 kernel.add_shared_decl(SharedDecl {
433 name: "tile_b".to_string(),
434 align: 4,
435 size_bytes: 4160, });
437
438 let s = kernel.stats();
439 assert_eq!(s.shared_bytes, 4352 + 4160);
440 }
441}