use oxicuda_ptx::prelude::SmVersion;
use crate::error::{BlasError, BlasResult};
use crate::level3::gemm::dispatch::TileConfig;
pub struct F64Config;
impl F64Config {
pub const ELEMENT_BYTES: u32 = 8;
#[must_use]
pub fn tile_config(sm: SmVersion, _m: u32, _n: u32, _k: u32) -> TileConfig {
match sm {
SmVersion::Sm80 | SmVersion::Sm86 => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 16,
warp_m: 32,
warp_n: 32,
stages: 3,
use_tensor_core: true,
split_k: 1,
},
SmVersion::Sm90 | SmVersion::Sm90a => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 16,
warp_m: 32,
warp_n: 32,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
SmVersion::Sm100 | SmVersion::Sm120 => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 16,
warp_m: 32,
warp_n: 32,
stages: 4,
use_tensor_core: true,
split_k: 1,
},
_ => TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 8,
warp_m: 32,
warp_n: 32,
stages: 2,
use_tensor_core: false,
split_k: 1,
},
}
}
#[must_use]
pub fn has_tensor_core(sm: SmVersion) -> bool {
sm >= SmVersion::Sm80
}
pub fn validate_dimensions(m: u32, n: u32, k: u32) -> BlasResult<()> {
if m == 0 || n == 0 || k == 0 {
return Err(BlasError::InvalidDimension(format!(
"FP64 GEMM dimensions must be non-zero: m={m}, n={n}, k={k}"
)));
}
Ok(())
}
#[must_use]
pub fn compute_split_k(m: u32, n: u32, k: u32) -> u32 {
if m * n < 4096 && k > 4096 {
let factor = k / 1024;
factor.clamp(2, 16)
} else {
1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tensor_core_availability() {
assert!(!F64Config::has_tensor_core(SmVersion::Sm75));
assert!(F64Config::has_tensor_core(SmVersion::Sm80));
assert!(F64Config::has_tensor_core(SmVersion::Sm90));
}
#[test]
fn tile_config_turing_no_tensor_core() {
let cfg = F64Config::tile_config(SmVersion::Sm75, 1024, 1024, 1024);
assert!(!cfg.use_tensor_core);
assert_eq!(cfg.tile_k, 8);
}
#[test]
fn tile_config_ampere_uses_tensor_core() {
let cfg = F64Config::tile_config(SmVersion::Sm80, 1024, 1024, 1024);
assert!(cfg.use_tensor_core);
assert_eq!(cfg.stages, 3);
}
#[test]
fn validate_zero_dimension() {
assert!(F64Config::validate_dimensions(0, 1024, 1024).is_err());
assert!(F64Config::validate_dimensions(1024, 0, 1024).is_err());
assert!(F64Config::validate_dimensions(1024, 1024, 0).is_err());
}
#[test]
fn validate_valid_dimensions() {
assert!(F64Config::validate_dimensions(128, 256, 512).is_ok());
}
#[test]
fn split_k_small_mn_large_k() {
let sk = F64Config::compute_split_k(32, 32, 8192);
assert!(sk >= 2);
assert!(sk <= 16);
}
#[test]
fn split_k_balanced() {
let sk = F64Config::compute_split_k(1024, 1024, 1024);
assert_eq!(sk, 1);
}
}