use cudarc::cublas::{GemmConfig, StridedBatchedConfig};
use crate::{Error, Result};
pub(crate) fn gemm_config<T>(
alpha: T,
beta: T,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
out_stride: &[usize],
) -> Result<StridedBatchedConfig<T>> {
use cudarc::cublas::sys::cublasOperation_t;
let lhs_dims = [b, m, k];
let rhs_dims = [b, k, n];
let out_dims = [b, m, n];
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(Error::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
out_stride: out_stride.to_vec(),
mnk: (m, n, k),
})?
};
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(Error::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
out_stride: out_stride.to_vec(),
mnk: (m, n, k),
})?
};
let gemm = GemmConfig {
alpha,
beta,
m: n as i32,
n: m as i32,
k: k as i32,
lda,
ldb,
ldc: n as i32,
transa,
transb,
};
let stride_a: usize = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_dims[1] => stride,
[_, stride] if lhs_dims[0] == 1 => stride,
[stride, _] if lhs_dims[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(Error::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
out_stride: out_stride.to_vec(),
mnk: (m, n, k),
})?,
};
let stride_b: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_dims[1] => stride,
[_, stride] if rhs_dims[0] == 1 => stride,
[stride, _] if rhs_dims[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(Error::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
out_stride: out_stride.to_vec(),
mnk: (m, n, k),
})?,
};
let stride_c: usize = match out_stride[..out_stride.len() - 2] {
[s1, stride] if s1 == stride * out_dims[1] => stride,
[_, stride] if out_dims[0] == 1 => stride,
[stride, _] if out_dims[1] == 1 => stride,
[stride] => stride,
[] => m * n,
_ => Err(Error::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
out_stride: out_stride.to_vec(),
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: stride_a as i64,
stride_b: stride_b as i64,
stride_c: stride_c as i64,
})
}