use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_ptx::prelude::*;
use crate::error::{BlasError, BlasResult};
use crate::types::{MathMode, Transpose};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct GemmProblem {
pub m: u32,
pub n: u32,
pub k: u32,
pub trans_a: Transpose,
pub trans_b: Transpose,
pub input_type: PtxType,
pub output_type: PtxType,
pub math_mode: MathMode,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct TileConfig {
pub tile_m: u32,
pub tile_n: u32,
pub tile_k: u32,
pub warp_m: u32,
pub warp_n: u32,
pub stages: u32,
pub use_tensor_core: bool,
pub split_k: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GemmCategory {
Standard,
Skinny,
SplitK,
StreamK,
WarpSpecialized,
BandwidthLimited,
}
struct CompiledGemm {
_module: Arc<Module>,
kernel: Kernel,
tile_config: TileConfig,
shared_mem_bytes: u32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct GemmKernelKey {
input_type: PtxType,
output_type: PtxType,
trans_a: Transpose,
trans_b: Transpose,
tile_config: TileConfig,
}
pub struct GemmDispatcher {
sm_version: SmVersion,
compiled: RwLock<HashMap<GemmKernelKey, Arc<CompiledGemm>>>,
}
impl GemmDispatcher {
pub fn new(sm: SmVersion) -> Self {
Self {
sm_version: sm,
compiled: RwLock::new(HashMap::new()),
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch(
&self,
problem: &GemmProblem,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
alpha_bits: u64,
beta_bits: u64,
stream: &oxicuda_driver::Stream,
) -> BlasResult<()> {
let category = self.classify(problem);
let tile_config = self.heuristic_tile_config(problem, &category);
let compiled = self.get_or_compile(problem, &tile_config)?;
let grid = Self::compute_grid(problem, &compiled.tile_config);
let block = Self::compute_block(&compiled.tile_config);
let params = LaunchParams::new(grid, block).with_shared_mem(compiled.shared_mem_bytes);
let args = (
a_ptr, b_ptr, c_ptr, problem.m, problem.n, problem.k, alpha_bits, beta_bits,
);
compiled
.kernel
.launch(¶ms, stream, &args)
.map_err(|e| BlasError::LaunchFailed(format!("GEMM kernel launch failed: {e}")))?;
Ok(())
}
pub fn classify(&self, problem: &GemmProblem) -> GemmCategory {
let m = problem.m;
let n = problem.n;
let k = problem.k;
if m < 32 || n < 32 {
return GemmCategory::Skinny;
}
if k > 4 * m && k > 4 * n && k >= 1024 {
return GemmCategory::SplitK;
}
{
let elem_bytes = problem.input_type.size_bytes();
if super::bandwidth_opt::is_bandwidth_limited(
m as usize, n as usize, k as usize, elem_bytes,
) {
return GemmCategory::BandwidthLimited;
}
}
if super::warp_specialized::WarpSpecializedGemm::is_applicable(problem, self.sm_version) {
return GemmCategory::WarpSpecialized;
}
if self.sm_version >= SmVersion::Sm90
&& u64::from(m) * u64::from(n) * u64::from(k) >= 64 * 1024 * 1024
{
return GemmCategory::StreamK;
}
GemmCategory::Standard
}
pub fn heuristic_tile_config(
&self,
problem: &GemmProblem,
category: &GemmCategory,
) -> TileConfig {
let caps = self.sm_version.capabilities();
let use_tc = problem.math_mode == MathMode::TensorCore
&& caps.has_tensor_cores
&& super::tensor_core::TensorCoreValidator::is_supported(
self.sm_version,
problem.input_type,
problem.output_type,
);
match category {
GemmCategory::Standard => {
let selector = super::tiles::TileSelector::new(self.sm_version, use_tc);
selector.select(problem.m, problem.n, problem.k)
}
GemmCategory::Skinny => self.skinny_tile_config(problem, use_tc),
GemmCategory::SplitK => self.splitk_tile_config(problem, use_tc),
GemmCategory::StreamK => self.streamk_tile_config(use_tc),
GemmCategory::WarpSpecialized => self.warp_specialized_tile_config(problem),
GemmCategory::BandwidthLimited => self.bandwidth_limited_tile_config(problem),
}
}
fn skinny_tile_config(&self, problem: &GemmProblem, use_tc: bool) -> TileConfig {
let small_dim = problem.m.min(problem.n);
let tile_small = if small_dim <= 8 {
8
} else if small_dim <= 16 {
16
} else {
32
};
let tile_large = if use_tc { 128 } else { 64 };
let (tile_m, tile_n) = if problem.m < problem.n {
(tile_small, tile_large)
} else {
(tile_large, tile_small)
};
TileConfig {
tile_m,
tile_n,
tile_k: if use_tc { 32 } else { 8 },
warp_m: tile_m.min(32),
warp_n: tile_n.min(32),
stages: if use_tc && self.sm_version >= SmVersion::Sm80 {
2
} else {
1
},
use_tensor_core: use_tc,
split_k: 1,
}
}
fn splitk_tile_config(&self, problem: &GemmProblem, use_tc: bool) -> TileConfig {
let target_k_per_split = 256u32;
let split_k = (problem.k / target_k_per_split).clamp(2, 32);
let base = if use_tc {
TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: if self.sm_version >= SmVersion::Sm80 {
3
} else {
2
},
use_tensor_core: true,
split_k: 1,
}
} else {
TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 8,
warp_m: 32,
warp_n: 32,
stages: 1,
use_tensor_core: false,
split_k: 1,
}
};
TileConfig { split_k, ..base }
}
fn streamk_tile_config(&self, use_tc: bool) -> TileConfig {
if use_tc {
TileConfig {
tile_m: 256,
tile_n: 128,
tile_k: 64,
warp_m: 64,
warp_n: 64,
stages: 4,
use_tensor_core: true,
split_k: 1, }
} else {
TileConfig {
tile_m: 128,
tile_n: 64,
tile_k: 16,
warp_m: 32,
warp_n: 32,
stages: 2,
use_tensor_core: false,
split_k: 1,
}
}
}
fn warp_specialized_tile_config(&self, problem: &GemmProblem) -> TileConfig {
let volume = u64::from(problem.m) * u64::from(problem.n) * u64::from(problem.k);
let stages = if volume >= 256 * 1024 * 1024 { 4 } else { 3 };
match super::warp_specialized::WarpSpecializedGemm::new(
128,
128,
64,
2,
6,
stages,
self.sm_version,
problem.input_type,
problem.output_type,
) {
Ok(ws) => ws.to_tile_config(),
Err(_) => {
TileConfig {
tile_m: 256,
tile_n: 128,
tile_k: 64,
warp_m: 64,
warp_n: 64,
stages: 4,
use_tensor_core: true,
split_k: 1,
}
}
}
}
fn bandwidth_limited_tile_config(&self, problem: &GemmProblem) -> TileConfig {
let prec = match problem.input_type {
PtxType::F16 => super::bandwidth_opt::BandwidthPrecision::F16,
PtxType::BF16 => super::bandwidth_opt::BandwidthPrecision::BF16,
PtxType::F64 => super::bandwidth_opt::BandwidthPrecision::F64,
_ => super::bandwidth_opt::BandwidthPrecision::F32,
};
let cfg = super::bandwidth_opt::BandwidthGemmConfig {
m: problem.m as usize,
n: problem.n as usize,
k: problem.k as usize,
sm_version: self.sm_version,
precision: prec,
strategy: super::bandwidth_opt::BandwidthStrategy::Auto,
};
let bw = super::bandwidth_opt::select_bandwidth_tiles(&cfg);
TileConfig {
tile_m: bw.tile_m as u32,
tile_n: bw.tile_n as u32,
tile_k: bw.tile_k as u32,
warp_m: (bw.tile_m / bw.warps_m.max(1)) as u32,
warp_n: (bw.tile_n / bw.warps_n.max(1)) as u32,
stages: bw.pipeline_stages as u32,
use_tensor_core: false,
split_k: 1,
}
}
fn get_or_compile(
&self,
problem: &GemmProblem,
tile_config: &TileConfig,
) -> BlasResult<Arc<CompiledGemm>> {
let key = GemmKernelKey {
input_type: problem.input_type,
output_type: problem.output_type,
trans_a: problem.trans_a,
trans_b: problem.trans_b,
tile_config: tile_config.clone(),
};
{
let cache = self
.compiled
.read()
.map_err(|_| BlasError::LaunchFailed("kernel cache lock poisoned".into()))?;
if let Some(entry) = cache.get(&key) {
return Ok(Arc::clone(entry));
}
}
let template = GemmTemplate {
tile_m: tile_config.tile_m,
tile_n: tile_config.tile_n,
tile_k: tile_config.tile_k,
warp_m: tile_config.warp_m,
warp_n: tile_config.warp_n,
precision: problem.input_type,
accumulator: problem.output_type,
use_tensor_core: tile_config.use_tensor_core,
stages: tile_config.stages,
target: self.sm_version,
epilogue: EpilogueKind::LinearCombination,
};
let ptx = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("GEMM PTX generation failed: {e}")))?;
let kernel_name = template.kernel_name();
let module = Arc::new(
Module::from_ptx(&ptx)
.map_err(|e| BlasError::LaunchFailed(format!("module load failed: {e}")))?,
);
let kernel = Kernel::from_module(Arc::clone(&module), &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup failed: {e}")))?;
let elem_bytes = problem.input_type.size_bytes() as u32;
let smem_a = tile_config.tile_m * tile_config.tile_k * elem_bytes;
let smem_b = tile_config.tile_k * tile_config.tile_n * elem_bytes;
let shared_mem_bytes = (smem_a + smem_b) * tile_config.stages;
let entry = Arc::new(CompiledGemm {
_module: module,
kernel,
tile_config: tile_config.clone(),
shared_mem_bytes,
});
{
let mut cache = self
.compiled
.write()
.map_err(|_| BlasError::LaunchFailed("kernel cache lock poisoned".into()))?;
cache.insert(key, Arc::clone(&entry));
}
Ok(entry)
}
fn compute_grid(problem: &GemmProblem, tc: &TileConfig) -> Dim3 {
let grid_x = problem.n.div_ceil(tc.tile_n);
let grid_y = problem.m.div_ceil(tc.tile_m);
let grid_z = tc.split_k;
Dim3::new(grid_x, grid_y, grid_z)
}
fn compute_block(tc: &TileConfig) -> Dim3 {
const WARP_SIZE: u32 = 32;
let warps_m = tc.tile_m / tc.warp_m.max(1);
let warps_n = tc.tile_n / tc.warp_n.max(1);
let threads = (warps_m * warps_n * WARP_SIZE).min(1024);
Dim3::new(threads, 1, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_problem(m: u32, n: u32, k: u32) -> GemmProblem {
GemmProblem {
m,
n,
k,
trans_a: Transpose::NoTrans,
trans_b: Transpose::NoTrans,
input_type: PtxType::F32,
output_type: PtxType::F32,
math_mode: MathMode::Default,
}
}
#[test]
fn classify_standard() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(512, 512, 512);
assert_eq!(d.classify(&p), GemmCategory::Standard);
}
#[test]
fn classify_skinny_m() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(8, 512, 512);
assert_eq!(d.classify(&p), GemmCategory::Skinny);
}
#[test]
fn classify_skinny_n() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(512, 16, 512);
assert_eq!(d.classify(&p), GemmCategory::Skinny);
}
#[test]
fn classify_split_k() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(64, 64, 8192);
assert_eq!(d.classify(&p), GemmCategory::SplitK);
}
#[test]
fn classify_stream_k_on_hopper() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let p = make_problem(4096, 4096, 4096);
assert_eq!(d.classify(&p), GemmCategory::StreamK);
}
#[test]
fn classify_standard_on_ampere_large() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(4096, 4096, 4096);
assert_eq!(d.classify(&p), GemmCategory::Standard);
}
#[test]
fn heuristic_simt_tile() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(256, 256, 256);
let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(!tc.use_tensor_core);
assert!(tc.tile_m > 0);
assert!(tc.tile_n > 0);
assert!(tc.tile_k > 0);
assert_eq!(tc.split_k, 1);
}
#[test]
fn heuristic_tc_tile_ampere() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let mut p = make_problem(1024, 1024, 1024);
p.math_mode = MathMode::TensorCore;
p.input_type = PtxType::F16;
let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(tc.use_tensor_core);
assert_eq!(tc.stages, 3);
}
#[test]
fn compute_grid_basic() {
let p = make_problem(256, 512, 128);
let tc = TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 1,
use_tensor_core: false,
split_k: 1,
};
let grid = GemmDispatcher::compute_grid(&p, &tc);
assert_eq!(grid.x, 4); assert_eq!(grid.y, 2); assert_eq!(grid.z, 1);
}
#[test]
fn classify_warp_specialized_hopper_f16() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let mut p = make_problem(4096, 4096, 4096);
p.input_type = PtxType::F16;
assert_eq!(d.classify(&p), GemmCategory::WarpSpecialized);
}
#[test]
fn classify_warp_specialized_hopper_bf16() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let mut p = make_problem(4096, 4096, 4096);
p.input_type = PtxType::BF16;
assert_eq!(d.classify(&p), GemmCategory::WarpSpecialized);
}
#[test]
fn classify_stream_k_hopper_f32_not_warp_specialized() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let p = make_problem(4096, 4096, 4096);
assert_eq!(d.classify(&p), GemmCategory::StreamK);
}
#[test]
fn heuristic_warp_specialized_tile() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let mut p = make_problem(4096, 4096, 4096);
p.input_type = PtxType::F16;
p.output_type = PtxType::F32;
let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::WarpSpecialized);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(tc.use_tensor_core);
assert_eq!(tc.tile_m, 128);
assert_eq!(tc.tile_n, 128);
assert_eq!(tc.tile_k, 64);
}
#[test]
fn compute_block_basic() {
let tc = TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 1,
use_tensor_core: false,
split_k: 1,
};
let block = GemmDispatcher::compute_block(&tc);
assert_eq!(block.x, 128);
}
#[test]
fn classify_large_square_as_standard() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(1024, 1024, 1024);
assert_eq!(
d.classify(&p),
GemmCategory::Standard,
"1024x1024x1024 on Ampere should be Standard"
);
}
#[test]
fn classify_thin_m_as_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(16, 1024, 512);
assert_eq!(
d.classify(&p),
GemmCategory::Skinny,
"M=16 should produce Skinny"
);
}
#[test]
fn classify_thin_n_as_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(1024, 8, 512);
assert_eq!(
d.classify(&p),
GemmCategory::Skinny,
"N=8 should produce Skinny"
);
}
#[test]
fn skinny_takes_priority_over_splitk() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(16, 16, 65536);
assert_eq!(
d.classify(&p),
GemmCategory::Skinny,
"Skinny check runs before SplitK"
);
}
#[test]
fn classify_k_heavy_as_splitk() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(64, 64, 8192);
assert_eq!(
d.classify(&p),
GemmCategory::SplitK,
"K=8192 >> M=64, N=64 should be SplitK"
);
}
#[test]
fn classify_moderate_k_not_splitk() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(64, 64, 200);
assert_ne!(
d.classify(&p),
GemmCategory::SplitK,
"K=200 is not > 4*M=256, so not SplitK"
);
}
#[test]
fn boundary_skinny_m_31_is_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(31, 512, 512);
assert_eq!(
d.classify(&p),
GemmCategory::Skinny,
"M=31 < 32 should be Skinny"
);
}
#[test]
fn boundary_skinny_m_32_not_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(32, 512, 512);
assert_ne!(
d.classify(&p),
GemmCategory::Skinny,
"M=32 is not < 32, should not be Skinny"
);
}
#[test]
fn boundary_skinny_n_31_is_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(512, 31, 512);
assert_eq!(
d.classify(&p),
GemmCategory::Skinny,
"N=31 < 32 should be Skinny"
);
}
#[test]
fn boundary_skinny_n_32_not_skinny() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(512, 32, 512);
assert_ne!(
d.classify(&p),
GemmCategory::Skinny,
"N=32 is not < 32, should not be Skinny"
);
}
#[test]
fn skinny_tile_has_small_dim_for_thin_m() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let mut p = make_problem(8, 512, 512);
p.math_mode = MathMode::Default;
let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::Skinny);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.tile_m <= 16,
"skinny M=8 should have tile_m <= 16, got {}",
tc.tile_m
);
}
#[test]
fn skinny_tile_has_small_dim_for_thin_n() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let mut p = make_problem(512, 16, 512);
p.math_mode = MathMode::Default;
let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::Skinny);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.tile_n <= 16,
"skinny N=16 should have tile_n <= 16, got {}",
tc.tile_n
);
}
#[test]
fn splitk_tile_has_split_factor_gt_1() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(64, 64, 8192);
let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::SplitK);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(tc.split_k > 1, "SplitK tile config must have split_k > 1");
}
#[test]
fn stream_k_only_on_hopper() {
let d_ampere = GemmDispatcher::new(SmVersion::Sm80);
let d_hopper = GemmDispatcher::new(SmVersion::Sm90);
let p = make_problem(4096, 4096, 4096);
let cat_ampere = d_ampere.classify(&p);
let cat_hopper = d_hopper.classify(&p);
assert_ne!(
cat_ampere,
GemmCategory::StreamK,
"Ampere should not use StreamK"
);
assert_eq!(
cat_hopper,
GemmCategory::StreamK,
"Hopper with large F32 problem should use StreamK"
);
}
#[test]
fn stream_k_tile_has_split_k_1() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let p = make_problem(4096, 4096, 4096); let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::StreamK);
let tc = d.heuristic_tile_config(&p, &cat);
assert_eq!(
tc.split_k, 1,
"StreamK manages its own decomposition, split_k must be 1"
);
}
#[test]
fn all_categories_produce_positive_tile_dims() {
let configs: &[(SmVersion, u32, u32, u32, PtxType)] = &[
(SmVersion::Sm80, 1024, 1024, 1024, PtxType::F32),
(SmVersion::Sm80, 8, 512, 256, PtxType::F32),
(SmVersion::Sm80, 512, 16, 256, PtxType::F32),
(SmVersion::Sm80, 64, 64, 8192, PtxType::F32),
(SmVersion::Sm90, 4096, 4096, 4096, PtxType::F32),
(SmVersion::Sm90, 4096, 4096, 4096, PtxType::F16),
];
for &(sm, m, n, k, itype) in configs {
let d = GemmDispatcher::new(sm);
let mut p = make_problem(m, n, k);
p.input_type = itype;
let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(tc.tile_m > 0, "tile_m=0 for {:?}", cat);
assert!(tc.tile_n > 0, "tile_n=0 for {:?}", cat);
assert!(tc.tile_k > 0, "tile_k=0 for {:?}", cat);
assert!(tc.stages > 0, "stages=0 for {:?}", cat);
}
}
#[test]
fn hopper_simt_stages_ge_turing_simt_stages() {
let d_hopper = GemmDispatcher::new(SmVersion::Sm90);
let d_turing = GemmDispatcher::new(SmVersion::Sm75);
let p = make_problem(1024, 1024, 1024);
let cat_h = d_hopper.classify(&p);
let cat_t = d_turing.classify(&p);
let tc_h = d_hopper.heuristic_tile_config(&p, &cat_h);
let tc_t = d_turing.heuristic_tile_config(&p, &cat_t);
assert!(
tc_h.stages >= tc_t.stages,
"Hopper ({}) should have >= SIMT stages as Turing ({})",
tc_h.stages,
tc_t.stages
);
}
#[test]
fn simt_fallback_no_tensor_core() {
let d = GemmDispatcher::new(SmVersion::Sm75);
let p = make_problem(512, 512, 512); let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
!tc.use_tensor_core,
"SIMT/Default math mode should not use tensor core"
);
}
#[test]
fn hopper_warp_specialized_f16_tile_valid_for_wgmma() {
let d = GemmDispatcher::new(SmVersion::Sm90);
let mut p = make_problem(4096, 4096, 4096);
p.input_type = PtxType::F16;
p.output_type = PtxType::F32;
p.math_mode = MathMode::TensorCore;
let cat = d.classify(&p);
assert_eq!(
cat,
GemmCategory::WarpSpecialized,
"Hopper F16 large problem should select WarpSpecialized"
);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(tc.use_tensor_core, "Hopper warp-specialized must use TC");
assert_eq!(
tc.tile_m % 16,
0,
"tile_m must be multiple of 16 for wgmma, got {}",
tc.tile_m
);
assert_eq!(
tc.tile_k % 16,
0,
"tile_k must be multiple of 16 for wgmma, got {}",
tc.tile_k
);
assert!(
tc.stages >= 2,
"Hopper TMA pipeline needs >= 2 stages, got {}",
tc.stages
);
}
#[test]
fn hopper_warp_specialized_ptx_contains_mma_and_cp_async() {
let gemm = super::super::warp_specialized::WarpSpecializedGemm::new(
128,
128,
64,
2,
6,
3,
SmVersion::Sm90,
PtxType::F16,
PtxType::F32,
)
.expect("valid Hopper warp-specialized config");
let ptx = gemm.generate_kernel().expect("PTX generation must succeed");
assert!(
ptx.contains("mma.sync.aligned"),
"Hopper warp-specialized PTX must contain mma.sync.aligned"
);
assert!(
ptx.contains("cp.async"),
"Hopper TMA pipeline PTX must contain cp.async"
);
assert!(
ptx.contains("cp.async.commit_group"),
"Producer path must commit async groups"
);
assert!(
ptx.contains("bar.arrive"),
"Producer path must signal consumer via bar.arrive"
);
assert!(
ptx.contains(".target sm_90"),
"PTX must target sm_90 for Hopper"
);
}
#[test]
fn ada_fp8_e4m3_ptx_contains_correct_mma_shape() {
let gemm = super::super::warp_specialized::WarpSpecializedGemm::new(
128,
128,
64,
2,
6,
2,
SmVersion::Sm90, PtxType::E4M3,
PtxType::F32,
)
.expect("valid FP8 E4M3 warp-specialized config");
let ptx = gemm.generate_kernel().expect("PTX generation must succeed");
assert!(
ptx.contains("e4m3"),
"FP8 E4M3 PTX must reference e4m3 type"
);
assert!(
ptx.contains("m16n8k32"),
"FP8 E4M3 must use m16n8k32 MMA shape (2x K vs F16 m16n8k16)"
);
assert!(
ptx.contains("mma.sync.aligned"),
"FP8 PTX must contain mma.sync.aligned"
);
}
#[test]
fn ada_fp8_e5m2_ptx_contains_correct_mma_shape() {
let gemm = super::super::warp_specialized::WarpSpecializedGemm::new(
128,
128,
64,
2,
6,
2,
SmVersion::Sm90a,
PtxType::E5M2,
PtxType::F32,
)
.expect("valid FP8 E5M2 config");
let ptx = gemm.generate_kernel().expect("PTX generation must succeed");
assert!(
ptx.contains("e5m2"),
"FP8 E5M2 PTX must reference e5m2 type"
);
assert!(
ptx.contains("m16n8k32"),
"FP8 E5M2 must use m16n8k32 MMA shape"
);
}
#[test]
fn turing_sm75_f16_tensor_core_path() {
let d = GemmDispatcher::new(SmVersion::Sm75);
let mut p = make_problem(1024, 1024, 512);
p.input_type = PtxType::F16;
p.output_type = PtxType::F32;
p.math_mode = MathMode::TensorCore;
let cat = d.classify(&p);
assert_eq!(
cat,
GemmCategory::Standard,
"Turing should use Standard category for this shape"
);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.use_tensor_core,
"Turing SM75 with F16 + TensorCore math must use TC path"
);
assert_eq!(
tc.tile_k % 16,
0,
"Turing tile_k must be multiple of 16 for WMMA m16n16k16, got {}",
tc.tile_k
);
}
#[test]
fn turing_sm75_tc_stages_capped_at_2() {
let d = GemmDispatcher::new(SmVersion::Sm75);
let mut p = make_problem(1024, 1024, 1024);
p.input_type = PtxType::F16;
p.math_mode = MathMode::TensorCore;
let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.stages <= 2,
"Turing TC path must have at most 2 pipeline stages, got {}",
tc.stages
);
}
#[test]
fn skinny_m1_classifies_and_uses_small_tile() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(1, 4096, 4096);
let cat = d.classify(&p);
assert_eq!(
cat,
GemmCategory::Skinny,
"M=1 must classify as Skinny (< 32)"
);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.tile_m <= 8,
"M=1 skinny tile must have tile_m <= 8 to avoid wasted threads, got {}",
tc.tile_m
);
}
#[test]
fn skinny_matrix_path_documented_efficiency() {
let d = GemmDispatcher::new(SmVersion::Sm80);
let p = make_problem(4, 2048, 2048);
let cat = d.classify(&p);
assert_eq!(cat, GemmCategory::Skinny);
let tc = d.heuristic_tile_config(&p, &cat);
assert!(
tc.tile_m <= 16,
"Small-M skinny tile must be compact (tile_m <= 16) for efficiency, got {}",
tc.tile_m
);
}
#[test]
fn tile_config_fits_shared_memory_budget() {
let sm_versions = [SmVersion::Sm75, SmVersion::Sm80, SmVersion::Sm90];
let test_problems: &[(u32, u32, u32)] =
&[(1024, 1024, 1024), (512, 512, 512), (256, 256, 256)];
for sm in sm_versions {
let sm_limit = sm.max_shared_mem_per_block();
for &(m, n, k) in test_problems {
let d = GemmDispatcher::new(sm);
let p = make_problem(m, n, k);
let cat = d.classify(&p);
let tc = d.heuristic_tile_config(&p, &cat);
let elem_bytes = 4u32;
let smem_a = tc.tile_m * tc.tile_k * elem_bytes;
let smem_b = tc.tile_k * tc.tile_n * elem_bytes;
let total_smem = (smem_a + smem_b) * tc.stages;
assert!(
total_smem <= sm_limit,
"SM{} ({:?}): smem={} > limit={} for {}x{}x{}",
sm as u32,
cat,
total_smem,
sm_limit,
m,
n,
k
);
}
}
}
}