Skip to main content

kaio_core/instr/
tensor_core.rs

1//! Tensor-core PTX operations.
2//!
3//! This module hosts the warp-collective tensor-core instructions. Phase 6
4//! supports a single shape — `m16n8k16` with fp16 inputs and fp32
5//! accumulation — which is the Ampere+ (SM 8.0) shape used by CUTLASS,
6//! cuBLAS, and every production fp16 matmul since 2020.
7//!
8//! Earlier shapes (Volta `m8n8k4`, Turing `m16n8k8`) have different
9//! fragment layouts and are out of scope for Phase 6.
10
11use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::fragment::{
15    FragmentA, FragmentA_M16N8K32, FragmentB, FragmentB_M16N8K32, FragmentC, FragmentC_M16N8K32,
16};
17use crate::types::PtxType;
18
19/// The shape of an `mma.sync` instruction.
20///
21/// Each variant corresponds to a distinct hardware tile geometry. Adding a
22/// new shape requires a new enum variant **and** a new set of fragment types
23/// because the per-thread register distribution is shape-dependent. See
24/// [`crate::fragment`].
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum MmaShape {
27    /// 16×16 × 16×8 → 16×8, fp16 / bf16 inputs with fp32 accumulate (Ampere+).
28    M16N8K16,
29    /// 16×32 × 32×8 → 16×8, signed int8 inputs with int32 accumulate (Ampere+).
30    ///
31    /// Used by `mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32`. K is 32,
32    /// not 16 — twice the K-tile of the fp16 path. Introduced in
33    /// Sprint 7.1 for INT8 dequantize-matmul.
34    M16N8K32,
35}
36
37impl MmaShape {
38    /// PTX shape token (e.g. `"m16n8k16"`).
39    pub fn ptx_token(&self) -> &'static str {
40        match self {
41            Self::M16N8K16 => "m16n8k16",
42            Self::M16N8K32 => "m16n8k32",
43        }
44    }
45
46    /// Minimum SM version required to execute this shape.
47    ///
48    /// Used by [`crate::ir::PtxModule::validate`] to reject kernels that
49    /// emit tensor-core ops against too-low a target SM.
50    pub fn min_sm(&self) -> u32 {
51        match self {
52            Self::M16N8K16 => 80,
53            Self::M16N8K32 => 80,
54        }
55    }
56}
57
58/// Tensor-core PTX instruction variants.
59///
60/// Warp-collective operations — `mma.sync` is executed cooperatively by
61/// all 32 threads in a warp with a rigid NVIDIA-defined register layout.
62/// See [`crate::fragment`] for the per-thread register distribution.
63#[derive(Debug, Clone)]
64pub enum TensorCoreOp {
65    /// Synchronous matrix-multiply-accumulate:
66    /// `mma.sync.aligned.{shape}.row.col.{d_ty}.{a_ty}.{b_ty}.{c_ty}`
67    /// `{d_regs}, {a_regs}, {b_regs}, {c_regs};`
68    ///
69    /// Computes `D = A * B + C` across the warp. A is row-major, B is
70    /// column-major (the `.row.col` modifiers are fixed for the fp16
71    /// `m16n8k16` form).
72    ///
73    /// Example emission:
74    /// ```text
75    /// mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
76    ///     {%f4,%f5,%f6,%f7},
77    ///     {%r0,%r1,%r2,%r3},
78    ///     {%r4,%r5},
79    ///     {%f0,%f1,%f2,%f3};
80    /// ```
81    MmaSync {
82        /// Destination (D) fragment — `.f32` accumulator output.
83        d: FragmentC,
84        /// Input A fragment — `.b32` packed half2 registers.
85        a: FragmentA,
86        /// Input B fragment — `.b32` packed half2 registers.
87        b: FragmentB,
88        /// Input C fragment — `.f32` accumulator input.
89        c: FragmentC,
90        /// Matrix shape (currently only [`MmaShape::M16N8K16`]).
91        shape: MmaShape,
92        /// D element type (currently `F32`).
93        d_ty: PtxType,
94        /// A element type (currently `F16` or `BF16`).
95        a_ty: PtxType,
96        /// B element type (currently `F16` or `BF16`).
97        b_ty: PtxType,
98        /// C element type (currently `F32`).
99        c_ty: PtxType,
100    },
101    /// INT8 matrix-multiply-accumulate:
102    /// `mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32`
103    /// `{d_regs}, {a_regs}, {b_regs}, {c_regs};`
104    ///
105    /// Computes `D = A * B + C` across the warp where A and B are signed
106    /// 8-bit integer matrices (packed i8x4 into `.b32` fragment registers)
107    /// and C/D are `.s32` accumulator matrices. A is row-major, B is
108    /// column-major (`.row.col`).
109    ///
110    /// Shape is implicitly [`MmaShape::M16N8K32`]; element types are
111    /// implicitly `.s32.s8.s8.s32`. Requires SM 8.0+ (Ampere or newer).
112    ///
113    /// Used as the fast-path compute primitive for `kaio_ops::matmul_int8`
114    /// (Sprint 7.1). A dequant-to-f16 fallback uses the regular
115    /// [`MmaSync`](Self::MmaSync) variant with `.f16.f16.f32` types.
116    ///
117    /// Example emission:
118    /// ```text
119    /// mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32
120    ///     {%r8,%r9,%r10,%r11},
121    ///     {%r0,%r1,%r2,%r3},
122    ///     {%r4,%r5},
123    ///     {%r12,%r13,%r14,%r15};
124    /// ```
125    MmaSyncInt8 {
126        /// Destination (D) fragment — four `.s32` accumulator registers.
127        d: FragmentC_M16N8K32,
128        /// Input A fragment — four `.b32` registers, each packing 4 signed i8.
129        a: FragmentA_M16N8K32,
130        /// Input B fragment — two `.b32` registers, each packing 4 signed i8.
131        b: FragmentB_M16N8K32,
132        /// Input C fragment — four `.s32` accumulator registers.
133        c: FragmentC_M16N8K32,
134    },
135}
136
137impl TensorCoreOp {
138    /// Minimum SM version required to execute this op.
139    pub fn min_sm(&self) -> u32 {
140        match self {
141            Self::MmaSync { shape, .. } => shape.min_sm(),
142            Self::MmaSyncInt8 { .. } => MmaShape::M16N8K32.min_sm(),
143        }
144    }
145
146    /// Short human-readable label used in validation errors
147    /// (e.g. `"mma.sync.m16n8k16"`).
148    pub fn feature_label(&self) -> String {
149        match self {
150            Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
151            Self::MmaSyncInt8 { .. } => {
152                format!("mma.sync.{}.s8.s8.s32", MmaShape::M16N8K32.ptx_token())
153            }
154        }
155    }
156}
157
158/// Format a fragment register list as `{%x0,%x1,...}` (no surrounding
159/// whitespace).
160fn format_reg_list(regs: &[crate::ir::Register]) -> String {
161    let joined = regs
162        .iter()
163        .map(|r| format!("{r}"))
164        .collect::<Vec<_>>()
165        .join(",");
166    format!("{{{joined}}}")
167}
168
169impl Emit for TensorCoreOp {
170    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
171        match self {
172            TensorCoreOp::MmaSync {
173                d,
174                a,
175                b,
176                c,
177                shape,
178                d_ty,
179                a_ty,
180                b_ty,
181                c_ty,
182            } => {
183                // mma.sync.aligned.{shape}.row.col.{dty}.{aty}.{bty}.{cty}
184                // For fp16/bf16 inputs with fp32 accumulate, .row.col is
185                // the only valid operand-layout modifier.
186                let mnemonic = format!(
187                    "mma.sync.aligned.{}.row.col{}{}{}{}",
188                    shape.ptx_token(),
189                    d_ty.ptx_suffix(),
190                    a_ty.ptx_suffix(),
191                    b_ty.ptx_suffix(),
192                    c_ty.ptx_suffix(),
193                );
194                let d_list = format_reg_list(&d.regs);
195                let a_list = format_reg_list(&a.regs);
196                let b_list = format_reg_list(&b.regs);
197                let c_list = format_reg_list(&c.regs);
198                w.instruction(
199                    &mnemonic,
200                    &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
201                )
202            }
203            TensorCoreOp::MmaSyncInt8 { d, a, b, c } => {
204                // Full instruction: mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32
205                // The `.row.col` layout qualifiers are mandatory per PTX ISA —
206                // A is row-major, B is col-major. Type suffix order is
207                // {d_ty}.{a_ty}.{b_ty}.{c_ty} — s32 accumulator with s8 inputs.
208                let mnemonic = "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32";
209                let d_list = format_reg_list(&d.regs);
210                let a_list = format_reg_list(&a.regs);
211                let b_list = format_reg_list(&b.regs);
212                let c_list = format_reg_list(&c.regs);
213                w.instruction(
214                    mnemonic,
215                    &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
216                )
217            }
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::fragment::{alloc_a, alloc_b, alloc_c};
226    use crate::ir::RegisterAllocator;
227
228    #[test]
229    fn mma_shape_token_and_min_sm() {
230        assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
231        assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
232        assert_eq!(MmaShape::M16N8K32.ptx_token(), "m16n8k32");
233        assert_eq!(MmaShape::M16N8K32.min_sm(), 80);
234    }
235
236    #[test]
237    fn emit_mma_sync_m16n8k16_f16_f32() {
238        let mut alloc = RegisterAllocator::new();
239        let a = alloc_a(&mut alloc);
240        let b = alloc_b(&mut alloc);
241        let c = alloc_c(&mut alloc);
242        let d = alloc_c(&mut alloc);
243
244        let op = TensorCoreOp::MmaSync {
245            d,
246            a,
247            b,
248            c,
249            shape: MmaShape::M16N8K16,
250            d_ty: PtxType::F32,
251            a_ty: PtxType::F16,
252            b_ty: PtxType::F16,
253            c_ty: PtxType::F32,
254        };
255
256        let mut w = PtxWriter::new();
257        w.indent();
258        op.emit(&mut w).unwrap();
259        let out = w.finish();
260
261        // Check the full line — operand order: D, A, B, C.
262        let expected = concat!(
263            "    mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
264            "{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
265        );
266        assert_eq!(out, expected);
267    }
268
269    #[test]
270    fn emit_mma_sync_m16n8k16_bf16_f32() {
271        let mut alloc = RegisterAllocator::new();
272        let a = alloc_a(&mut alloc);
273        let b = alloc_b(&mut alloc);
274        let c = alloc_c(&mut alloc);
275        let d = alloc_c(&mut alloc);
276
277        let op = TensorCoreOp::MmaSync {
278            d,
279            a,
280            b,
281            c,
282            shape: MmaShape::M16N8K16,
283            d_ty: PtxType::F32,
284            a_ty: PtxType::BF16,
285            b_ty: PtxType::BF16,
286            c_ty: PtxType::F32,
287        };
288
289        let mut w = PtxWriter::new();
290        w.indent();
291        op.emit(&mut w).unwrap();
292        assert!(
293            w.finish()
294                .contains("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32")
295        );
296    }
297
298    #[test]
299    fn min_sm_and_feature_label() {
300        let mut alloc = RegisterAllocator::new();
301        let op = TensorCoreOp::MmaSync {
302            d: alloc_c(&mut alloc),
303            a: alloc_a(&mut alloc),
304            b: alloc_b(&mut alloc),
305            c: alloc_c(&mut alloc),
306            shape: MmaShape::M16N8K16,
307            d_ty: PtxType::F32,
308            a_ty: PtxType::F16,
309            b_ty: PtxType::F16,
310            c_ty: PtxType::F32,
311        };
312        assert_eq!(op.min_sm(), 80);
313        assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
314    }
315
316    #[test]
317    fn emit_mma_sync_int8_m16n8k32() {
318        use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
319        let mut alloc = RegisterAllocator::new();
320        let a = alloc_a_M16N8K32(&mut alloc);
321        let b = alloc_b_M16N8K32(&mut alloc);
322        let c = alloc_c_M16N8K32(&mut alloc);
323        let d = alloc_c_M16N8K32(&mut alloc);
324
325        let op = TensorCoreOp::MmaSyncInt8 { d, a, b, c };
326
327        let mut w = PtxWriter::new();
328        w.indent();
329        op.emit(&mut w).unwrap();
330        let out = w.finish();
331
332        // Register layout:
333        //   A = 4 × alloc_packed_int8x4  → %r0..%r3
334        //   B = 2 × alloc_packed_int8x4  → %r4..%r5
335        //   C = 4 × alloc(S32)           → %r6..%r9
336        //   D = 4 × alloc(S32)           → %r10..%r13
337        // All live in the %r class (S8/S32/U32 share RegKind::R).
338        // Operand order in mma: D, A, B, C.
339        let expected = concat!(
340            "    mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 ",
341            "{%r10,%r11,%r12,%r13}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%r6,%r7,%r8,%r9};\n",
342        );
343        assert_eq!(out, expected);
344    }
345
346    #[test]
347    fn int8_min_sm_and_feature_label() {
348        use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
349        let mut alloc = RegisterAllocator::new();
350        let op = TensorCoreOp::MmaSyncInt8 {
351            d: alloc_c_M16N8K32(&mut alloc),
352            a: alloc_a_M16N8K32(&mut alloc),
353            b: alloc_b_M16N8K32(&mut alloc),
354            c: alloc_c_M16N8K32(&mut alloc),
355        };
356        assert_eq!(op.min_sm(), 80);
357        assert_eq!(op.feature_label(), "mma.sync.m16n8k32.s8.s8.s32");
358    }
359
360    #[test]
361    fn tensor_core_via_ptx_instruction() {
362        use crate::ir::PtxInstruction;
363        let mut alloc = RegisterAllocator::new();
364        let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
365            d: alloc_c(&mut alloc),
366            a: alloc_a(&mut alloc),
367            b: alloc_b(&mut alloc),
368            c: alloc_c(&mut alloc),
369            shape: MmaShape::M16N8K16,
370            d_ty: PtxType::F32,
371            a_ty: PtxType::F16,
372            b_ty: PtxType::F16,
373            c_ty: PtxType::F32,
374        });
375        let mut w = PtxWriter::new();
376        w.indent();
377        instr.emit(&mut w).unwrap();
378        assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
379    }
380}