use oxicuda_ptx::prelude::SmVersion;
use crate::error::{BlasError, BlasResult};
use crate::level3::gemm::dispatch::TileConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Fp8Format {
E4M3,
E5M2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Fp8WorkloadClass {
SmallSquare,
LargeSquare,
SkinnyM,
SkinnyN,
SkinnyK,
General,
}
#[must_use]
pub fn classify_workload(m: u32, n: u32, k: u32) -> Fp8WorkloadClass {
let mn = (m as u64) * (n as u64);
if mn < 8192 {
return Fp8WorkloadClass::SmallSquare;
}
if m < 128 && n >= 128 {
return Fp8WorkloadClass::SkinnyM;
}
if n < 128 && m >= 128 {
return Fp8WorkloadClass::SkinnyN;
}
if k < 64 {
return Fp8WorkloadClass::SkinnyK;
}
if m >= 1024 && n >= 1024 {
return Fp8WorkloadClass::LargeSquare;
}
Fp8WorkloadClass::General
}
pub struct Fp8TileHeuristic;
impl Fp8TileHeuristic {
#[must_use]
pub fn select(m: u32, n: u32, k: u32, sm: SmVersion) -> TileConfig {
let class = classify_workload(m, n, k);
match sm {
SmVersion::Sm89 => Self::select_ada(class, k),
SmVersion::Sm90 | SmVersion::Sm90a => Self::select_hopper(class, k),
SmVersion::Sm100 | SmVersion::Sm120 => Self::select_blackwell(class, k),
_ => Self::select_fallback(),
}
}
fn select_ada(class: Fp8WorkloadClass, k: u32) -> TileConfig {
let stages = if k < 64 { 2 } else { 4 };
match class {
Fp8WorkloadClass::SmallSquare => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 32,
warp_m: 32,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyM => TileConfig {
tile_m: 64,
tile_n: 128,
tile_k: 64,
warp_m: 32,
warp_n: 64,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyN => TileConfig {
tile_m: 128,
tile_n: 64,
tile_k: 64,
warp_m: 64,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::LargeSquare => TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 64,
warp_m: 64,
warp_n: 64,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyK | Fp8WorkloadClass::General => TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 64,
warp_m: 64,
warp_n: 64,
stages,
use_tensor_core: true,
split_k: 1,
},
}
}
fn select_hopper(class: Fp8WorkloadClass, k: u32) -> TileConfig {
let stages = if k < 64 { 2 } else { 4 };
match class {
Fp8WorkloadClass::SmallSquare => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 64,
warp_m: 32,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyM => TileConfig {
tile_m: 64,
tile_n: 256,
tile_k: 128,
warp_m: 32,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyN => TileConfig {
tile_m: 256,
tile_n: 64,
tile_k: 128,
warp_m: 128,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::LargeSquare => TileConfig {
tile_m: 256,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyK => TileConfig {
tile_m: 128,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages: 2,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::General => TileConfig {
tile_m: 128,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
}
}
fn select_blackwell(class: Fp8WorkloadClass, k: u32) -> TileConfig {
let stages = if k < 64 { 2 } else { 4 };
match class {
Fp8WorkloadClass::SmallSquare => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 64,
warp_m: 32,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyM => TileConfig {
tile_m: 64,
tile_n: 256,
tile_k: 128,
warp_m: 32,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyN => TileConfig {
tile_m: 256,
tile_n: 64,
tile_k: 128,
warp_m: 128,
warp_n: 32,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::LargeSquare => TileConfig {
tile_m: 256,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::SkinnyK => TileConfig {
tile_m: 256,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages: 2,
use_tensor_core: true,
split_k: 1,
},
Fp8WorkloadClass::General => TileConfig {
tile_m: 256,
tile_n: 256,
tile_k: 128,
warp_m: 64,
warp_n: 128,
stages,
use_tensor_core: true,
split_k: 1,
},
}
}
fn select_fallback() -> TileConfig {
TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 32,
warp_m: 32,
warp_n: 32,
stages: 1,
use_tensor_core: false,
split_k: 1,
}
}
}
pub struct Fp8Config;
impl Fp8Config {
pub const ELEMENT_BYTES: u32 = 1;
#[must_use]
pub fn tile_config(sm: SmVersion, m: u32, n: u32, k: u32) -> TileConfig {
Fp8TileHeuristic::select(m, n, k, sm)
}
#[must_use]
pub fn is_available(sm: SmVersion) -> bool {
sm >= SmVersion::Sm89
}
#[must_use]
pub fn is_format_supported(format: Fp8Format, sm: SmVersion) -> bool {
match format {
Fp8Format::E4M3 => sm >= SmVersion::Sm89,
Fp8Format::E5M2 => sm >= SmVersion::Sm89,
}
}
#[must_use]
pub fn has_wgmma(sm: SmVersion) -> bool {
sm >= SmVersion::Sm90
}
pub fn validate(sm: SmVersion, format: Fp8Format, m: u32, n: u32, k: u32) -> BlasResult<()> {
if m == 0 || n == 0 || k == 0 {
return Err(BlasError::InvalidDimension(format!(
"FP8 GEMM dimensions must be non-zero: m={m}, n={n}, k={k}"
)));
}
if !Self::is_available(sm) {
return Err(BlasError::UnsupportedOperation(format!(
"FP8 Tensor Cores require Ada Lovelace+ (sm_89), got {sm}"
)));
}
if !Self::is_format_supported(format, sm) {
return Err(BlasError::UnsupportedOperation(format!(
"FP8 format {format:?} not supported on {sm}"
)));
}
Ok(())
}
#[must_use]
pub fn compute_split_k(m: u32, n: u32, k: u32) -> u32 {
if m * n < 2048 && k > 32768 {
let factor = k / 8192;
factor.clamp(2, 64)
} else {
1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fp8_not_on_ampere() {
assert!(!Fp8Config::is_available(SmVersion::Sm80));
assert!(!Fp8Config::is_available(SmVersion::Sm86));
}
#[test]
fn fp8_on_ada() {
assert!(Fp8Config::is_available(SmVersion::Sm89));
}
#[test]
fn fp8_on_hopper() {
assert!(Fp8Config::is_available(SmVersion::Sm90));
assert!(Fp8Config::has_wgmma(SmVersion::Sm90a));
}
#[test]
fn tile_config_hopper() {
let cfg = Fp8Config::tile_config(SmVersion::Sm90, 2048, 2048, 2048);
assert!(cfg.use_tensor_core);
assert_eq!(cfg.tile_k, 128);
}
#[test]
fn validate_ok() {
assert!(Fp8Config::validate(SmVersion::Sm89, Fp8Format::E4M3, 128, 128, 128).is_ok());
}
#[test]
fn validate_unsupported() {
assert!(Fp8Config::validate(SmVersion::Sm80, Fp8Format::E4M3, 128, 128, 128).is_err());
}
#[test]
fn validate_zero_dim() {
assert!(Fp8Config::validate(SmVersion::Sm90, Fp8Format::E5M2, 0, 128, 128).is_err());
}
#[test]
fn format_support() {
assert!(Fp8Config::is_format_supported(
Fp8Format::E4M3,
SmVersion::Sm89
));
assert!(Fp8Config::is_format_supported(
Fp8Format::E5M2,
SmVersion::Sm90
));
assert!(!Fp8Config::is_format_supported(
Fp8Format::E4M3,
SmVersion::Sm80
));
}
#[test]
fn classify_small_square() {
assert_eq!(
classify_workload(64, 64, 256),
Fp8WorkloadClass::SmallSquare
);
assert_eq!(
classify_workload(32, 128, 512),
Fp8WorkloadClass::SmallSquare
);
}
#[test]
fn classify_large_square() {
assert_eq!(
classify_workload(1024, 1024, 512),
Fp8WorkloadClass::LargeSquare
);
assert_eq!(
classify_workload(4096, 4096, 4096),
Fp8WorkloadClass::LargeSquare
);
}
#[test]
fn classify_skinny_m() {
assert_eq!(classify_workload(64, 512, 256), Fp8WorkloadClass::SkinnyM);
}
#[test]
fn classify_skinny_n() {
assert_eq!(classify_workload(512, 64, 256), Fp8WorkloadClass::SkinnyN);
}
#[test]
fn classify_skinny_k() {
assert_eq!(classify_workload(256, 256, 32), Fp8WorkloadClass::SkinnyK);
}
#[test]
fn classify_general() {
assert_eq!(classify_workload(256, 512, 128), Fp8WorkloadClass::General);
}
#[test]
fn classify_edge_k_equals_1() {
assert_eq!(classify_workload(1, 1, 1), Fp8WorkloadClass::SmallSquare);
}
#[test]
fn classify_edge_m_equals_1() {
assert_eq!(classify_workload(1, 16384, 256), Fp8WorkloadClass::SkinnyM);
}
#[test]
fn classify_edge_n_equals_1() {
assert_eq!(classify_workload(16384, 1, 256), Fp8WorkloadClass::SkinnyN);
}
#[test]
fn heuristic_small_square_ada() {
let cfg = Fp8TileHeuristic::select(64, 64, 256, SmVersion::Sm89);
assert_eq!(cfg.tile_m, 64);
assert_eq!(cfg.tile_n, 64);
assert!(cfg.use_tensor_core);
}
#[test]
fn heuristic_skinny_m_hopper() {
let cfg = Fp8TileHeuristic::select(64, 512, 256, SmVersion::Sm90);
assert_eq!(cfg.tile_m, 64);
assert_eq!(cfg.tile_n, 256);
assert_eq!(cfg.tile_k, 128);
}
#[test]
fn heuristic_skinny_n_hopper() {
let cfg = Fp8TileHeuristic::select(512, 64, 256, SmVersion::Sm90);
assert_eq!(cfg.tile_m, 256);
assert_eq!(cfg.tile_n, 64);
}
#[test]
fn heuristic_skinny_k_reduces_stages() {
let cfg = Fp8TileHeuristic::select(256, 256, 32, SmVersion::Sm90);
assert_eq!(cfg.stages, 2);
}
#[test]
fn heuristic_large_square_blackwell() {
let cfg = Fp8TileHeuristic::select(2048, 2048, 2048, SmVersion::Sm100);
assert_eq!(cfg.tile_m, 256);
assert_eq!(cfg.tile_n, 256);
assert_eq!(cfg.stages, 4);
}
#[test]
fn heuristic_general_hopper() {
let cfg = Fp8TileHeuristic::select(256, 512, 128, SmVersion::Sm90);
assert_eq!(cfg.tile_m, 128);
assert_eq!(cfg.tile_n, 256);
assert_eq!(cfg.stages, 4);
}
#[test]
fn heuristic_fallback_pre_ada() {
let cfg = Fp8TileHeuristic::select(256, 256, 256, SmVersion::Sm80);
assert!(!cfg.use_tensor_core);
assert_eq!(cfg.stages, 1);
}
#[test]
fn tile_config_delegates_to_heuristic() {
let general = Fp8Config::tile_config(SmVersion::Sm90, 256, 512, 128);
let skinny_m = Fp8Config::tile_config(SmVersion::Sm90, 64, 512, 256);
assert!(general.tile_m > skinny_m.tile_m);
}
#[test]
fn heuristic_skinny_m_blackwell() {
let cfg = Fp8TileHeuristic::select(1, 16384, 256, SmVersion::Sm100);
assert_eq!(cfg.tile_m, 64);
assert_eq!(cfg.tile_n, 256);
}
#[test]
fn heuristic_skinny_k_ada() {
let cfg = Fp8TileHeuristic::select(256, 256, 16, SmVersion::Sm89);
assert_eq!(cfg.stages, 2);
}
#[test]
fn heuristic_large_square_sm120() {
let cfg = Fp8TileHeuristic::select(4096, 4096, 4096, SmVersion::Sm120);
assert_eq!(cfg.tile_m, 256);
assert_eq!(cfg.tile_n, 256);
assert_eq!(cfg.tile_k, 128);
assert_eq!(cfg.stages, 4);
assert!(cfg.use_tensor_core);
}
#[test]
fn fp8_workload_class_256x256x256_is_general() {
let wl = classify_workload(256, 256, 256);
assert_eq!(
wl,
Fp8WorkloadClass::General,
"256×256×256 should classify as General"
);
}
#[test]
fn fp8_workload_class_large_matrix_is_large_square() {
let wl = classify_workload(4096, 4096, 4096);
assert_eq!(
wl,
Fp8WorkloadClass::LargeSquare,
"4096×4096×4096 should classify as LargeSquare"
);
}
#[test]
fn ada_fp8_e4m3_tile_config_uses_tensor_core() {
let cfg = Fp8TileHeuristic::select(128, 128, 128, SmVersion::Sm89);
assert!(
cfg.use_tensor_core,
"Ada FP8 GEMM must use Tensor Core path (mma.sync)"
);
}
#[test]
fn ada_fp8_e4m3_format_is_supported() {
assert!(
Fp8Config::is_format_supported(Fp8Format::E4M3, SmVersion::Sm89),
"Ada Lovelace (SM89) must support FP8 E4M3"
);
assert!(
Fp8Config::is_available(SmVersion::Sm89),
"Fp8Config::is_available must return true for SM89"
);
}
#[test]
fn ada_fp8_e5m2_format_is_supported() {
assert!(
Fp8Config::is_format_supported(Fp8Format::E5M2, SmVersion::Sm89),
"Ada Lovelace (SM89) must support FP8 E5M2"
);
}
#[test]
fn ada_fp8_mixed_e4m3_e5m2_both_valid_on_sm89() {
let e4m3_ok = Fp8Config::validate(SmVersion::Sm89, Fp8Format::E4M3, 128, 128, 128);
let e5m2_ok = Fp8Config::validate(SmVersion::Sm89, Fp8Format::E5M2, 128, 128, 128);
assert!(
e4m3_ok.is_ok(),
"Ada E4M3 validation must pass: {:?}",
e4m3_ok.err()
);
assert!(
e5m2_ok.is_ok(),
"Ada E5M2 validation must pass: {:?}",
e5m2_ok.err()
);
}
#[test]
fn fp8_not_supported_on_pre_ada_sm80() {
assert!(
!Fp8Config::is_available(SmVersion::Sm80),
"Ampere (SM80) must NOT support FP8 GEMM"
);
assert!(
!Fp8Config::is_format_supported(Fp8Format::E4M3, SmVersion::Sm80),
"Ampere must NOT support E4M3"
);
assert!(
!Fp8Config::is_format_supported(Fp8Format::E5M2, SmVersion::Sm80),
"Ampere must NOT support E5M2"
);
}
#[test]
fn hopper_fp8_wgmma_larger_tile_k_than_ada() {
let ada_cfg = Fp8TileHeuristic::select(4096, 4096, 4096, SmVersion::Sm89);
let hopper_cfg = Fp8TileHeuristic::select(4096, 4096, 4096, SmVersion::Sm90);
assert!(
hopper_cfg.tile_k >= ada_cfg.tile_k,
"Hopper FP8 tile_k ({}) must be >= Ada FP8 tile_k ({})",
hopper_cfg.tile_k,
ada_cfg.tile_k
);
assert!(
Fp8Config::has_wgmma(SmVersion::Sm90),
"Hopper must have WGMMA for high-throughput FP8"
);
}
}