use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::templates::gemm::{EpilogueKind, GemmTemplate};
use std::sync::Arc;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, Transpose};
const INDIVIDUAL_DISPATCH_LIMIT: usize = 4;
const TILE_M: u32 = 16;
const TILE_N: u32 = 16;
const TILE_K: u32 = 16;
#[derive(Debug, Clone, Copy)]
pub struct GroupedGemmProblem {
pub trans_a: Transpose,
pub trans_b: Transpose,
pub m: u32,
pub n: u32,
pub k: u32,
pub a_ptr: CUdeviceptr,
pub lda: u32,
pub b_ptr: CUdeviceptr,
pub ldb: u32,
pub c_ptr: CUdeviceptr,
pub ldc: u32,
pub d_ptr: CUdeviceptr,
pub ldd: u32,
}
fn validate_problem<T: GpuFloat>(problem: &GroupedGemmProblem) -> BlasResult<()> {
if problem.m == 0 || problem.n == 0 || problem.k == 0 {
return Err(BlasError::InvalidDimension(
"m, n, and k must all be positive in every grouped problem".into(),
));
}
let rows_a = match problem.trans_a {
Transpose::NoTrans => problem.m,
Transpose::Trans | Transpose::ConjTrans => problem.k,
};
let rows_b = match problem.trans_b {
Transpose::NoTrans => problem.k,
Transpose::Trans | Transpose::ConjTrans => problem.n,
};
if problem.lda < rows_a {
return Err(BlasError::InvalidDimension(format!(
"lda ({}) must be >= rows of op(A) ({rows_a})",
problem.lda
)));
}
if problem.ldb < rows_b {
return Err(BlasError::InvalidDimension(format!(
"ldb ({}) must be >= rows of op(B) ({rows_b})",
problem.ldb
)));
}
if problem.ldc < problem.m {
return Err(BlasError::InvalidDimension(format!(
"ldc ({}) must be >= m ({})",
problem.ldc, problem.m
)));
}
if problem.ldd < problem.m {
return Err(BlasError::InvalidDimension(format!(
"ldd ({}) must be >= m ({})",
problem.ldd, problem.m
)));
}
let _elem = T::SIZE;
Ok(())
}
fn build_gemm_template<T: GpuFloat>(sm: SmVersion) -> GemmTemplate {
GemmTemplate {
tile_m: TILE_M,
tile_n: TILE_N,
tile_k: TILE_K,
warp_m: TILE_M,
warp_n: TILE_N,
precision: T::PTX_TYPE,
accumulator: T::PTX_TYPE,
use_tensor_core: false,
stages: 1,
target: sm,
epilogue: EpilogueKind::LinearCombination,
}
}
fn generate_gemm_ptx<T: GpuFloat>(sm: SmVersion) -> BlasResult<(String, String)> {
let template = build_gemm_template::<T>(sm);
let kernel_name = template.kernel_name();
let ptx = template.generate()?;
Ok((ptx, kernel_name))
}
fn dispatch_individual<T: GpuFloat>(
handle: &BlasHandle,
problems: &[GroupedGemmProblem],
alpha: T,
beta: T,
) -> BlasResult<()> {
let sm = handle.sm_version();
let (ptx_source, kernel_name) = generate_gemm_ptx::<T>(sm)?;
let module = oxicuda_driver::Module::from_ptx(&ptx_source).map_err(BlasError::Cuda)?;
let module = Arc::new(module);
let kernel = Kernel::from_module(Arc::clone(&module), &kernel_name).map_err(BlasError::Cuda)?;
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
for (idx, p) in problems.iter().enumerate() {
let grid = Dim3::new(p.m.div_ceil(TILE_M), p.n.div_ceil(TILE_N), 1);
let block = Dim3::new(TILE_M, TILE_N, 1);
let params = LaunchParams::new(grid, block);
let args = (
p.m, p.n, p.k, alpha_bits, p.a_ptr, p.lda, p.b_ptr, p.ldb, beta_bits, p.c_ptr, p.ldc,
p.d_ptr, p.ldd,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("grouped problem {idx}: {e}")))?;
}
Ok(())
}
const PROBLEM_ROW_U32S: usize = 16;
fn pack_problem_table(problems: &[GroupedGemmProblem]) -> Vec<u32> {
let mut table = Vec::with_capacity(problems.len() * PROBLEM_ROW_U32S);
for p in problems {
table.push(p.m);
table.push(p.n);
table.push(p.k);
table.push(p.lda);
table.push(p.ldb);
table.push(p.ldc);
table.push(p.ldd);
table.push(p.a_ptr as u32);
table.push((p.a_ptr >> 32) as u32);
table.push(p.b_ptr as u32);
table.push((p.b_ptr >> 32) as u32);
table.push(p.c_ptr as u32);
table.push((p.c_ptr >> 32) as u32);
table.push(p.d_ptr as u32);
table.push((p.d_ptr >> 32) as u32);
let trans_flags = encode_transpose(p.trans_a) | (encode_transpose(p.trans_b) << 2);
table.push(trans_flags);
}
table
}
fn encode_transpose(t: Transpose) -> u32 {
match t {
Transpose::NoTrans => 0,
Transpose::Trans => 1,
Transpose::ConjTrans => 2,
}
}
fn compute_block_prefix_sums(problems: &[GroupedGemmProblem]) -> Vec<u32> {
let mut prefix = Vec::with_capacity(problems.len() + 1);
prefix.push(0u32);
for p in problems {
let blocks_m = p.m.div_ceil(TILE_M);
let blocks_n = p.n.div_ceil(TILE_N);
let last = prefix.last().copied().unwrap_or(0);
prefix.push(last.saturating_add(blocks_m.saturating_mul(blocks_n)));
}
prefix
}
fn dispatch_unified<T: GpuFloat>(
handle: &BlasHandle,
problems: &[GroupedGemmProblem],
alpha: T,
beta: T,
) -> BlasResult<()> {
let sm = handle.sm_version();
let (ptx_source, kernel_name) = generate_gemm_ptx::<T>(sm)?;
let module = oxicuda_driver::Module::from_ptx(&ptx_source).map_err(BlasError::Cuda)?;
let module = Arc::new(module);
let kernel = Kernel::from_module(module, &kernel_name).map_err(BlasError::Cuda)?;
let table_host = pack_problem_table(problems);
let mut table_device = DeviceBuffer::<u32>::alloc(table_host.len()).map_err(BlasError::Cuda)?;
table_device
.copy_from_host(&table_host)
.map_err(BlasError::Cuda)?;
let prefix_host = compute_block_prefix_sums(problems);
let total_blocks = prefix_host.last().copied().unwrap_or(0);
let mut prefix_device =
DeviceBuffer::<u32>::alloc(prefix_host.len()).map_err(BlasError::Cuda)?;
prefix_device
.copy_from_host(&prefix_host)
.map_err(BlasError::Cuda)?;
let grid = Dim3::new(total_blocks, 1, 1);
let block = Dim3::new(TILE_M, TILE_N, 1);
let params = LaunchParams::new(grid, block);
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
let args = (
problems.len() as u32,
alpha_bits,
beta_bits,
table_device.as_device_ptr(),
prefix_device.as_device_ptr(),
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(e.to_string()))
}
pub fn gemm_grouped<T: GpuFloat>(
handle: &BlasHandle,
problems: &[GroupedGemmProblem],
alpha: T,
beta: T,
) -> BlasResult<()> {
if problems.is_empty() {
return Ok(());
}
for (idx, problem) in problems.iter().enumerate() {
validate_problem::<T>(problem)
.map_err(|e| BlasError::InvalidArgument(format!("grouped problem {idx}: {e}")))?;
}
if problems.len() <= INDIVIDUAL_DISPATCH_LIMIT {
dispatch_individual::<T>(handle, problems, alpha, beta)
} else {
dispatch_unified::<T>(handle, problems, alpha, beta)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_problem(m: u32, n: u32, k: u32) -> GroupedGemmProblem {
GroupedGemmProblem {
trans_a: Transpose::NoTrans,
trans_b: Transpose::NoTrans,
m,
n,
k,
a_ptr: 0x1000,
lda: m,
b_ptr: 0x2000,
ldb: k,
c_ptr: 0x3000,
ldc: m,
d_ptr: 0x4000,
ldd: m,
}
}
#[test]
fn validate_rejects_zero_dimension() {
let p = make_problem(0, 64, 64);
assert!(validate_problem::<f32>(&p).is_err());
}
#[test]
fn validate_accepts_valid_problem() {
let p = make_problem(128, 64, 32);
assert!(validate_problem::<f32>(&p).is_ok());
}
#[test]
fn encode_transpose_round_trip() {
assert_eq!(encode_transpose(Transpose::NoTrans), 0);
assert_eq!(encode_transpose(Transpose::Trans), 1);
assert_eq!(encode_transpose(Transpose::ConjTrans), 2);
}
#[test]
fn pack_problem_table_row_size() {
let problems = vec![make_problem(64, 64, 64)];
let table = pack_problem_table(&problems);
assert_eq!(table.len(), PROBLEM_ROW_U32S);
}
#[test]
fn prefix_sums_correct() {
let problems = vec![make_problem(32, 32, 16), make_problem(64, 64, 16)];
let prefix = compute_block_prefix_sums(&problems);
assert_eq!(prefix, vec![0, 4, 20]);
}
#[test]
fn validate_transposed_problem() {
let p = GroupedGemmProblem {
trans_a: Transpose::Trans,
trans_b: Transpose::Trans,
m: 64,
n: 32,
k: 16,
a_ptr: 0x1000,
lda: 16, b_ptr: 0x2000,
ldb: 32, c_ptr: 0x3000,
ldc: 64,
d_ptr: 0x4000,
ldd: 64,
};
assert!(validate_problem::<f32>(&p).is_ok());
}
}