use crate::{
dtypes::*,
shapes::*,
tensor::{cuda::Cuda, Tensor},
};
use cudarc::{
cublas::{
result::CublasError, sys::cublasOperation_t, CudaBlas, Gemm, GemmConfig,
StridedBatchedConfig,
},
driver::{DevicePtr, DevicePtrMut},
};
const TRANS: cublasOperation_t = cublasOperation_t::CUBLAS_OP_T;
const NO_TRANS: cublasOperation_t = cublasOperation_t::CUBLAS_OP_N;
fn gemm_cfg<M: Dim, K: Dim, N: Dim, E: Dtype>(
(m, k, n): (M, K, N),
lhs_strides: [usize; 2],
rhs_strides: [usize; 2],
beta: E,
out_strides: [usize; 2],
) -> (GemmConfig<E>, bool) {
let (lhs_stride, lhs_trans) = super::matrix_strides((m.size(), k.size()), lhs_strides);
let (rhs_stride, rhs_trans) = super::matrix_strides((k.size(), n.size()), rhs_strides);
let (out_stride, out_trans) = super::matrix_strides((m.size(), n.size()), out_strides);
if !out_trans {
let cfg = GemmConfig {
transa: if rhs_trans { TRANS } else { NO_TRANS },
transb: if lhs_trans { TRANS } else { NO_TRANS },
m: n.size() as i32,
n: m.size() as i32,
k: k.size() as i32,
alpha: E::ONE,
lda: rhs_stride as i32,
ldb: lhs_stride as i32,
beta,
ldc: out_stride as i32,
};
(cfg, true)
} else {
let cfg = GemmConfig {
transa: if lhs_trans { NO_TRANS } else { TRANS },
transb: if rhs_trans { NO_TRANS } else { TRANS },
m: m.size() as i32,
n: n.size() as i32,
k: k.size() as i32,
alpha: E::ONE,
lda: lhs_stride as i32,
ldb: rhs_stride as i32,
beta,
ldc: out_stride as i32,
};
(cfg, false)
}
}
#[cfg(feature = "f16")]
impl Gemm<AMP<f16>> for CudaBlas {
unsafe fn gemm<A: DevicePtr<AMP<f16>>, B: DevicePtr<AMP<f16>>, C: DevicePtrMut<AMP<f16>>>(
&self,
cfg: GemmConfig<AMP<f16>>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.alpha.0.to_f32();
let beta: f32 = cfg.beta.0.to_f32();
cudarc::cublas::result::gemm_ex(
*self.handle(),
cfg.transa,
cfg.transb,
cfg.m,
cfg.n,
cfg.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.lda,
*b.device_ptr() as *const _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.ldb,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.ldc,
cudarc::cublas::sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudarc::cublas::sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
unsafe fn gemm_strided_batched<
A: DevicePtr<AMP<f16>>,
B: DevicePtr<AMP<f16>>,
C: DevicePtrMut<AMP<f16>>,
>(
&self,
cfg: StridedBatchedConfig<AMP<f16>>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.gemm.alpha.0.to_f32();
let beta: f32 = cfg.gemm.beta.0.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*self.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
cudarc::cublas::sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
cudarc::cublas::sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cudarc::cublas::sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
}
impl Cuda {
#[allow(clippy::too_many_arguments)]
pub(crate) unsafe fn gemm<
E: Dtype,
M: Dim,
K: Dim,
N: Dim,
A: DevicePtr<E>,
B: DevicePtr<E>,
C: DevicePtrMut<E>,
>(
&self,
(m, k, n): (M, K, N),
lhs: &A,
lhs_strides: [usize; 2],
rhs: &B,
rhs_strides: [usize; 2],
beta: E,
out: &mut C,
out_strides: [usize; 2],
) -> Result<(), CublasError>
where
CudaBlas: Gemm<E>,
{
let (cfg, swap_ops) = gemm_cfg((m, k, n), lhs_strides, rhs_strides, beta, out_strides);
if !swap_ops {
self.blas.gemm(cfg, lhs, rhs, out)
} else {
self.blas.gemm(cfg, rhs, lhs, out)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) unsafe fn gemm_batch<
E: Dtype,
Batch: Dim,
M: Dim,
K: Dim,
N: Dim,
A: DevicePtr<E>,
B: DevicePtr<E>,
C: DevicePtrMut<E>,
>(
&self,
(batch, m, k, n): (Batch, M, K, N),
lhs: &A,
lhs_strides: [usize; 3],
rhs: &B,
rhs_strides: [usize; 3],
beta: E,
out: &mut C,
out_strides: [usize; 3],
) -> Result<(), CublasError>
where
CudaBlas: Gemm<E>,
{
assert_ne!(out_strides[0], 0);
let (gemm, swap_ops) = gemm_cfg(
(m, k, n),
[lhs_strides[1], lhs_strides[2]],
[rhs_strides[1], rhs_strides[2]],
beta,
[out_strides[1], out_strides[2]],
);
if !swap_ops {
let cfg = StridedBatchedConfig {
gemm,
stride_a: lhs_strides[0] as i64,
stride_b: rhs_strides[0] as i64,
stride_c: out_strides[0] as i64,
batch_size: batch.size() as i32,
};
self.blas.gemm_strided_batched(cfg, lhs, rhs, out)
} else {
let cfg = StridedBatchedConfig {
gemm,
stride_a: rhs_strides[0] as i64,
stride_b: lhs_strides[0] as i64,
stride_c: out_strides[0] as i64,
batch_size: batch.size() as i32,
};
self.blas.gemm_strided_batched(cfg, rhs, lhs, out)
}
}
}
impl<E: Dtype> super::MatMatKernel<E> for Cuda
where
CudaBlas: Gemm<E>,
{
fn forward<M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(M, K), E, Self>,
rhs: &Tensor<(K, N), E, Self>,
) -> Result<Tensor<(M, N), E, Self>, Self::Err> {
let (m, _) = lhs.shape;
let (k, n) = rhs.shape;
let shape = (m, n);
let strides = shape.strides();
let mut storage = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
unsafe {
self.gemm(
(m, k, n),
lhs.data.as_ref(),
lhs.strides,
rhs.data.as_ref(),
rhs.strides,
Default::default(),
&mut storage,
strides,
)
}?;
Ok(self.build_tensor(shape, strides, storage))
}
fn backward<M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(M, K), E, Self>,
grad_lhs: &mut Self::Vec,
rhs: &Tensor<(K, N), E, Self>,
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let (m, _) = lhs.shape;
let (k, n) = rhs.shape;
let strides = (m, n).strides();
self.par_stream.wait_for_default()?;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
self.gemm(
(m, n, k),
grad_out,
strides,
rhs.data.as_ref(),
[rhs.strides[1], rhs.strides[0]],
E::ONE,
grad_lhs,
lhs.strides,
)?;
self.blas.set_stream(None)?;
self.gemm(
(k, m, n),
lhs.data.as_ref(),
[lhs.strides[1], lhs.strides[0]],
grad_out,
strides,
E::ONE,
grad_rhs,
rhs.strides,
)?;
}
self.dev.wait_for(self.par_stream.as_ref())?;
Ok(())
}
}
impl<E: Dtype> super::MatMatBrKernel<E> for Cuda
where
CudaBlas: Gemm<E>,
{
fn forward<B: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, M, K), E, Self>,
rhs: &Tensor<(K, N), E, Self>,
) -> Result<Tensor<(B, M, N), E, Self>, Self::Err> {
let (batch, m, _) = lhs.shape;
let (k, n) = rhs.shape;
let shape = (batch, m, n);
let strides = shape.strides();
let mut storage = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
unsafe {
self.gemm_batch(
(batch, m, k, n),
lhs.data.as_ref(),
lhs.strides,
rhs.data.as_ref(),
[0, rhs.strides[0], rhs.strides[1]],
Default::default(),
&mut storage,
strides,
)?;
}
Ok(self.build_tensor(shape, strides, storage))
}
fn backward<B: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, M, K), E, Self>,
grad_lhs: &mut Self::Vec,
rhs: &Tensor<(K, N), E, Self>,
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let (batch, m, _) = lhs.shape;
let (k, n) = rhs.shape;
let strides = (batch, m, n).strides();
self.par_stream.wait_for_default()?;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
for i in 0..batch.size() {
self.gemm(
(k, m, n),
&lhs.data.slice(i * lhs.strides[0]..),
[lhs.strides[2], lhs.strides[1]],
&grad_out.slice(i * strides[0]..),
[strides[1], strides[2]],
E::ONE,
grad_rhs,
rhs.strides,
)?;
}
self.blas.set_stream(None)?;
self.gemm_batch(
(batch, m, n, k),
grad_out,
strides,
rhs.data.as_ref(),
[0, rhs.strides[1], rhs.strides[0]],
E::ONE,
grad_lhs,
lhs.strides,
)?;
}
self.dev.wait_for(self.par_stream.as_ref())?;
Ok(())
}
}
impl<E: Dtype> super::MatMatBatch3Kernel<E> for Cuda
where
CudaBlas: Gemm<E>,
{
fn forward<B: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, M, K), E, Self>,
rhs: &Tensor<(B, K, N), E, Self>,
) -> Result<Tensor<(B, M, N), E, Self>, Self::Err> {
assert_ne!(lhs.strides[0], 0);
assert_ne!(rhs.strides[0], 0);
let (batch, m, _) = lhs.shape;
let (_, k, n) = rhs.shape;
let shape = (batch, m, n);
let strides = shape.strides();
let mut storage = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
unsafe {
self.gemm_batch(
(batch, m, k, n),
lhs.data.as_ref(),
lhs.strides,
rhs.data.as_ref(),
rhs.strides,
Default::default(),
&mut storage,
strides,
)?;
}
Ok(self.build_tensor(shape, strides, storage))
}
fn backward<B: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, M, K), E, Self>,
grad_lhs: &mut Self::Vec,
rhs: &Tensor<(B, K, N), E, Self>,
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let (batch, m, _) = lhs.shape;
let (_, k, n) = rhs.shape;
let strides = (batch, m, n).strides();
self.par_stream.wait_for_default()?;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
self.gemm_batch(
(batch, m, n, k),
grad_out,
strides,
rhs.data.as_ref(),
[rhs.strides[0], rhs.strides[2], rhs.strides[1]],
E::ONE,
grad_lhs,
lhs.strides,
)?;
self.blas.set_stream(None)?;
self.gemm_batch(
(batch, k, m, n),
lhs.data.as_ref(),
[lhs.strides[0], lhs.strides[2], lhs.strides[1]],
grad_out,
strides,
E::ONE,
grad_rhs,
rhs.strides,
)?;
}
self.dev.wait_for(self.par_stream.as_ref())?;
Ok(())
}
}
impl<E: Dtype> super::MatMatBatch4Kernel<E> for Cuda
where
CudaBlas: Gemm<E>,
{
fn forward<B: Dim, S: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, S, M, K), E, Self>,
rhs: &Tensor<(B, S, K, N), E, Self>,
) -> Result<Tensor<(B, S, M, N), E, Self>, Self::Err> {
assert_ne!(lhs.strides[0], 0);
assert_ne!(rhs.strides[0], 0);
assert_ne!(lhs.strides[1], 0);
assert_ne!(rhs.strides[1], 0);
let (batch, seq, m, _) = lhs.shape;
let (_, _, k, n) = rhs.shape;
let shape = (batch, seq, m, n);
let strides = shape.strides();
let mut storage = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
let half_batch = batch.size() / 2;
self.par_stream.wait_for_default()?;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
for b in 0..half_batch {
self.gemm_batch(
(seq, m, k, n),
&lhs.data.slice(b * lhs.strides[0]..),
[lhs.strides[1], lhs.strides[2], lhs.strides[3]],
&rhs.data.slice(b * rhs.strides[0]..),
[rhs.strides[1], rhs.strides[2], rhs.strides[3]],
Default::default(),
&mut storage.slice_mut(b * strides[0]..),
[strides[1], strides[2], strides[3]],
)?;
}
self.blas.set_stream(None)?;
for b in half_batch..batch.size() {
self.gemm_batch(
(seq, m, k, n),
&lhs.data.slice(b * lhs.strides[0]..),
[lhs.strides[1], lhs.strides[2], lhs.strides[3]],
&rhs.data.slice(b * rhs.strides[0]..),
[rhs.strides[1], rhs.strides[2], rhs.strides[3]],
Default::default(),
&mut storage.slice_mut(b * strides[0]..),
[strides[1], strides[2], strides[3]],
)?;
}
}
self.dev.wait_for(self.par_stream.as_ref())?;
Ok(self.build_tensor(shape, strides, storage))
}
fn backward<B: Dim, S: Dim, M: Dim, K: Dim, N: Dim>(
&self,
lhs: &Tensor<(B, S, M, K), E, Self>,
grad_lhs: &mut Self::Vec,
rhs: &Tensor<(B, S, K, N), E, Self>,
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let (batch, seq, m, _) = lhs.shape;
let (_, _, k, n) = rhs.shape;
let strides = (batch, seq, m, n).strides();
self.par_stream.wait_for_default()?;
unsafe {
self.blas.set_stream(Some(self.par_stream.as_ref()))?;
for i in 0..batch.size() {
self.gemm_batch(
(seq, m, n, k),
&grad_out.slice(i * strides[0]..),
[strides[1], strides[2], strides[3]],
&rhs.data.slice(i * rhs.strides[0]..),
[rhs.strides[1], rhs.strides[3], rhs.strides[2]],
E::ONE,
&mut grad_lhs.slice_mut(i * lhs.strides[0]..),
[lhs.strides[1], lhs.strides[2], lhs.strides[3]],
)?;
}
self.blas.set_stream(None)?;
for i in 0..batch.size() {
self.gemm_batch(
(seq, k, m, n),
&lhs.data.slice(i * lhs.strides[0]..),
[lhs.strides[1], lhs.strides[3], lhs.strides[2]],
&grad_out.slice(i * strides[0]..),
[strides[1], strides[2], strides[3]],
E::ONE,
&mut grad_rhs.slice_mut(i * rhs.strides[0]..),
[rhs.strides[1], rhs.strides[2], rhs.strides[3]],
)?;
}
}
self.dev.wait_for(self.par_stream.as_ref())?;
Ok(())
}
}