use crate::error::PtxGenError;
use crate::ir::{PtxType, WmmaShape};
#[derive(Debug, Clone)]
pub struct WmmaConfig {
pub shape: WmmaShape,
pub a_type: PtxType,
pub b_type: PtxType,
pub c_type: PtxType,
}
impl WmmaConfig {
#[must_use]
pub const fn new(shape: WmmaShape, a_type: PtxType, b_type: PtxType, c_type: PtxType) -> Self {
Self {
shape,
a_type,
b_type,
c_type,
}
}
pub fn validate(&self) -> Result<(), PtxGenError> {
if self.a_type != self.b_type {
return Err(PtxGenError::InvalidType(format!(
"WMMA requires matching A/B types, got A={}, B={}",
self.a_type.as_ptx_str(),
self.b_type.as_ptx_str()
)));
}
if !matches!(self.a_type, PtxType::F16) {
return Err(PtxGenError::InvalidType(format!(
"WMMA A/B type must be F16, got {}",
self.a_type.as_ptx_str()
)));
}
if !matches!(self.c_type, PtxType::F16 | PtxType::F32) {
return Err(PtxGenError::InvalidType(format!(
"WMMA C/D type must be F16 or F32, got {}",
self.c_type.as_ptx_str()
)));
}
Ok(())
}
pub fn fragment_size_a(&self) -> Result<u32, PtxGenError> {
self.validate()?;
Ok(match self.shape {
WmmaShape::M16N16K16 | WmmaShape::M8N32K16 | WmmaShape::M32N8K16 => 8,
})
}
pub fn fragment_size_b(&self) -> Result<u32, PtxGenError> {
self.validate()?;
Ok(match self.shape {
WmmaShape::M16N16K16 | WmmaShape::M8N32K16 | WmmaShape::M32N8K16 => 8,
})
}
pub fn fragment_size_c(&self) -> Result<u32, PtxGenError> {
self.validate()?;
Ok(match self.c_type {
PtxType::F16 => 4,
PtxType::F32 => 8,
_ => {
return Err(PtxGenError::InvalidType(format!(
"unsupported accumulator type: {}",
self.c_type.as_ptx_str()
)));
}
})
}
#[must_use]
pub const fn dimensions(&self) -> (u32, u32, u32) {
match self.shape {
WmmaShape::M16N16K16 => (16, 16, 16),
WmmaShape::M8N32K16 => (8, 32, 16),
WmmaShape::M32N8K16 => (32, 8, 16),
}
}
#[must_use]
pub const fn total_elements_a(&self) -> u32 {
let (m, _, k) = self.dimensions();
m * k
}
#[must_use]
pub const fn total_elements_b(&self) -> u32 {
let (_, n, k) = self.dimensions();
k * n
}
#[must_use]
pub const fn total_elements_c(&self) -> u32 {
let (m, n, _) = self.dimensions();
m * n
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_f16_f32_config() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F32,
);
assert!(cfg.validate().is_ok());
assert_eq!(cfg.fragment_size_a().expect("valid"), 8);
assert_eq!(cfg.fragment_size_b().expect("valid"), 8);
assert_eq!(cfg.fragment_size_c().expect("valid"), 8);
}
#[test]
fn valid_f16_f16_config() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F16,
);
assert!(cfg.validate().is_ok());
assert_eq!(cfg.fragment_size_c().expect("valid"), 4);
}
#[test]
fn mismatched_ab_types() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F32,
PtxType::F32,
);
assert!(cfg.validate().is_err());
}
#[test]
fn invalid_a_type() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F32,
PtxType::F32,
PtxType::F32,
);
assert!(cfg.validate().is_err());
}
#[test]
fn dimensions() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F32,
);
assert_eq!(cfg.dimensions(), (16, 16, 16));
assert_eq!(cfg.total_elements_a(), 256);
assert_eq!(cfg.total_elements_b(), 256);
assert_eq!(cfg.total_elements_c(), 256);
}
#[test]
fn m8n32k16_dimensions() {
let cfg = WmmaConfig::new(
WmmaShape::M8N32K16,
PtxType::F16,
PtxType::F16,
PtxType::F32,
);
assert_eq!(cfg.dimensions(), (8, 32, 16));
assert_eq!(cfg.total_elements_a(), 128);
assert_eq!(cfg.total_elements_b(), 512);
assert_eq!(cfg.total_elements_c(), 256);
}
#[test]
fn test_wmma_m16n16k16_f16_accumulator_fragment_layout() {
let cfg_f32 = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F32,
);
assert_eq!(
cfg_f32.fragment_size_a().expect("valid config"),
8,
"m16n16k16 A fragment must be 8 registers/thread"
);
assert_eq!(
cfg_f32.fragment_size_b().expect("valid config"),
8,
"m16n16k16 B fragment must be 8 registers/thread"
);
assert_eq!(
cfg_f32.fragment_size_c().expect("valid config"),
8,
"m16n16k16 C fragment (F32) must be 8 registers/thread"
);
let cfg_f16 = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F16,
);
assert_eq!(
cfg_f16.fragment_size_c().expect("valid config"),
4,
"m16n16k16 C fragment (F16) must be 4 registers/thread (2 values packed per reg)"
);
}
#[test]
fn test_wmma_m16n16k16_total_element_counts() {
let cfg = WmmaConfig::new(
WmmaShape::M16N16K16,
PtxType::F16,
PtxType::F16,
PtxType::F32,
);
assert_eq!(
cfg.total_elements_a(),
256,
"m16n16k16 total A elements = M*K = 16*16 = 256"
);
assert_eq!(
cfg.total_elements_b(),
256,
"m16n16k16 total B elements = K*N = 16*16 = 256"
);
assert_eq!(
cfg.total_elements_c(),
256,
"m16n16k16 total C elements = M*N = 16*16 = 256"
);
}
#[test]
fn test_wmma_all_shapes_same_ab_fragment_register_count() {
for shape in [
WmmaShape::M16N16K16,
WmmaShape::M8N32K16,
WmmaShape::M32N8K16,
] {
let cfg = WmmaConfig::new(shape, PtxType::F16, PtxType::F16, PtxType::F32);
assert_eq!(
cfg.fragment_size_a().expect("valid"),
8,
"{shape:?} A fragment must be 8 regs"
);
assert_eq!(
cfg.fragment_size_b().expect("valid"),
8,
"{shape:?} B fragment must be 8 regs"
);
}
}
}