#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WgmmaShape {
M64N64K16,
M64N128K16,
M64N256K16,
M64N64K32,
M64N128K32,
M64N256K32,
}
impl WgmmaShape {
pub fn dims(self) -> (u32, u32, u32) {
match self {
WgmmaShape::M64N64K16 => (64, 64, 16),
WgmmaShape::M64N128K16 => (64, 128, 16),
WgmmaShape::M64N256K16 => (64, 256, 16),
WgmmaShape::M64N64K32 => (64, 64, 32),
WgmmaShape::M64N128K32 => (64, 128, 32),
WgmmaShape::M64N256K32 => (64, 256, 32),
}
}
pub fn macro_name(self) -> &'static str {
match self {
WgmmaShape::M64N64K16 => "ATOMR_WGMMA_F16_M64N64K16",
WgmmaShape::M64N128K16 => "ATOMR_WGMMA_F16_M64N128K16",
WgmmaShape::M64N256K16 => "ATOMR_WGMMA_F16_M64N256K16",
WgmmaShape::M64N64K32 => "ATOMR_WGMMA_F8_M64N64K32",
WgmmaShape::M64N128K32 => "ATOMR_WGMMA_F8_M64N128K32",
WgmmaShape::M64N256K32 => "ATOMR_WGMMA_F8_M64N256K32",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dims_round_trip() {
assert_eq!(WgmmaShape::M64N64K16.dims(), (64, 64, 16));
assert_eq!(WgmmaShape::M64N256K32.dims(), (64, 256, 32));
}
#[test]
fn macro_names_match_header() {
assert_eq!(
WgmmaShape::M64N64K16.macro_name(),
"ATOMR_WGMMA_F16_M64N64K16"
);
}
}