use oxicuda_ptx::arch::{ArchCapabilities, SmVersion};
use oxicuda_ptx::ir::PtxType;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TensorCoreConfig {
pub mma_m: u32,
pub mma_n: u32,
pub mma_k: u32,
pub instruction: TcInstruction,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TcInstruction {
Wmma,
Mma,
Wgmma,
}
pub struct TensorCoreValidator;
impl TensorCoreValidator {
pub fn is_supported(sm: SmVersion, input: PtxType, accumulator: PtxType) -> bool {
Self::config(sm, input, accumulator).is_some()
}
pub fn config(sm: SmVersion, input: PtxType, accumulator: PtxType) -> Option<TensorCoreConfig> {
let caps = ArchCapabilities::for_sm(sm);
if !caps.has_tensor_cores {
return None;
}
if accumulator != PtxType::F32 && accumulator != PtxType::F64 {
return None;
}
match sm {
SmVersion::Sm75 => Self::turing_config(input, accumulator),
SmVersion::Sm80 | SmVersion::Sm86 => Self::ampere_config(input, accumulator, &caps),
SmVersion::Sm89 => Self::ada_config(input, accumulator, &caps),
SmVersion::Sm90 | SmVersion::Sm90a => Self::hopper_config(input, accumulator, &caps),
SmVersion::Sm100 | SmVersion::Sm120 => {
Self::blackwell_config(input, accumulator, &caps)
}
}
}
fn turing_config(input: PtxType, accumulator: PtxType) -> Option<TensorCoreConfig> {
if input == PtxType::F16 && accumulator == PtxType::F32 {
Some(TensorCoreConfig {
mma_m: 16,
mma_n: 16,
mma_k: 16,
instruction: TcInstruction::Wmma,
})
} else {
None
}
}
fn ampere_config(
input: PtxType,
accumulator: PtxType,
caps: &ArchCapabilities,
) -> Option<TensorCoreConfig> {
if !caps.has_ampere_mma {
return None;
}
match (input, accumulator) {
(PtxType::F16, PtxType::F32) | (PtxType::BF16, PtxType::F32) => {
Some(TensorCoreConfig {
mma_m: 16,
mma_n: 8,
mma_k: 16,
instruction: TcInstruction::Mma,
})
}
(PtxType::F32, PtxType::F32) => Some(TensorCoreConfig {
mma_m: 16,
mma_n: 8,
mma_k: 8,
instruction: TcInstruction::Mma,
}),
_ => None,
}
}
fn ada_config(
input: PtxType,
accumulator: PtxType,
caps: &ArchCapabilities,
) -> Option<TensorCoreConfig> {
if let Some(cfg) = Self::ampere_config(input, accumulator, caps) {
return Some(cfg);
}
if caps.has_fp8 && matches!(input, PtxType::U8) && accumulator == PtxType::F32 {
return Some(TensorCoreConfig {
mma_m: 16,
mma_n: 8,
mma_k: 32,
instruction: TcInstruction::Mma,
});
}
None
}
fn hopper_config(
input: PtxType,
accumulator: PtxType,
caps: &ArchCapabilities,
) -> Option<TensorCoreConfig> {
if caps.has_wgmma {
match (input, accumulator) {
(PtxType::F16, PtxType::F32) | (PtxType::BF16, PtxType::F32) => {
return Some(TensorCoreConfig {
mma_m: 64,
mma_n: 256,
mma_k: 16,
instruction: TcInstruction::Wgmma,
});
}
(PtxType::F32, PtxType::F32) => {
return Some(TensorCoreConfig {
mma_m: 64,
mma_n: 128,
mma_k: 8,
instruction: TcInstruction::Wgmma,
});
}
_ => {}
}
}
Self::ampere_config(input, accumulator, caps)
}
fn blackwell_config(
input: PtxType,
accumulator: PtxType,
caps: &ArchCapabilities,
) -> Option<TensorCoreConfig> {
Self::hopper_config(input, accumulator, caps)
}
pub fn validate_tile(
tile_m: u32,
tile_n: u32,
tile_k: u32,
config: &TensorCoreConfig,
) -> Result<(), String> {
if tile_m % config.mma_m != 0 {
return Err(format!(
"tile_m ({tile_m}) must be a multiple of mma_m ({})",
config.mma_m
));
}
if tile_n % config.mma_n != 0 {
return Err(format!(
"tile_n ({tile_n}) must be a multiple of mma_n ({})",
config.mma_n
));
}
if tile_k % config.mma_k != 0 {
return Err(format!(
"tile_k ({tile_k}) must be a multiple of mma_k ({})",
config.mma_k
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn turing_f16_supported() {
assert!(TensorCoreValidator::is_supported(
SmVersion::Sm75,
PtxType::F16,
PtxType::F32,
));
}
#[test]
fn turing_f32_not_supported() {
assert!(!TensorCoreValidator::is_supported(
SmVersion::Sm75,
PtxType::F32,
PtxType::F32,
));
}
#[test]
fn ampere_f16_mma() {
let cfg = TensorCoreValidator::config(SmVersion::Sm80, PtxType::F16, PtxType::F32);
let cfg = cfg.expect("should have config");
assert_eq!(cfg.instruction, TcInstruction::Mma);
assert_eq!(cfg.mma_m, 16);
assert_eq!(cfg.mma_n, 8);
assert_eq!(cfg.mma_k, 16);
}
#[test]
fn ampere_tf32() {
let cfg = TensorCoreValidator::config(SmVersion::Sm80, PtxType::F32, PtxType::F32);
let cfg = cfg.expect("TF32 path should exist on Ampere");
assert_eq!(cfg.mma_k, 8);
}
#[test]
fn hopper_wgmma_f16() {
let cfg = TensorCoreValidator::config(SmVersion::Sm90, PtxType::F16, PtxType::F32);
let cfg = cfg.expect("Hopper should support WGMMA for F16");
assert_eq!(cfg.instruction, TcInstruction::Wgmma);
assert_eq!(cfg.mma_m, 64);
}
#[test]
fn validate_tile_ok() {
let cfg = TensorCoreConfig {
mma_m: 16,
mma_n: 8,
mma_k: 16,
instruction: TcInstruction::Mma,
};
assert!(TensorCoreValidator::validate_tile(128, 128, 32, &cfg).is_ok());
}
#[test]
fn validate_tile_bad_m() {
let cfg = TensorCoreConfig {
mma_m: 16,
mma_n: 8,
mma_k: 16,
instruction: TcInstruction::Mma,
};
assert!(TensorCoreValidator::validate_tile(100, 128, 32, &cfg).is_err());
}
#[test]
fn f64_not_supported_tc() {
assert!(!TensorCoreValidator::is_supported(
SmVersion::Sm80,
PtxType::F64,
PtxType::F64,
));
}
}