Skip to main content

kaio_core/instr/
tensor_core.rs

1//! Tensor-core PTX operations.
2//!
3//! This module hosts the warp-collective tensor-core instructions. Phase 6
4//! supports a single shape — `m16n8k16` with fp16 inputs and fp32
5//! accumulation — which is the Ampere+ (SM 8.0) shape used by CUTLASS,
6//! cuBLAS, and every production fp16 matmul since 2020.
7//!
8//! Earlier shapes (Volta `m8n8k4`, Turing `m16n8k8`) have different
9//! fragment layouts and are out of scope for Phase 6.
10
11use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::fragment::{FragmentA, FragmentB, FragmentC};
15use crate::types::PtxType;
16
17/// The shape of an `mma.sync` instruction.
18///
19/// Exactly one variant today. Adding a new shape later means adding a
20/// new enum variant **and** a new set of fragment types — the fragment
21/// register count is shape-dependent. See [`crate::fragment`].
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MmaShape {
24    /// 16×16 × 16×8 → 16×8, fp16 inputs with fp32 accumulate (Ampere+).
25    M16N8K16,
26}
27
28impl MmaShape {
29    /// PTX shape token (e.g. `"m16n8k16"`).
30    pub fn ptx_token(&self) -> &'static str {
31        match self {
32            Self::M16N8K16 => "m16n8k16",
33        }
34    }
35
36    /// Minimum SM version required to execute this shape.
37    ///
38    /// Used by [`crate::ir::PtxModule::validate`] to reject kernels that
39    /// emit tensor-core ops against too-low a target SM.
40    pub fn min_sm(&self) -> u32 {
41        match self {
42            Self::M16N8K16 => 80,
43        }
44    }
45}
46
47/// Tensor-core PTX instruction variants.
48///
49/// Warp-collective operations — `mma.sync` is executed cooperatively by
50/// all 32 threads in a warp with a rigid NVIDIA-defined register layout.
51/// See [`crate::fragment`] for the per-thread register distribution.
52#[derive(Debug, Clone)]
53pub enum TensorCoreOp {
54    /// Synchronous matrix-multiply-accumulate:
55    /// `mma.sync.aligned.{shape}.row.col.{d_ty}.{a_ty}.{b_ty}.{c_ty}`
56    /// `{d_regs}, {a_regs}, {b_regs}, {c_regs};`
57    ///
58    /// Computes `D = A * B + C` across the warp. A is row-major, B is
59    /// column-major (the `.row.col` modifiers are fixed for the fp16
60    /// `m16n8k16` form).
61    ///
62    /// Example emission:
63    /// ```text
64    /// mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
65    ///     {%f4,%f5,%f6,%f7},
66    ///     {%r0,%r1,%r2,%r3},
67    ///     {%r4,%r5},
68    ///     {%f0,%f1,%f2,%f3};
69    /// ```
70    MmaSync {
71        /// Destination (D) fragment — `.f32` accumulator output.
72        d: FragmentC,
73        /// Input A fragment — `.b32` packed half2 registers.
74        a: FragmentA,
75        /// Input B fragment — `.b32` packed half2 registers.
76        b: FragmentB,
77        /// Input C fragment — `.f32` accumulator input.
78        c: FragmentC,
79        /// Matrix shape (currently only [`MmaShape::M16N8K16`]).
80        shape: MmaShape,
81        /// D element type (currently `F32`).
82        d_ty: PtxType,
83        /// A element type (currently `F16` or `BF16`).
84        a_ty: PtxType,
85        /// B element type (currently `F16` or `BF16`).
86        b_ty: PtxType,
87        /// C element type (currently `F32`).
88        c_ty: PtxType,
89    },
90}
91
92impl TensorCoreOp {
93    /// Minimum SM version required to execute this op.
94    pub fn min_sm(&self) -> u32 {
95        match self {
96            Self::MmaSync { shape, .. } => shape.min_sm(),
97        }
98    }
99
100    /// Short human-readable label used in validation errors
101    /// (e.g. `"mma.sync.m16n8k16"`).
102    pub fn feature_label(&self) -> String {
103        match self {
104            Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
105        }
106    }
107}
108
109/// Format a fragment register list as `{%x0,%x1,...}` (no surrounding
110/// whitespace).
111fn format_reg_list(regs: &[crate::ir::Register]) -> String {
112    let joined = regs
113        .iter()
114        .map(|r| format!("{r}"))
115        .collect::<Vec<_>>()
116        .join(",");
117    format!("{{{joined}}}")
118}
119
120impl Emit for TensorCoreOp {
121    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
122        match self {
123            TensorCoreOp::MmaSync {
124                d,
125                a,
126                b,
127                c,
128                shape,
129                d_ty,
130                a_ty,
131                b_ty,
132                c_ty,
133            } => {
134                // mma.sync.aligned.{shape}.row.col.{dty}.{aty}.{bty}.{cty}
135                // For fp16/bf16 inputs with fp32 accumulate, .row.col is
136                // the only valid operand-layout modifier.
137                let mnemonic = format!(
138                    "mma.sync.aligned.{}.row.col{}{}{}{}",
139                    shape.ptx_token(),
140                    d_ty.ptx_suffix(),
141                    a_ty.ptx_suffix(),
142                    b_ty.ptx_suffix(),
143                    c_ty.ptx_suffix(),
144                );
145                let d_list = format_reg_list(&d.regs);
146                let a_list = format_reg_list(&a.regs);
147                let b_list = format_reg_list(&b.regs);
148                let c_list = format_reg_list(&c.regs);
149                w.instruction(
150                    &mnemonic,
151                    &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
152                )
153            }
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::fragment::{alloc_a, alloc_b, alloc_c};
162    use crate::ir::RegisterAllocator;
163
164    #[test]
165    fn mma_shape_token_and_min_sm() {
166        assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
167        assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
168    }
169
170    #[test]
171    fn emit_mma_sync_m16n8k16_f16_f32() {
172        let mut alloc = RegisterAllocator::new();
173        let a = alloc_a(&mut alloc);
174        let b = alloc_b(&mut alloc);
175        let c = alloc_c(&mut alloc);
176        let d = alloc_c(&mut alloc);
177
178        let op = TensorCoreOp::MmaSync {
179            d,
180            a,
181            b,
182            c,
183            shape: MmaShape::M16N8K16,
184            d_ty: PtxType::F32,
185            a_ty: PtxType::F16,
186            b_ty: PtxType::F16,
187            c_ty: PtxType::F32,
188        };
189
190        let mut w = PtxWriter::new();
191        w.indent();
192        op.emit(&mut w).unwrap();
193        let out = w.finish();
194
195        // Check the full line — operand order: D, A, B, C.
196        let expected = concat!(
197            "    mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
198            "{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
199        );
200        assert_eq!(out, expected);
201    }
202
203    #[test]
204    fn emit_mma_sync_m16n8k16_bf16_f32() {
205        let mut alloc = RegisterAllocator::new();
206        let a = alloc_a(&mut alloc);
207        let b = alloc_b(&mut alloc);
208        let c = alloc_c(&mut alloc);
209        let d = alloc_c(&mut alloc);
210
211        let op = TensorCoreOp::MmaSync {
212            d,
213            a,
214            b,
215            c,
216            shape: MmaShape::M16N8K16,
217            d_ty: PtxType::F32,
218            a_ty: PtxType::BF16,
219            b_ty: PtxType::BF16,
220            c_ty: PtxType::F32,
221        };
222
223        let mut w = PtxWriter::new();
224        w.indent();
225        op.emit(&mut w).unwrap();
226        assert!(
227            w.finish()
228                .contains("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32")
229        );
230    }
231
232    #[test]
233    fn min_sm_and_feature_label() {
234        let mut alloc = RegisterAllocator::new();
235        let op = TensorCoreOp::MmaSync {
236            d: alloc_c(&mut alloc),
237            a: alloc_a(&mut alloc),
238            b: alloc_b(&mut alloc),
239            c: alloc_c(&mut alloc),
240            shape: MmaShape::M16N8K16,
241            d_ty: PtxType::F32,
242            a_ty: PtxType::F16,
243            b_ty: PtxType::F16,
244            c_ty: PtxType::F32,
245        };
246        assert_eq!(op.min_sm(), 80);
247        assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
248    }
249
250    #[test]
251    fn tensor_core_via_ptx_instruction() {
252        use crate::ir::PtxInstruction;
253        let mut alloc = RegisterAllocator::new();
254        let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
255            d: alloc_c(&mut alloc),
256            a: alloc_a(&mut alloc),
257            b: alloc_b(&mut alloc),
258            c: alloc_c(&mut alloc),
259            shape: MmaShape::M16N8K16,
260            d_ty: PtxType::F32,
261            a_ty: PtxType::F16,
262            b_ty: PtxType::F16,
263            c_ty: PtxType::F32,
264        });
265        let mut w = PtxWriter::new();
266        w.indent();
267        instr.emit(&mut w).unwrap();
268        assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
269    }
270}