#[cfg(feature = "cuda")]
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig, sys};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64};
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul(a, b, m, k, n, device);
}
let blas = device.blas();
let mut c = alloc_zeros_f32(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f64(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64", expected: vec![i32::MAX as usize], got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64", expected: vec![i32::MAX as usize], got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64", expected: vec![i32::MAX as usize], got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f64(batch * m * n, device)?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(());
}
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul_into(a, b, m, k, n, c, device);
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_bmm_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return Ok(());
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_f16 = crate::kernels::gpu_f32_to_f16(a, device)?;
let b_f16 = crate::kernels::gpu_f32_to_f16(b, device)?;
let mut c = alloc_zeros_f32(m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _record_a) = a_f16.device_ptr(&stream);
let (b_ptr, _record_b) = b_f16.device_ptr(&stream);
let (c_ptr, _record_c) = c.inner_mut().device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32, m_i32, k_i32, (&alpha) as *const f32 as *const c_void, b_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, n_i32, a_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, k_i32, (&beta) as *const f32 as *const c_void, c_ptr as *mut c_void, cublas_sys::cudaDataType_t::CUDA_R_32F, n_i32, cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, )?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_f16 = crate::kernels::gpu_f32_to_f16(a, device)?;
let b_f16 = crate::kernels::gpu_f32_to_f16(b, device)?;
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a_f16.device_ptr(&stream);
let (b_ptr, _rb) = b_f16.device_ptr(&stream);
let (c_ptr, _rc) = c.inner_mut().device_ptr_mut(&stream);
let stride_a_f16 = (k * n) as i64; let stride_b_f16 = (m * k) as i64; let stride_c = (m * n) as i64;
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16F,
n_i32,
stride_a_f16,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16F,
k_i32,
stride_b_f16,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_32F,
n_i32,
stride_c,
batch as i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_bf16 = crate::kernels::gpu_f32_to_bf16(a, device)?;
let b_bf16 = crate::kernels::gpu_f32_to_bf16(b, device)?;
let mut c = alloc_zeros_f32(m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a_bf16.device_ptr(&stream);
let (b_ptr, _rb) = b_bf16.device_ptr(&stream);
let (c_ptr, _rc) = c.inner_mut().device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_32F,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16_bf16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() < m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() < k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(m * n)?);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let mut c = device.stream().alloc_zeros::<u16>(m * n)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16_bf16(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16_bf16_nt(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() < m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() < n * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![n, k],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(m * n)?);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let mut c = device.stream().alloc_zeros::<u16>(m * n)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_T,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16_bf16_nt(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched_nt(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
batch_count: usize,
stride_a_elems: usize,
stride_b_elems: usize,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch_count == 0 || m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(batch_count * m * n)?);
}
let (m_i32, k_i32, n_i32, bc_i32) = (m as i32, k as i32, n as i32, batch_count as i32);
let mut c = device.stream().alloc_zeros::<u16>(batch_count * m * n)?;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_T,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_b_elems as i64,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_a_elems as i64,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
(m * n) as i64,
bc_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
batch_count: usize,
stride_a_elems: usize,
stride_b_elems: usize,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch_count == 0 || m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(batch_count * m * n)?);
}
let (m_i32, k_i32, n_i32, bc_i32) = (m as i32, k as i32, n as i32, batch_count as i32);
let mut c = device.stream().alloc_zeros::<u16>(batch_count * m * n)?;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
stride_b_elems as i64,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_a_elems as i64,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
(m * n) as i64,
bc_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched_nt(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_batch_count: usize,
_stride_a_elems: usize,
_stride_b_elems: usize,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_batch_count: usize,
_stride_a_elems: usize,
_stride_b_elems: usize,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
fn cpu_matmul_naive<T>(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Vec<T>
where
T: Copy + Default + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let mut c = vec![T::default(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = T::default();
for p in 0..k {
sum = sum + a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
c
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
use crate::device::GpuDevice;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
fn setup_f32(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_close_f32(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32], tol: f32) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
fn assert_buf_close_f64(buf: &CudaBuffer<f64>, device: &GpuDevice, expected: &[f64], tol: f64) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
#[test]
fn matmul_f32_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-4);
}
#[test]
fn matmul_f32_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f32(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &a_data, 1e-6);
}
#[test]
fn matmul_f32_row_vector() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0];
let b_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let expected: Vec<f32> = vec![4.0, 5.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 1, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 2);
assert_buf_close_f32(&c, &dev, &expected, 1e-6);
}
#[test]
fn matmul_f32_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f32 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f64_basic() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f64> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f64> = vec![58.0, 64.0, 139.0, 154.0];
let a = cpu_to_gpu(&a_data, &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f64(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f64");
assert_eq!(c.len(), 4);
assert_buf_close_f64(&c, &dev, &expected, 1e-10);
}
#[test]
fn matmul_f32_vs_cpu() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, m, k, n, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &expected, 1e-3);
}
#[test]
fn matmul_f32_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let gpu_start = std::time::Instant::now();
let _c = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let gpu_elapsed = gpu_start.elapsed();
let cpu_start = std::time::Instant::now();
let _c_cpu = cpu_matmul_naive(&a_data, &b_data, dim, dim, dim);
let cpu_elapsed = cpu_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: GPU = {:.3}ms, CPU = {:.3}ms, speedup = {:.1}x",
gpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() / gpu_elapsed.as_secs_f64(),
);
}
#[test]
fn matmul_f16_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f16");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-2);
}
#[test]
fn matmul_f16_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f16(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f16");
assert_buf_close_f32(&c, &dev, &a_data, 1e-3);
}
#[test]
fn matmul_f16_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f16 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f16_k_zero() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 0, 3, &dev).expect("gpu_matmul_f16 k=0");
assert_eq!(c.len(), 6);
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert!(host.iter().all(|&x| x == 0.0));
}
#[test]
fn matmul_f16_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_vs_f32_reference() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, m, k, n, &dev).expect("gpu_matmul_f16");
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "output length mismatch");
let mut max_err: f32 = 0.0;
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
let abs_err = (got - exp).abs();
let rel_err = if exp.abs() > 1e-6 {
abs_err / exp.abs()
} else {
abs_err
};
max_err = max_err.max(abs_err);
assert!(
rel_err < 0.05 || abs_err < 0.1,
"element {i}: f16 got {got}, f32 ref {exp}, abs_err {abs_err}, rel_err {rel_err}",
);
}
eprintln!("matmul_f16_vs_f32: {m}x{k} @ {k}x{n}, max absolute error = {max_err:.6}",);
}
#[test]
fn matmul_f16_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let f32_start = std::time::Instant::now();
let _c32 = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let f32_elapsed = f32_start.elapsed();
let f16_start = std::time::Instant::now();
let _c16 = gpu_matmul_f16(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f16");
let f16_elapsed = f16_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: f32 = {:.3}ms, f16 = {:.3}ms, f16 speedup = {:.1}x",
f32_elapsed.as_secs_f64() * 1000.0,
f16_elapsed.as_secs_f64() * 1000.0,
f32_elapsed.as_secs_f64() / f16_elapsed.as_secs_f64(),
);
}
fn upload_as_bf16(
dev: &GpuDevice,
data: &[f32],
) -> cudarc::driver::CudaSlice<u16> {
let u16_data: Vec<u16> = data
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
dev.stream()
.clone_htod(&u16_data)
.expect("bf16 upload")
}
fn download_bf16_as_f32(
dev: &GpuDevice,
buf: &cudarc::driver::CudaSlice<u16>,
) -> Vec<f32> {
let bits: Vec<u16> = dev.stream().clone_dtoh(buf).expect("bf16 download");
bits.into_iter()
.map(|b| half::bf16::from_bits(b).to_f32())
.collect()
}
#[test]
fn matmul_bf16_bf16_basic_2x3_3x2() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_bf16_bf16");
let got = download_bf16_as_f32(&dev, &c);
assert_eq!(got.len(), expected.len());
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
assert!(
(g - e).abs() < 1.0,
"element {i}: got {g}, expected {e}, diff {}",
(g - e).abs(),
);
}
}
#[test]
fn matmul_bf16_bf16_identity_rows() {
let i_data: Vec<f32> = {
let mut v = vec![0.0f32; 16];
for d in 0..4 {
v[d * 4 + d] = 1.0;
}
v
};
let x_data: Vec<f32> = (0..12).map(|i| (i as f32) - 6.0).collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let i = upload_as_bf16(&dev, &i_data);
let x = upload_as_bf16(&dev, &x_data);
let c = gpu_matmul_bf16_bf16(&i, &x, 4, 4, 3, &dev).expect("matmul");
let got = download_bf16_as_f32(&dev, &c);
for (a, b) in got.iter().zip(x_data.iter()) {
assert_eq!(a, b, "identity matmul must preserve X exactly");
}
}
#[test]
fn matmul_bf16_bf16_nt_basic_2x3_2x3() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![50.0, 68.0, 122.0, 167.0];
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16_nt(&a, &b, 2, 3, 2, &dev).expect("matmul_nt");
let got = download_bf16_as_f32(&dev, &c);
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
assert!(
(g - e).abs() <= e.abs() * 0.02 + 1.0,
"nt[{i}]: got {g}, expected {e}",
);
}
}
#[test]
fn matmul_bf16_bf16_nt_equivalent_to_explicit_transpose() {
let m = 4;
let k = 3;
let n = 5;
let a_data: Vec<f32> = (0..m * k).map(|i| (i as f32) * 0.5 - 2.0).collect();
let b_data: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.25 + 1.0).collect();
let mut b_t: Vec<f32> = vec![0.0; k * n];
for i in 0..n {
for j in 0..k {
b_t[j * n + i] = b_data[i * k + j];
}
}
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let bt = upload_as_bf16(&dev, &b_t);
let c_nt = gpu_matmul_bf16_bf16_nt(&a, &b, m, k, n, &dev).unwrap();
let c_ref = gpu_matmul_bf16_bf16(&a, &bt, m, k, n, &dev).unwrap();
let nt = download_bf16_as_f32(&dev, &c_nt);
let rf = download_bf16_as_f32(&dev, &c_ref);
for (i, (&a, &b)) in nt.iter().zip(rf.iter()).enumerate() {
assert!(
(a - b).abs() < 0.01,
"nt[{i}]={a} vs ref[{i}]={b}",
);
}
}
#[test]
fn matmul_bf16_strided_batched_matches_per_batch_reference() {
let dev = GpuDevice::new(0).expect("cuda");
let a0: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b0: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let a1: Vec<f32> = vec![0.5, 0.5, 0.5, 1.0, 1.0, 1.0];
let b1: Vec<f32> = vec![2.0, 2.0, 2.0, -1.0, -1.0, -1.0];
let a: Vec<f32> = [&a0[..], &a1[..]].concat();
let b: Vec<f32> = [&b0[..], &b1[..]].concat();
let a_gpu = upload_as_bf16(&dev, &a);
let b_gpu = upload_as_bf16(&dev, &b);
let c = gpu_matmul_bf16_bf16_strided_batched_nt(
&a_gpu, &b_gpu, 2, 3, 2, 2, 6, 6, 0.5, &dev,
)
.expect("batched");
let got = download_bf16_as_f32(&dev, &c);
let ref0 = gpu_matmul_bf16_bf16_nt(
&upload_as_bf16(&dev, &a0),
&upload_as_bf16(&dev, &b0),
2,
3,
2,
&dev,
)
.unwrap();
let ref1 = gpu_matmul_bf16_bf16_nt(
&upload_as_bf16(&dev, &a1),
&upload_as_bf16(&dev, &b1),
2,
3,
2,
&dev,
)
.unwrap();
let expected0 = download_bf16_as_f32(&dev, &ref0);
let expected1 = download_bf16_as_f32(&dev, &ref1);
for (i, (&g, &e)) in got[..4].iter().zip(expected0.iter()).enumerate() {
let scaled = e * 0.5;
assert!(
(g - scaled).abs() < scaled.abs() * 0.05 + 0.1,
"b0[{i}]: got {g}, expected {scaled}",
);
}
for (i, (&g, &e)) in got[4..].iter().zip(expected1.iter()).enumerate() {
let scaled = e * 0.5;
assert!(
(g - scaled).abs() < scaled.abs() * 0.05 + 0.1,
"b1[{i}]: got {g}, expected {scaled}",
);
}
}
#[test]
fn matmul_bf16_bf16_large_dims_finite() {
let dim = 512;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| (((i * 7 + 13) % 1000) as f32 / 1000.0) - 0.5)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| (((i * 11 + 3) % 1000) as f32 / 1000.0) - 0.5)
.collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16(&a, &b, dim, dim, dim, &dev).expect("matmul");
let got = download_bf16_as_f32(&dev, &c);
assert_eq!(got.len(), dim * dim);
for &v in &got {
assert!(v.is_finite(), "non-finite output in bf16 matmul");
}
}
}