use std::fmt;
use crate::emit::{Emit, PtxWriter};
use crate::fragment::{FragmentA, FragmentB, FragmentC};
use crate::types::PtxType;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MmaShape {
M16N8K16,
}
impl MmaShape {
pub fn ptx_token(&self) -> &'static str {
match self {
Self::M16N8K16 => "m16n8k16",
}
}
pub fn min_sm(&self) -> u32 {
match self {
Self::M16N8K16 => 80,
}
}
}
#[derive(Debug, Clone)]
pub enum TensorCoreOp {
MmaSync {
d: FragmentC,
a: FragmentA,
b: FragmentB,
c: FragmentC,
shape: MmaShape,
d_ty: PtxType,
a_ty: PtxType,
b_ty: PtxType,
c_ty: PtxType,
},
}
impl TensorCoreOp {
pub fn min_sm(&self) -> u32 {
match self {
Self::MmaSync { shape, .. } => shape.min_sm(),
}
}
pub fn feature_label(&self) -> String {
match self {
Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
}
}
}
fn format_reg_list(regs: &[crate::ir::Register]) -> String {
let joined = regs
.iter()
.map(|r| format!("{r}"))
.collect::<Vec<_>>()
.join(",");
format!("{{{joined}}}")
}
impl Emit for TensorCoreOp {
fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
match self {
TensorCoreOp::MmaSync {
d,
a,
b,
c,
shape,
d_ty,
a_ty,
b_ty,
c_ty,
} => {
let mnemonic = format!(
"mma.sync.aligned.{}.row.col{}{}{}{}",
shape.ptx_token(),
d_ty.ptx_suffix(),
a_ty.ptx_suffix(),
b_ty.ptx_suffix(),
c_ty.ptx_suffix(),
);
let d_list = format_reg_list(&d.regs);
let a_list = format_reg_list(&a.regs);
let b_list = format_reg_list(&b.regs);
let c_list = format_reg_list(&c.regs);
w.instruction(
&mnemonic,
&[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fragment::{alloc_a, alloc_b, alloc_c};
use crate::ir::RegisterAllocator;
#[test]
fn mma_shape_token_and_min_sm() {
assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
}
#[test]
fn emit_mma_sync_m16n8k16_f16_f32() {
let mut alloc = RegisterAllocator::new();
let a = alloc_a(&mut alloc);
let b = alloc_b(&mut alloc);
let c = alloc_c(&mut alloc);
let d = alloc_c(&mut alloc);
let op = TensorCoreOp::MmaSync {
d,
a,
b,
c,
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
};
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
"{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
);
assert_eq!(out, expected);
}
#[test]
fn emit_mma_sync_m16n8k16_bf16_f32() {
let mut alloc = RegisterAllocator::new();
let a = alloc_a(&mut alloc);
let b = alloc_b(&mut alloc);
let c = alloc_c(&mut alloc);
let d = alloc_c(&mut alloc);
let op = TensorCoreOp::MmaSync {
d,
a,
b,
c,
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::BF16,
b_ty: PtxType::BF16,
c_ty: PtxType::F32,
};
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
assert!(
w.finish()
.contains("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32")
);
}
#[test]
fn min_sm_and_feature_label() {
let mut alloc = RegisterAllocator::new();
let op = TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a(&mut alloc),
b: alloc_b(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
};
assert_eq!(op.min_sm(), 80);
assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
}
#[test]
fn tensor_core_via_ptx_instruction() {
use crate::ir::PtxInstruction;
let mut alloc = RegisterAllocator::new();
let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a(&mut alloc),
b: alloc_b(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
});
let mut w = PtxWriter::new();
w.indent();
instr.emit(&mut w).unwrap();
assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
}
}