use std::fmt;
use crate::emit::{Emit, PtxWriter};
use crate::fragment::{
FragmentA_BF16, FragmentA_F16, FragmentA_M16N8K32, FragmentB_BF16, FragmentB_F16,
FragmentB_M16N8K32, FragmentC, FragmentC_M16N8K32,
};
use crate::types::PtxType;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MmaShape {
M16N8K16,
M16N8K32,
}
impl MmaShape {
pub fn ptx_token(&self) -> &'static str {
match self {
Self::M16N8K16 => "m16n8k16",
Self::M16N8K32 => "m16n8k32",
}
}
pub fn min_sm(&self) -> u32 {
match self {
Self::M16N8K16 => 80,
Self::M16N8K32 => 80,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LdMatrixDst {
X2([crate::ir::Register; 2]),
X4([crate::ir::Register; 4]),
}
impl LdMatrixDst {
pub fn num_token(&self) -> &'static str {
match self {
Self::X2(_) => "x2",
Self::X4(_) => "x4",
}
}
pub fn regs(&self) -> &[crate::ir::Register] {
match self {
Self::X2(regs) => regs,
Self::X4(regs) => regs,
}
}
}
#[derive(Debug, Clone)]
pub enum TensorCoreOp {
MmaSync {
d: FragmentC,
a: FragmentA_F16,
b: FragmentB_F16,
c: FragmentC,
shape: MmaShape,
d_ty: PtxType,
a_ty: PtxType,
b_ty: PtxType,
c_ty: PtxType,
},
MmaSyncInt8 {
d: FragmentC_M16N8K32,
a: FragmentA_M16N8K32,
b: FragmentB_M16N8K32,
c: FragmentC_M16N8K32,
},
MmaSyncBf16 {
d: FragmentC,
a: FragmentA_BF16,
b: FragmentB_BF16,
c: FragmentC,
},
LdMatrix {
dst: LdMatrixDst,
addr: crate::ir::Register,
trans: bool,
},
}
impl TensorCoreOp {
pub fn min_sm(&self) -> u32 {
match self {
Self::MmaSync { shape, .. } => shape.min_sm(),
Self::MmaSyncInt8 { .. } => MmaShape::M16N8K32.min_sm(),
Self::MmaSyncBf16 { .. } => MmaShape::M16N8K16.min_sm(),
Self::LdMatrix { .. } => 75,
}
}
pub fn feature_label(&self) -> String {
match self {
Self::MmaSync { shape, .. } => format!("mma.sync.{}", shape.ptx_token()),
Self::MmaSyncInt8 { .. } => {
format!("mma.sync.{}.s8.s8.s32", MmaShape::M16N8K32.ptx_token())
}
Self::MmaSyncBf16 { .. } => {
format!("mma.sync.{}.bf16.bf16.f32", MmaShape::M16N8K16.ptx_token())
}
Self::LdMatrix { dst, trans, .. } => {
format!(
"ldmatrix.m8n8.{}{}",
dst.num_token(),
if *trans { ".trans" } else { "" }
)
}
}
}
}
fn format_reg_list(regs: &[crate::ir::Register]) -> String {
let joined = regs
.iter()
.map(|r| format!("{r}"))
.collect::<Vec<_>>()
.join(",");
format!("{{{joined}}}")
}
impl Emit for TensorCoreOp {
fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
match self {
TensorCoreOp::MmaSync {
d,
a,
b,
c,
shape,
d_ty,
a_ty,
b_ty,
c_ty,
} => {
let mnemonic = format!(
"mma.sync.aligned.{}.row.col{}{}{}{}",
shape.ptx_token(),
d_ty.ptx_suffix(),
a_ty.ptx_suffix(),
b_ty.ptx_suffix(),
c_ty.ptx_suffix(),
);
let d_list = format_reg_list(&d.regs);
let a_list = format_reg_list(&a.regs);
let b_list = format_reg_list(&b.regs);
let c_list = format_reg_list(&c.regs);
w.instruction(
&mnemonic,
&[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
)
}
TensorCoreOp::MmaSyncInt8 { d, a, b, c } => {
let mnemonic = "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32";
let d_list = format_reg_list(&d.regs);
let a_list = format_reg_list(&a.regs);
let b_list = format_reg_list(&b.regs);
let c_list = format_reg_list(&c.regs);
w.instruction(
mnemonic,
&[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
)
}
TensorCoreOp::MmaSyncBf16 { d, a, b, c } => {
let mnemonic = "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32";
let d_list = format_reg_list(&d.regs);
let a_list = format_reg_list(&a.regs);
let b_list = format_reg_list(&b.regs);
let c_list = format_reg_list(&c.regs);
w.instruction(
mnemonic,
&[&d_list as &dyn fmt::Display, &a_list, &b_list, &c_list],
)
}
TensorCoreOp::LdMatrix { dst, addr, trans } => {
let mnemonic = format!(
"ldmatrix.sync.aligned.m8n8.{}{}.shared.b16",
dst.num_token(),
if *trans { ".trans" } else { "" },
);
let d_list = format_reg_list(dst.regs());
let addr_operand = format!("[{addr}]");
w.instruction(&mnemonic, &[&d_list as &dyn fmt::Display, &addr_operand])
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
use crate::ir::RegisterAllocator;
#[test]
fn mma_shape_token_and_min_sm() {
assert_eq!(MmaShape::M16N8K16.ptx_token(), "m16n8k16");
assert_eq!(MmaShape::M16N8K16.min_sm(), 80);
assert_eq!(MmaShape::M16N8K32.ptx_token(), "m16n8k32");
assert_eq!(MmaShape::M16N8K32.min_sm(), 80);
}
#[test]
fn emit_mma_sync_m16n8k16_f16_f32() {
let mut alloc = RegisterAllocator::new();
let a = alloc_a_f16(&mut alloc);
let b = alloc_b_f16(&mut alloc);
let c = alloc_c(&mut alloc);
let d = alloc_c(&mut alloc);
let op = TensorCoreOp::MmaSync {
d,
a,
b,
c,
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
};
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 ",
"{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
);
assert_eq!(out, expected);
}
#[test]
fn emit_mma_sync_bf16_m16n8k16() {
use crate::fragment::{alloc_a_bf16, alloc_b_bf16};
let mut alloc = RegisterAllocator::new();
let a = alloc_a_bf16(&mut alloc);
let b = alloc_b_bf16(&mut alloc);
let c = alloc_c(&mut alloc);
let d = alloc_c(&mut alloc);
let op = TensorCoreOp::MmaSyncBf16 { d, a, b, c };
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 ",
"{%f4,%f5,%f6,%f7}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%f0,%f1,%f2,%f3};\n",
);
assert_eq!(out, expected);
}
#[test]
fn min_sm_and_feature_label_bf16() {
use crate::fragment::{alloc_a_bf16, alloc_b_bf16};
let mut alloc = RegisterAllocator::new();
let op = TensorCoreOp::MmaSyncBf16 {
d: alloc_c(&mut alloc),
a: alloc_a_bf16(&mut alloc),
b: alloc_b_bf16(&mut alloc),
c: alloc_c(&mut alloc),
};
assert_eq!(op.min_sm(), 80);
assert_eq!(op.feature_label(), "mma.sync.m16n8k16.bf16.bf16.f32");
}
#[test]
fn min_sm_and_feature_label() {
let mut alloc = RegisterAllocator::new();
let op = TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a_f16(&mut alloc),
b: alloc_b_f16(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
};
assert_eq!(op.min_sm(), 80);
assert_eq!(op.feature_label(), "mma.sync.m16n8k16");
}
#[test]
fn emit_mma_sync_int8_m16n8k32() {
use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
let mut alloc = RegisterAllocator::new();
let a = alloc_a_M16N8K32(&mut alloc);
let b = alloc_b_M16N8K32(&mut alloc);
let c = alloc_c_M16N8K32(&mut alloc);
let d = alloc_c_M16N8K32(&mut alloc);
let op = TensorCoreOp::MmaSyncInt8 { d, a, b, c };
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 ",
"{%r10,%r11,%r12,%r13}, {%r0,%r1,%r2,%r3}, {%r4,%r5}, {%r6,%r7,%r8,%r9};\n",
);
assert_eq!(out, expected);
}
#[test]
fn int8_min_sm_and_feature_label() {
use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
let mut alloc = RegisterAllocator::new();
let op = TensorCoreOp::MmaSyncInt8 {
d: alloc_c_M16N8K32(&mut alloc),
a: alloc_a_M16N8K32(&mut alloc),
b: alloc_b_M16N8K32(&mut alloc),
c: alloc_c_M16N8K32(&mut alloc),
};
assert_eq!(op.min_sm(), 80);
assert_eq!(op.feature_label(), "mma.sync.m16n8k32.s8.s8.s32");
}
#[test]
fn emit_ldmatrix_x4_shared_b16() {
let mut alloc = RegisterAllocator::new();
let dst = LdMatrixDst::X4([
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
]);
let addr = alloc.alloc(PtxType::U32);
let op = TensorCoreOp::LdMatrix {
dst,
addr,
trans: false,
};
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" ldmatrix.sync.aligned.m8n8.x4.shared.b16 ",
"{%r0,%r1,%r2,%r3}, [%r4];\n",
);
assert_eq!(out, expected);
}
#[test]
fn emit_ldmatrix_x2_trans_shared_b16() {
let mut alloc = RegisterAllocator::new();
let dst = LdMatrixDst::X2([alloc.alloc_packed_half2(), alloc.alloc_packed_half2()]);
let addr = alloc.alloc(PtxType::U32);
let op = TensorCoreOp::LdMatrix {
dst,
addr,
trans: true,
};
let mut w = PtxWriter::new();
w.indent();
op.emit(&mut w).unwrap();
let out = w.finish();
let expected = concat!(
" ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 ",
"{%r0,%r1}, [%r2];\n",
);
assert_eq!(out, expected);
}
#[test]
fn ldmatrix_min_sm_and_feature_label() {
let mut alloc = RegisterAllocator::new();
let x4 = TensorCoreOp::LdMatrix {
dst: LdMatrixDst::X4([
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
]),
addr: alloc.alloc(PtxType::U32),
trans: false,
};
assert_eq!(x4.min_sm(), 75);
assert_eq!(x4.feature_label(), "ldmatrix.m8n8.x4");
let x2t = TensorCoreOp::LdMatrix {
dst: LdMatrixDst::X2([alloc.alloc_packed_half2(), alloc.alloc_packed_half2()]),
addr: alloc.alloc(PtxType::U32),
trans: true,
};
assert_eq!(x2t.min_sm(), 75);
assert_eq!(x2t.feature_label(), "ldmatrix.m8n8.x2.trans");
}
#[test]
fn ldmatrix_dst_accessors() {
let mut alloc = RegisterAllocator::new();
let r0 = alloc.alloc_packed_half2();
let r1 = alloc.alloc_packed_half2();
let x2 = LdMatrixDst::X2([r0, r1]);
assert_eq!(x2.num_token(), "x2");
assert_eq!(x2.regs(), &[r0, r1]);
let r2 = alloc.alloc_packed_half2();
let r3 = alloc.alloc_packed_half2();
let x4 = LdMatrixDst::X4([r0, r1, r2, r3]);
assert_eq!(x4.num_token(), "x4");
assert_eq!(x4.regs().len(), 4);
}
#[test]
fn tensor_core_via_ptx_instruction() {
use crate::ir::PtxInstruction;
let mut alloc = RegisterAllocator::new();
let instr = PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
d: alloc_c(&mut alloc),
a: alloc_a_f16(&mut alloc),
b: alloc_b_f16(&mut alloc),
c: alloc_c(&mut alloc),
shape: MmaShape::M16N8K16,
d_ty: PtxType::F32,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
});
let mut w = PtxWriter::new();
w.indent();
instr.emit(&mut w).unwrap();
assert!(w.finish().contains("mma.sync.aligned.m16n8k16.row.col"));
}
}