use oxicuda_ptx::prelude::SmVersion;
use crate::error::{BlasError, BlasResult};
use crate::level3::gemm::dispatch::TileConfig;
pub struct Bf16Config;
impl Bf16Config {
pub const ELEMENT_BYTES: u32 = 2;
#[must_use]
pub fn tile_config(sm: SmVersion, _m: u32, _n: u32, _k: u32) -> TileConfig {
match sm {
SmVersion::Sm80 | SmVersion::Sm86 => TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
SmVersion::Sm89 => TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
SmVersion::Sm90 | SmVersion::Sm90a => TileConfig {
tile_m: 128,
tile_n: 256,
tile_k: 64,
warp_m: 64,
warp_n: 128,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
SmVersion::Sm100 | SmVersion::Sm120 => TileConfig {
tile_m: 128,
tile_n: 256,
tile_k: 64,
warp_m: 64,
warp_n: 128,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
_ => {
TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 16,
warp_m: 32,
warp_n: 32,
stages: 2,
use_tensor_core: false,
split_k: 1,
}
}
}
}
#[must_use]
pub fn has_tensor_core(sm: SmVersion) -> bool {
sm >= SmVersion::Sm80
}
#[must_use]
pub fn has_wgmma(sm: SmVersion) -> bool {
sm >= SmVersion::Sm90
}
pub fn validate_dimensions(m: u32, n: u32, k: u32) -> BlasResult<()> {
if m == 0 || n == 0 || k == 0 {
return Err(BlasError::InvalidDimension(format!(
"BF16 GEMM dimensions must be non-zero: m={m}, n={n}, k={k}"
)));
}
Ok(())
}
#[must_use]
pub fn is_efficient(sm: SmVersion) -> bool {
sm >= SmVersion::Sm80
}
#[must_use]
pub fn compute_split_k(m: u32, n: u32, k: u32) -> u32 {
if m * n < 4096 && k > 16384 {
let factor = k / 4096;
factor.clamp(2, 32)
} else {
1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_bf16_tensor_core_on_turing() {
assert!(!Bf16Config::has_tensor_core(SmVersion::Sm75));
let cfg = Bf16Config::tile_config(SmVersion::Sm75, 512, 512, 512);
assert!(!cfg.use_tensor_core);
}
#[test]
fn bf16_tensor_core_on_ampere() {
assert!(Bf16Config::has_tensor_core(SmVersion::Sm80));
let cfg = Bf16Config::tile_config(SmVersion::Sm80, 1024, 1024, 1024);
assert!(cfg.use_tensor_core);
}
#[test]
fn wgmma_support() {
assert!(!Bf16Config::has_wgmma(SmVersion::Sm89));
assert!(Bf16Config::has_wgmma(SmVersion::Sm90));
}
#[test]
fn efficiency_check() {
assert!(!Bf16Config::is_efficient(SmVersion::Sm75));
assert!(Bf16Config::is_efficient(SmVersion::Sm80));
}
#[test]
fn validate_ok() {
assert!(Bf16Config::validate_dimensions(256, 256, 128).is_ok());
}
#[test]
fn validate_zero() {
assert!(Bf16Config::validate_dimensions(256, 0, 128).is_err());
}
}