1use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::fragment::{FragmentA, FragmentB, FragmentC};
15use crate::types::PtxType;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MmaShape {
24 M16N8K16,
26}
27
28impl MmaShape {
29 pub fn ptx_token(&self) -> &'static str {
31 match self {
32 Self::M16N8K16 => "m16n8k16",
33 }
34 }
35
36 pub fn min_sm(&self) -> u32 {
41 match self {
42 Self::M16N8K16 => 80,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
53pub enum TensorCoreOp {
54 MmaSync {
71 d: FragmentC,
73 a: FragmentA,
75 b: FragmentB,
77 c: FragmentC,
79 shape: MmaShape,
81 d_ty: PtxType,
83 a_ty: PtxType,
85 b_ty: PtxType,
87 c_ty: PtxType,
89 },
90}
91
92impl TensorCoreOp {
93 pub fn min_sm(&self) -> u32 {
95 match self {
96 Self::MmaSync { shape, .. } => shape.min_sm(),
97 }
98 }
99
100 pub fn feature_label(&self) -> String {
103 match self {
104 Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
105 }
106 }
107}
108
109fn format_reg_list(regs: &[crate::ir::Register]) -> String {
112 let joined = regs
113 .iter()
114 .map(|r| format!("{r}"))
115 .collect::<Vec<_>>()
116 .join(",");
117 format!("{{{joined}}}")
118}
119
120impl Emit for TensorCoreOp {
121 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
122 match self {
123 TensorCoreOp::MmaSync {
124 d,
125 a,
126 b,
127 c,
128 shape,
129 d_ty,
130 a_ty,
131 b_ty,
132 c_ty,
133 } => {
134 let mnemonic = format!(
138 "mma.sync.aligned.{}.row.col{}{}{}{}",
139 shape.ptx_token(),
140 d_ty.ptx_suffix(),
141 a_ty.ptx_suffix(),
142 b_ty.ptx_suffix(),
143 c_ty.ptx_suffix(),
144 );
145 let d_list = format_reg_list(&d.regs);
146 let a_list = format_reg_list(&a.regs);
147 let b_list = format_reg_list(&b.regs);
148 let c_list = format_reg_list(&c.regs);
149 w.instruction(
150 &mnemonic,
151 &[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
152 )
153 }
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::fragment::{alloc_a, alloc_b, alloc_c};
162 use crate::ir::RegisterAllocator;
163
164 #[test]
165 fn mma_shape_token_and_min_sm() {
166 assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
167 assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
168 }
169
170 #[test]
171 fn emit_mma_sync_m16n8k16_f16_f32() {
172 let mut alloc = RegisterAllocator::new();
173 let a = alloc_a(&mut alloc);
174 let b = alloc_b(&mut alloc);
175 let c = alloc_c(&mut alloc);
176 let d = alloc_c(&mut alloc);
177
178 let op = TensorCoreOp::MmaSync {
179 d,
180 a,
181 b,
182 c,
183 shape: MmaShape::M16N8K16,
184 d_ty: PtxType::F32,
185 a_ty: PtxType::F16,
186 b_ty: PtxType::F16,
187 c_ty: PtxType::F32,
188 };
189
190 let mut w = PtxWriter::new();
191 w.indent();
192 op.emit(&mut w).unwrap();
193 let out = w.finish();
194
195 let expected = concat!(
197 " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
198 "{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
199 );
200 assert_eq!(out, expected);
201 }
202
203 #[test]
204 fn emit_mma_sync_m16n8k16_bf16_f32() {
205 let mut alloc = RegisterAllocator::new();
206 let a = alloc_a(&mut alloc);
207 let b = alloc_b(&mut alloc);
208 let c = alloc_c(&mut alloc);
209 let d = alloc_c(&mut alloc);
210
211 let op = TensorCoreOp::MmaSync {
212 d,
213 a,
214 b,
215 c,
216 shape: MmaShape::M16N8K16,
217 d_ty: PtxType::F32,
218 a_ty: PtxType::BF16,
219 b_ty: PtxType::BF16,
220 c_ty: PtxType::F32,
221 };
222
223 let mut w = PtxWriter::new();
224 w.indent();
225 op.emit(&mut w).unwrap();
226 assert!(
227 w.finish()
228 .contains("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32")
229 );
230 }
231
232 #[test]
233 fn min_sm_and_feature_label() {
234 let mut alloc = RegisterAllocator::new();
235 let op = TensorCoreOp::MmaSync {
236 d: alloc_c(&mut alloc),
237 a: alloc_a(&mut alloc),
238 b: alloc_b(&mut alloc),
239 c: alloc_c(&mut alloc),
240 shape: MmaShape::M16N8K16,
241 d_ty: PtxType::F32,
242 a_ty: PtxType::F16,
243 b_ty: PtxType::F16,
244 c_ty: PtxType::F32,
245 };
246 assert_eq!(op.min_sm(), 80);
247 assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
248 }
249
250 #[test]
251 fn tensor_core_via_ptx_instruction() {
252 use crate::ir::PtxInstruction;
253 let mut alloc = RegisterAllocator::new();
254 let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
255 d: alloc_c(&mut alloc),
256 a: alloc_a(&mut alloc),
257 b: alloc_b(&mut alloc),
258 c: alloc_c(&mut alloc),
259 shape: MmaShape::M16N8K16,
260 d_ty: PtxType::F32,
261 a_ty: PtxType::F16,
262 b_ty: PtxType::F16,
263 c_ty: PtxType::F32,
264 });
265 let mut w = PtxWriter::new();
266 w.indent();
267 instr.emit(&mut w).unwrap();
268 assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
269 }
270}