kaio_core/instr/tensor_core.rs
1//! Tensor-core PTX operations.
2//!
3//! This module hosts the warp-collective tensor-core instructions. Phase 6
4//! supports a single shape — `m16n8k16` with fp16 inputs and fp32
5//! accumulation — which is the Ampere+ (SM 8.0) shape used by CUTLASS,
6//! cuBLAS, and every production fp16 matmul since 2020.
7//!
8//! Earlier shapes (Volta `m8n8k4`, Turing `m16n8k8`) have different
9//! fragment layouts and are out of scope for Phase 6.
10
11use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::fragment::{
15 FragmentA, FragmentA_M16N8K32, FragmentB, FragmentB_M16N8K32, FragmentC, FragmentC_M16N8K32,
16};
17use crate::types::PtxType;
18
19/// The shape of an `mma.sync` instruction.
20///
21/// Each variant corresponds to a distinct hardware tile geometry. Adding a
22/// new shape requires a new enum variant **and** a new set of fragment types
23/// because the per-thread register distribution is shape-dependent. See
24/// [`crate::fragment`].
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum MmaShape {
27 /// 16×16 × 16×8 → 16×8, fp16 / bf16 inputs with fp32 accumulate (Ampere+).
28 M16N8K16,
29 /// 16×32 × 32×8 → 16×8, signed int8 inputs with int32 accumulate (Ampere+).
30 ///
31 /// Used by `mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32`. K is 32,
32 /// not 16 — twice the K-tile of the fp16 path. Introduced in
33 /// Sprint 7.1 for INT8 dequantize-matmul.
34 M16N8K32,
35}
36
37impl MmaShape {
38 /// PTX shape token (e.g. `"m16n8k16"`).
39 pub fn ptx_token(&self) -> &'static str {
40 match self {
41 Self::M16N8K16 => "m16n8k16",
42 Self::M16N8K32 => "m16n8k32",
43 }
44 }
45
46 /// Minimum SM version required to execute this shape.
47 ///
48 /// Used by [`crate::ir::PtxModule::validate`] to reject kernels that
49 /// emit tensor-core ops against too-low a target SM.
50 pub fn min_sm(&self) -> u32 {
51 match self {
52 Self::M16N8K16 => 80,
53 Self::M16N8K32 => 80,
54 }
55 }
56}
57
58/// Tensor-core PTX instruction variants.
59///
60/// Warp-collective operations — `mma.sync` is executed cooperatively by
61/// all 32 threads in a warp with a rigid NVIDIA-defined register layout.
62/// See [`crate::fragment`] for the per-thread register distribution.
63#[derive(Debug, Clone)]
64pub enum TensorCoreOp {
65 /// Synchronous matrix-multiply-accumulate:
66 /// `mma.sync.aligned.{shape}.row.col.{d_ty}.{a_ty}.{b_ty}.{c_ty}`
67 /// `{d_regs}, {a_regs}, {b_regs}, {c_regs};`
68 ///
69 /// Computes `D = A * B + C` across the warp. A is row-major, B is
70 /// column-major (the `.row.col` modifiers are fixed for the fp16
71 /// `m16n8k16` form).
72 ///
73 /// Example emission:
74 /// ```text
75 /// mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
76 /// {%f4,%f5,%f6,%f7},
77 /// {%r0,%r1,%r2,%r3},
78 /// {%r4,%r5},
79 /// {%f0,%f1,%f2,%f3};
80 /// ```
81 MmaSync {
82 /// Destination (D) fragment — `.f32` accumulator output.
83 d: FragmentC,
84 /// Input A fragment — `.b32` packed half2 registers.
85 a: FragmentA,
86 /// Input B fragment — `.b32` packed half2 registers.
87 b: FragmentB,
88 /// Input C fragment — `.f32` accumulator input.
89 c: FragmentC,
90 /// Matrix shape (currently only [`MmaShape::M16N8K16`]).
91 shape: MmaShape,
92 /// D element type (currently `F32`).
93 d_ty: PtxType,
94 /// A element type (currently `F16` or `BF16`).
95 a_ty: PtxType,
96 /// B element type (currently `F16` or `BF16`).
97 b_ty: PtxType,
98 /// C element type (currently `F32`).
99 c_ty: PtxType,
100 },
101 /// INT8 matrix-multiply-accumulate:
102 /// `mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32`
103 /// `{d_regs}, {a_regs}, {b_regs}, {c_regs};`
104 ///
105 /// Computes `D = A * B + C` across the warp where A and B are signed
106 /// 8-bit integer matrices (packed i8x4 into `.b32` fragment registers)
107 /// and C/D are `.s32` accumulator matrices. A is row-major, B is
108 /// column-major (`.row.col`).
109 ///
110 /// Shape is implicitly [`MmaShape::M16N8K32`]; element types are
111 /// implicitly `.s32.s8.s8.s32`. Requires SM 8.0+ (Ampere or newer).
112 ///
113 /// Used as the fast-path compute primitive for `kaio_ops::matmul_int8`
114 /// (Sprint 7.1). A dequant-to-f16 fallback uses the regular
115 /// [`MmaSync`](Self::MmaSync) variant with `.f16.f16.f32` types.
116 ///
117 /// Example emission:
118 /// ```text
119 /// mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32
120 /// {%r8,%r9,%r10,%r11},
121 /// {%r0,%r1,%r2,%r3},
122 /// {%r4,%r5},
123 /// {%r12,%r13,%r14,%r15};
124 /// ```
125 MmaSyncInt8 {
126 /// Destination (D) fragment — four `.s32` accumulator registers.
127 d: FragmentC_M16N8K32,
128 /// Input A fragment — four `.b32` registers, each packing 4 signed i8.
129 a: FragmentA_M16N8K32,
130 /// Input B fragment — two `.b32` registers, each packing 4 signed i8.
131 b: FragmentB_M16N8K32,
132 /// Input C fragment — four `.s32` accumulator registers.
133 c: FragmentC_M16N8K32,
134 },
135}
136
137impl TensorCoreOp {
138 /// Minimum SM version required to execute this op.
139 pub fn min_sm(&self) -> u32 {
140 match self {
141 Self::MmaSync { shape, .. } => shape.min_sm(),
142 Self::MmaSyncInt8 { .. } => MmaShape::M16N8K32.min_sm(),
143 }
144 }
145
146 /// Short human-readable label used in validation errors
147 /// (e.g. `"mma.sync.m16n8k16"`).
148 pub fn feature_label(&self) -> String {
149 match self {
150 Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
151 Self::MmaSyncInt8 { .. } => {
152 format!("mma.sync.{}.s8.s8.s32", MmaShape::M16N8K32.ptx_token())
153 }
154 }
155 }
156}
157
158/// Format a fragment register list as `{%x0,%x1,...}` (no surrounding
159/// whitespace).
160fn format_reg_list(regs: &[crate::ir::Register]) -> String {
161 let joined = regs
162 .iter()
163 .map(|r| format!("{r}"))
164 .collect::<Vec<_>>()
165 .join(",");
166 format!("{{{joined}}}")
167}
168
169impl Emit for TensorCoreOp {
170 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
171 match self {
172 TensorCoreOp::MmaSync {
173 d,
174 a,
175 b,
176 c,
177 shape,
178 d_ty,
179 a_ty,
180 b_ty,
181 c_ty,
182 } => {
183 // mma.sync.aligned.{shape}.row.col.{dty}.{aty}.{bty}.{cty}
184 // For fp16/bf16 inputs with fp32 accumulate, .row.col is
185 // the only valid operand-layout modifier.
186 let mnemonic = format!(
187 "mma.sync.aligned.{}.row.col{}{}{}{}",
188 shape.ptx_token(),
189 d_ty.ptx_suffix(),
190 a_ty.ptx_suffix(),
191 b_ty.ptx_suffix(),
192 c_ty.ptx_suffix(),
193 );
194 let d_list = format_reg_list(&d.regs);
195 let a_list = format_reg_list(&a.regs);
196 let b_list = format_reg_list(&b.regs);
197 let c_list = format_reg_list(&c.regs);
198 w.instruction(
199 &mnemonic,
200 &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
201 )
202 }
203 TensorCoreOp::MmaSyncInt8 { d, a, b, c } => {
204 // Full instruction: mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32
205 // The `.row.col` layout qualifiers are mandatory per PTX ISA —
206 // A is row-major, B is col-major. Type suffix order is
207 // {d_ty}.{a_ty}.{b_ty}.{c_ty} — s32 accumulator with s8 inputs.
208 let mnemonic = "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32";
209 let d_list = format_reg_list(&d.regs);
210 let a_list = format_reg_list(&a.regs);
211 let b_list = format_reg_list(&b.regs);
212 let c_list = format_reg_list(&c.regs);
213 w.instruction(
214 mnemonic,
215 &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
216 )
217 }
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::fragment::{alloc_a, alloc_b, alloc_c};
226 use crate::ir::RegisterAllocator;
227
228 #[test]
229 fn mma_shape_token_and_min_sm() {
230 assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
231 assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
232 assert_eq!(MmaShape::M16N8K32.ptx_token(), "m16n8k32");
233 assert_eq!(MmaShape::M16N8K32.min_sm(), 80);
234 }
235
236 #[test]
237 fn emit_mma_sync_m16n8k16_f16_f32() {
238 let mut alloc = RegisterAllocator::new();
239 let a = alloc_a(&mut alloc);
240 let b = alloc_b(&mut alloc);
241 let c = alloc_c(&mut alloc);
242 let d = alloc_c(&mut alloc);
243
244 let op = TensorCoreOp::MmaSync {
245 d,
246 a,
247 b,
248 c,
249 shape: MmaShape::M16N8K16,
250 d_ty: PtxType::F32,
251 a_ty: PtxType::F16,
252 b_ty: PtxType::F16,
253 c_ty: PtxType::F32,
254 };
255
256 let mut w = PtxWriter::new();
257 w.indent();
258 op.emit(&mut w).unwrap();
259 let out = w.finish();
260
261 // Check the full line — operand order: D, A, B, C.
262 let expected = concat!(
263 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
264 "{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
265 );
266 assert_eq!(out, expected);
267 }
268
269 #[test]
270 fn emit_mma_sync_m16n8k16_bf16_f32() {
271 let mut alloc = RegisterAllocator::new();
272 let a = alloc_a(&mut alloc);
273 let b = alloc_b(&mut alloc);
274 let c = alloc_c(&mut alloc);
275 let d = alloc_c(&mut alloc);
276
277 let op = TensorCoreOp::MmaSync {
278 d,
279 a,
280 b,
281 c,
282 shape: MmaShape::M16N8K16,
283 d_ty: PtxType::F32,
284 a_ty: PtxType::BF16,
285 b_ty: PtxType::BF16,
286 c_ty: PtxType::F32,
287 };
288
289 let mut w = PtxWriter::new();
290 w.indent();
291 op.emit(&mut w).unwrap();
292 assert!(
293 w.finish()
294 .contains("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32")
295 );
296 }
297
298 #[test]
299 fn min_sm_and_feature_label() {
300 let mut alloc = RegisterAllocator::new();
301 let op = TensorCoreOp::MmaSync {
302 d: alloc_c(&mut alloc),
303 a: alloc_a(&mut alloc),
304 b: alloc_b(&mut alloc),
305 c: alloc_c(&mut alloc),
306 shape: MmaShape::M16N8K16,
307 d_ty: PtxType::F32,
308 a_ty: PtxType::F16,
309 b_ty: PtxType::F16,
310 c_ty: PtxType::F32,
311 };
312 assert_eq!(op.min_sm(), 80);
313 assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
314 }
315
316 #[test]
317 fn emit_mma_sync_int8_m16n8k32() {
318 use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
319 let mut alloc = RegisterAllocator::new();
320 let a = alloc_a_M16N8K32(&mut alloc);
321 let b = alloc_b_M16N8K32(&mut alloc);
322 let c = alloc_c_M16N8K32(&mut alloc);
323 let d = alloc_c_M16N8K32(&mut alloc);
324
325 let op = TensorCoreOp::MmaSyncInt8 { d, a, b, c };
326
327 let mut w = PtxWriter::new();
328 w.indent();
329 op.emit(&mut w).unwrap();
330 let out = w.finish();
331
332 // Register layout:
333 // A = 4 × alloc_packed_int8x4 → %r0..%r3
334 // B = 2 × alloc_packed_int8x4 → %r4..%r5
335 // C = 4 × alloc(S32) → %r6..%r9
336 // D = 4 × alloc(S32) → %r10..%r13
337 // All live in the %r class (S8/S32/U32 share RegKind::R).
338 // Operand order in mma: D, A, B, C.
339 let expected = concat!(
340 " mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 ",
341 "{%r10,%r11,%r12,%r13}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%r6,%r7,%r8,%r9};\n",
342 );
343 assert_eq!(out, expected);
344 }
345
346 #[test]
347 fn int8_min_sm_and_feature_label() {
348 use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
349 let mut alloc = RegisterAllocator::new();
350 let op = TensorCoreOp::MmaSyncInt8 {
351 d: alloc_c_M16N8K32(&mut alloc),
352 a: alloc_a_M16N8K32(&mut alloc),
353 b: alloc_b_M16N8K32(&mut alloc),
354 c: alloc_c_M16N8K32(&mut alloc),
355 };
356 assert_eq!(op.min_sm(), 80);
357 assert_eq!(op.feature_label(), "mma.sync.m16n8k32.s8.s8.s32");
358 }
359
360 #[test]
361 fn tensor_core_via_ptx_instruction() {
362 use crate::ir::PtxInstruction;
363 let mut alloc = RegisterAllocator::new();
364 let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
365 d: alloc_c(&mut alloc),
366 a: alloc_a(&mut alloc),
367 b: alloc_b(&mut alloc),
368 c: alloc_c(&mut alloc),
369 shape: MmaShape::M16N8K16,
370 d_ty: PtxType::F32,
371 a_ty: PtxType::F16,
372 b_ty: PtxType::F16,
373 c_ty: PtxType::F32,
374 });
375 let mut w = PtxWriter::new();
376 w.indent();
377 instr.emit(&mut w).unwrap();
378 assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
379 }
380}