use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, MatrixDesc, MatrixDescMut, Transpose};
use super::gemm::dispatch::{GemmDispatcher, GemmProblem};
#[allow(clippy::too_many_arguments)]
pub fn gemm<T: GpuFloat>(
handle: &BlasHandle,
trans_a: Transpose,
trans_b: Transpose,
alpha: T,
a: &MatrixDesc<T>,
b: &MatrixDesc<T>,
beta: T,
c: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
let (m, k_a) = a.effective_dims(trans_a);
let (k_b, n) = b.effective_dims(trans_b);
if m == 0 || n == 0 || k_a == 0 {
return Err(BlasError::InvalidDimension(
"GEMM dimensions must be non-zero".into(),
));
}
if k_a != k_b {
return Err(BlasError::DimensionMismatch(format!(
"inner dimensions of op(A) ({k_a}) and op(B) ({k_b}) do not match"
)));
}
let k = k_a;
if c.rows != m || c.cols != n {
return Err(BlasError::DimensionMismatch(format!(
"C is {}x{} but GEMM produces {}x{}",
c.rows, c.cols, m, n
)));
}
let problem = GemmProblem {
m,
n,
k,
trans_a,
trans_b,
input_type: T::PTX_TYPE,
output_type: accumulator_ptx_type::<T>(),
math_mode: handle.math_mode(),
};
let dispatcher = GemmDispatcher::new(handle.sm_version());
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
dispatcher.dispatch(
&problem,
a.ptr,
b.ptr,
c.ptr,
alpha_bits,
beta_bits,
handle.stream(),
)
}
fn accumulator_ptx_type<T: GpuFloat>() -> PtxType {
<T::Accumulator as GpuFloat>::PTX_TYPE
}
pub fn gemm_output_dims(
a_rows: u32,
a_cols: u32,
trans_a: Transpose,
b_rows: u32,
b_cols: u32,
trans_b: Transpose,
) -> (u32, u32) {
let m = match trans_a {
Transpose::NoTrans => a_rows,
Transpose::Trans | Transpose::ConjTrans => a_cols,
};
let n = match trans_b {
Transpose::NoTrans => b_cols,
Transpose::Trans | Transpose::ConjTrans => b_rows,
};
(m, n)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn output_dims_no_trans() {
let (m, n) = gemm_output_dims(128, 64, Transpose::NoTrans, 64, 256, Transpose::NoTrans);
assert_eq!((m, n), (128, 256));
}
#[test]
fn output_dims_trans_a() {
let (m, n) = gemm_output_dims(64, 128, Transpose::Trans, 128, 256, Transpose::NoTrans);
assert_eq!((m, n), (128, 256));
}
#[test]
fn output_dims_trans_b() {
let (m, n) = gemm_output_dims(128, 64, Transpose::NoTrans, 256, 64, Transpose::Trans);
assert_eq!((m, n), (128, 256));
}
#[test]
fn output_dims_both_trans() {
let (m, n) = gemm_output_dims(64, 128, Transpose::Trans, 256, 128, Transpose::Trans);
assert_eq!((m, n), (128, 256));
}
#[test]
fn accumulator_ptx_type_f32() {
assert_eq!(accumulator_ptx_type::<f32>(), PtxType::F32);
}
#[test]
fn accumulator_ptx_type_f64() {
assert_eq!(accumulator_ptx_type::<f64>(), PtxType::F64);
}
}