use std::vec::Vec;
#[allow(unused_imports)]
use crate::{
cublas::{result::CublasError, sys, CudaBlas},
driver::{CudaSlice, DevicePtr, DevicePtrMut, DeviceRepr},
};
pub trait GroupedGemmDtype: DeviceRepr {
type ComputeType: DeviceRepr + Copy;
fn data_type() -> sys::cudaDataType_t;
fn compute_type() -> sys::cublasComputeType_t;
}
#[cfg(feature = "f16")]
impl GroupedGemmDtype for half::f16 {
type ComputeType = f32;
fn data_type() -> sys::cudaDataType_t {
sys::cudaDataType_t::CUDA_R_16F
}
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
}
}
#[cfg(feature = "f16")]
impl GroupedGemmDtype for half::bf16 {
type ComputeType = f32;
fn data_type() -> sys::cudaDataType_t {
sys::cudaDataType_t::CUDA_R_16BF
}
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
}
}
pub struct GroupedGemmConfig<T: GroupedGemmDtype> {
pub transbs: Vec<sys::cublasOperation_t>,
pub transas: Vec<sys::cublasOperation_t>,
pub ms: Vec<usize>,
pub ns: Vec<usize>,
pub ks: Vec<usize>,
pub alphas: Vec<T::ComputeType>,
pub betas: Vec<T::ComputeType>,
pub ldas: Vec<usize>,
pub ldbs: Vec<usize>,
pub ldcs: Vec<usize>,
pub problem_sizes: Vec<usize>,
}
impl<T: GroupedGemmDtype> GroupedGemmConfig<T> {
pub fn problem_count(&self) -> usize {
self.problem_sizes.iter().sum()
}
pub fn group_count(&self) -> usize {
self.problem_sizes.len()
}
#[inline]
pub fn validate(&self) {
let group_count = self.group_count();
assert_eq!(self.transbs.len(), group_count);
assert_eq!(self.transas.len(), group_count);
assert_eq!(self.ms.len(), group_count);
assert_eq!(self.ns.len(), group_count);
assert_eq!(self.ks.len(), group_count);
assert_eq!(self.alphas.len(), group_count);
assert_eq!(self.betas.len(), group_count);
assert_eq!(self.ldas.len(), group_count);
assert_eq!(self.ldbs.len(), group_count);
assert_eq!(self.ldcs.len(), group_count);
}
}
pub trait GroupedGemm<T: GroupedGemmDtype> {
unsafe fn grouped_gemm<A: DevicePtr<T>, B: DevicePtr<T>, C: DevicePtrMut<T>>(
&self,
config: GroupedGemmConfig<T>,
a_slices: &[&A],
b_slices: &[&B],
c_slices: &mut [&mut C],
) -> Result<(), CublasError>;
}
impl<T: GroupedGemmDtype> GroupedGemm<T> for CudaBlas {
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
))]
unsafe fn grouped_gemm<A: DevicePtr<T>, B: DevicePtr<T>, C: DevicePtrMut<T>>(
&self,
_config: GroupedGemmConfig<T>,
_a_slices: &[&A],
_b_slices: &[&B],
_c_slices: &mut [&mut C],
) -> Result<(), CublasError> {
panic!("cublas GroupedGemm requires cuda 12.5+");
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
)))]
unsafe fn grouped_gemm<A: DevicePtr<T>, B: DevicePtr<T>, C: DevicePtrMut<T>>(
&self,
config: GroupedGemmConfig<T>,
a_slices: &[&A],
b_slices: &[&B],
c_slices: &mut [&mut C],
) -> Result<(), CublasError> {
config.validate();
assert_eq!(a_slices.len(), config.problem_count());
assert_eq!(b_slices.len(), config.problem_count());
assert_eq!(c_slices.len(), config.problem_count());
let (a_ptrs, _a_guard_vec): (Vec<u64>, Vec<_>) =
a_slices.iter().map(|s| s.device_ptr(&self.stream)).unzip();
let (b_ptrs, _b_guard_vec): (Vec<u64>, Vec<_>) =
b_slices.iter().map(|s| s.device_ptr(&self.stream)).unzip();
let (mut c_ptrs, _c_guard_vec): (Vec<u64>, Vec<_>) = c_slices
.iter_mut()
.map(|s| s.device_ptr_mut(&self.stream))
.unzip();
let cuda_dtype = T::data_type();
let group_count = config.group_count();
let alpha_f32: Vec<T::ComputeType> =
config.alphas.iter().map(|x| *x as T::ComputeType).collect();
let beta_f32: Vec<T::ComputeType> =
config.betas.iter().map(|x| *x as T::ComputeType).collect();
let m_array: Vec<i32> = config.ms.iter().map(|&x| x as i32).collect();
let n_array: Vec<i32> = config.ns.iter().map(|&x| x as i32).collect();
let k_array: Vec<i32> = config.ks.iter().map(|&x| x as i32).collect();
let lda_array: Vec<i32> = config.ldas.iter().map(|&x| x as i32).collect();
let ldb_array: Vec<i32> = config.ldbs.iter().map(|&x| x as i32).collect();
let ldc_array: Vec<i32> = config.ldcs.iter().map(|&x| x as i32).collect();
let group_size: Vec<i32> = config.problem_sizes.iter().map(|&x| x as i32).collect();
unsafe {
sys::cublasGemmGroupedBatchedEx(
self.handle,
config.transas.as_ptr(),
config.transbs.as_ptr(),
m_array.as_ptr(),
n_array.as_ptr(),
k_array.as_ptr(),
alpha_f32.as_ptr() as _,
a_ptrs.as_ptr() as _,
cuda_dtype,
lda_array.as_ptr(),
b_ptrs.as_ptr() as _,
cuda_dtype,
ldb_array.as_ptr(),
beta_f32.as_ptr() as _,
c_ptrs.as_mut_ptr() as _,
cuda_dtype,
ldc_array.as_ptr(),
group_count as i32,
group_size.as_ptr(),
T::compute_type(),
)
.result()?;
};
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(unused)]
use super::*;
use crate::driver::CudaContext;
use std::vec;
#[test]
#[cfg(feature = "f16")]
fn test_grouped_gemm_f16() {
use half::f16;
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let handle = CudaBlas::new(stream.clone()).unwrap();
let a0_host = [1.0, 3.0, 2.0, 4.0].map(f16::from_f32);
let b0_host = [5.0, 7.0, 6.0, 8.0].map(f16::from_f32);
let a1_host = [5.0, 7.0, 6.0, 8.0].map(f16::from_f32);
let b1_host = [9.0, 11.0, 10.0, 12.0].map(f16::from_f32);
let a2_host = [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0].map(f16::from_f32);
let b2_host = [4.0, 7.0, 10.0, 5.0, 8.0, 11.0, 6.0, 9.0, 12.0].map(f16::from_f32);
let a0 = stream.clone_htod(&a0_host).unwrap();
let b0 = stream.clone_htod(&b0_host).unwrap();
let a1 = stream.clone_htod(&a1_host).unwrap();
let b1 = stream.clone_htod(&b1_host).unwrap();
let a2 = stream.clone_htod(&a2_host).unwrap();
let b2 = stream.clone_htod(&b2_host).unwrap();
let mut c0 = stream.alloc_zeros::<f16>(4).unwrap();
let mut c1 = stream.alloc_zeros::<f16>(4).unwrap();
let mut c2 = stream.alloc_zeros::<f16>(9).unwrap();
let config = GroupedGemmConfig {
transbs: vec![sys::cublasOperation_t::CUBLAS_OP_N; 2],
transas: vec![sys::cublasOperation_t::CUBLAS_OP_N; 2],
ms: vec![2, 3],
ns: vec![2, 3],
ks: vec![2, 3],
alphas: vec![1.0; 2],
betas: vec![0.0; 2],
ldas: vec![2, 3],
ldbs: vec![2, 3],
ldcs: vec![2, 3],
problem_sizes: vec![2, 1],
};
unsafe {
handle
.grouped_gemm(
config,
&[&a0, &a1, &a2],
&[&b0, &b1, &b2],
&mut [&mut c0, &mut c1, &mut c2],
)
.unwrap();
}
let c0_host = stream.clone_dtoh(&c0).unwrap();
let c1_host = stream.clone_dtoh(&c1).unwrap();
let c2_host = stream.clone_dtoh(&c2).unwrap();
let expected_c0 = [19.0, 43.0, 22.0, 50.0].map(f16::from_f32);
let expected_c1 = [111.0, 151.0, 122.0, 166.0].map(f16::from_f32);
let expected_c2 =
[48.0, 111.0, 174.0, 54.0, 126.0, 198.0, 60.0, 141.0, 222.0].map(f16::from_f32);
assert_eq!(c0_host, expected_c0);
assert_eq!(c1_host, expected_c1);
assert_eq!(c2_host, expected_c2);
}
#[test]
#[cfg(feature = "f16")]
fn test_grouped_gemm_raw_bf16() {
use half::bf16;
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let handle = CudaBlas::new(stream.clone()).unwrap();
let a0_host = [1.0, 3.0, 2.0, 4.0].map(bf16::from_f32);
let b0_host = [5.0, 7.0, 6.0, 8.0].map(bf16::from_f32);
let a1_host = [5.0, 7.0, 6.0, 8.0].map(bf16::from_f32);
let b1_host = [9.0, 11.0, 10.0, 12.0].map(bf16::from_f32);
let a2_host = [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0].map(bf16::from_f32);
let b2_host = [4.0, 7.0, 10.0, 5.0, 8.0, 11.0, 6.0, 9.0, 12.0].map(bf16::from_f32);
let a0 = stream.clone_htod(&a0_host).unwrap();
let b0 = stream.clone_htod(&b0_host).unwrap();
let a1 = stream.clone_htod(&a1_host).unwrap();
let b1 = stream.clone_htod(&b1_host).unwrap();
let a2 = stream.clone_htod(&a2_host).unwrap();
let b2 = stream.clone_htod(&b2_host).unwrap();
let mut c0 = stream.alloc_zeros::<bf16>(4).unwrap();
let mut c1 = stream.alloc_zeros::<bf16>(4).unwrap();
let mut c2 = stream.alloc_zeros::<bf16>(9).unwrap();
let config = GroupedGemmConfig {
transbs: vec![sys::cublasOperation_t::CUBLAS_OP_N; 2],
transas: vec![sys::cublasOperation_t::CUBLAS_OP_N; 2],
ms: vec![2, 3],
ns: vec![2, 3],
ks: vec![2, 3],
alphas: vec![1.0; 2],
betas: vec![0.0; 2],
ldas: vec![2, 3],
ldbs: vec![2, 3],
ldcs: vec![2, 3],
problem_sizes: vec![2, 1],
};
unsafe {
handle
.grouped_gemm(
config,
&[&a0, &a1, &a2],
&[&b0, &b1, &b2],
&mut [&mut c0, &mut c1, &mut c2],
)
.unwrap();
}
let c0_host = stream.clone_dtoh(&c0).unwrap();
let c1_host = stream.clone_dtoh(&c1).unwrap();
let c2_host = stream.clone_dtoh(&c2).unwrap();
let expected_c0 = [19.0, 43.0, 22.0, 50.0].map(bf16::from_f32);
let expected_c1 = [111.0, 151.0, 122.0, 166.0].map(bf16::from_f32);
let expected_c2 =
[48.0, 111.0, 174.0, 54.0, 126.0, 198.0, 60.0, 141.0, 222.0].map(bf16::from_f32);
assert_eq!(c0_host, expected_c0);
assert_eq!(c1_host, expected_c1);
assert_eq!(c2_host, expected_c2);
}
}