#[cfg(feature = "cuda")]
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig, sys};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64};
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul(a, b, m, k, n, device);
}
let blas = device.blas();
let mut c = alloc_zeros_f32(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f64(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(());
}
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul_into(a, b, m, k, n, c, device);
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_bmm_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return Ok(());
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_f16 = crate::kernels::gpu_f32_to_f16(a, device)?;
let b_f16 = crate::kernels::gpu_f32_to_f16(b, device)?;
let mut c = alloc_zeros_f32(m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _record_a) = a_f16.device_ptr(stream);
let (b_ptr, _record_b) = b_f16.device_ptr(stream);
let (c_ptr, _record_c) = c.inner_mut().device_ptr_mut(stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32, m_i32, k_i32, (&alpha) as *const f32 as *const c_void, b_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, n_i32, a_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, k_i32, (&beta) as *const f32 as *const c_void, c_ptr as *mut c_void, cublas_sys::cudaDataType_t::CUDA_R_32F, n_i32, cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, )?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
fn cpu_matmul_naive<T>(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Vec<T>
where
T: Copy + Default + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let mut c = vec![T::default(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = T::default();
for p in 0..k {
sum = sum + a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
c
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
use crate::device::GpuDevice;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
fn setup_f32(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_close_f32(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32], tol: f32) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
fn assert_buf_close_f64(buf: &CudaBuffer<f64>, device: &GpuDevice, expected: &[f64], tol: f64) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
#[test]
fn matmul_f32_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-4);
}
#[test]
fn matmul_f32_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f32(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &a_data, 1e-6);
}
#[test]
fn matmul_f32_row_vector() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0];
let b_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let expected: Vec<f32> = vec![4.0, 5.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 1, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 2);
assert_buf_close_f32(&c, &dev, &expected, 1e-6);
}
#[test]
fn matmul_f32_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f32 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f64_basic() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f64> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f64> = vec![58.0, 64.0, 139.0, 154.0];
let a = cpu_to_gpu(&a_data, &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f64(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f64");
assert_eq!(c.len(), 4);
assert_buf_close_f64(&c, &dev, &expected, 1e-10);
}
#[test]
fn matmul_f32_vs_cpu() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, m, k, n, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &expected, 1e-3);
}
#[test]
fn matmul_f32_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let gpu_start = std::time::Instant::now();
let _c = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let gpu_elapsed = gpu_start.elapsed();
let cpu_start = std::time::Instant::now();
let _c_cpu = cpu_matmul_naive(&a_data, &b_data, dim, dim, dim);
let cpu_elapsed = cpu_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: GPU = {:.3}ms, CPU = {:.3}ms, speedup = {:.1}x",
gpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() / gpu_elapsed.as_secs_f64(),
);
}
#[test]
fn matmul_f16_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f16");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-2);
}
#[test]
fn matmul_f16_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f16(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f16");
assert_buf_close_f32(&c, &dev, &a_data, 1e-3);
}
#[test]
fn matmul_f16_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f16 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f16_k_zero() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 0, 3, &dev).expect("gpu_matmul_f16 k=0");
assert_eq!(c.len(), 6);
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert!(host.iter().all(|&x| x == 0.0));
}
#[test]
fn matmul_f16_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_vs_f32_reference() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, m, k, n, &dev).expect("gpu_matmul_f16");
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "output length mismatch");
let mut max_err: f32 = 0.0;
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
let abs_err = (got - exp).abs();
let rel_err = if exp.abs() > 1e-6 {
abs_err / exp.abs()
} else {
abs_err
};
max_err = max_err.max(abs_err);
assert!(
rel_err < 0.05 || abs_err < 0.1,
"element {i}: f16 got {got}, f32 ref {exp}, abs_err {abs_err}, rel_err {rel_err}",
);
}
eprintln!("matmul_f16_vs_f32: {m}x{k} @ {k}x{n}, max absolute error = {max_err:.6}",);
}
#[test]
fn matmul_f16_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let f32_start = std::time::Instant::now();
let _c32 = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let f32_elapsed = f32_start.elapsed();
let f16_start = std::time::Instant::now();
let _c16 = gpu_matmul_f16(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f16");
let f16_elapsed = f16_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: f32 = {:.3}ms, f16 = {:.3}ms, f16 speedup = {:.1}x",
f32_elapsed.as_secs_f64() * 1000.0,
f16_elapsed.as_secs_f64() * 1000.0,
f32_elapsed.as_secs_f64() / f16_elapsed.as_secs_f64(),
);
}
}