#[cfg(feature = "cuda")]
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros, cpu_to_gpu, gpu_to_cpu};
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros::<f32>(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
match CudaBlas::new(device.stream().clone()) {
Ok(blas) => {
let mut c = alloc_zeros::<f32>(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
Err(_) => {
eprintln!(
"ferrotorch-gpu: cuBLAS handle creation failed, \
falling back to CPU matmul for [{m}x{k}] @ [{k}x{n}]"
);
cpu_matmul_fallback_f32(a, b, m, k, n, device)
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros::<f64>(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
match CudaBlas::new(device.stream().clone()) {
Ok(blas) => {
let mut c = alloc_zeros::<f64>(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
Err(_) => {
eprintln!(
"ferrotorch-gpu: cuBLAS handle creation failed, \
falling back to CPU matmul for [{m}x{k}] @ [{k}x{n}]"
);
cpu_matmul_fallback_f64(a, b, m, k, n, device)
}
}
}
#[cfg(feature = "cuda")]
fn cpu_matmul_fallback_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let c_host = cpu_matmul_naive(&a_host, &b_host, m, k, n);
cpu_to_gpu(&c_host, device)
}
#[cfg(feature = "cuda")]
fn cpu_matmul_fallback_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_host = gpu_to_cpu(a, device)?;
let b_host = gpu_to_cpu(b, device)?;
let c_host = cpu_matmul_naive(&a_host, &b_host, m, k, n);
cpu_to_gpu(&c_host, device)
}
fn cpu_matmul_naive<T>(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Vec<T>
where
T: Copy + Default + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let mut c = vec![T::default(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = T::default();
for p in 0..k {
sum = sum + a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
c
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
use crate::device::GpuDevice;
fn setup_f32(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_close_f32(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32], tol: f32) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
fn assert_buf_close_f64(buf: &CudaBuffer<f64>, device: &GpuDevice, expected: &[f64], tol: f64) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
#[test]
fn matmul_f32_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-4);
}
#[test]
fn matmul_f32_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f32(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &a_data, 1e-6);
}
#[test]
fn matmul_f32_row_vector() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0];
let b_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let expected: Vec<f32> = vec![4.0, 5.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 1, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 2);
assert_buf_close_f32(&c, &dev, &expected, 1e-6);
}
#[test]
fn matmul_f32_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f32 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f64_basic() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f64> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f64> = vec![58.0, 64.0, 139.0, 154.0];
let a = cpu_to_gpu(&a_data, &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f64(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f64");
assert_eq!(c.len(), 4);
assert_buf_close_f64(&c, &dev, &expected, 1e-10);
}
#[test]
fn matmul_f32_vs_cpu() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, m, k, n, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &expected, 1e-3);
}
#[test]
fn matmul_f32_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let gpu_start = std::time::Instant::now();
let _c = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let gpu_elapsed = gpu_start.elapsed();
let cpu_start = std::time::Instant::now();
let _c_cpu = cpu_matmul_naive(&a_data, &b_data, dim, dim, dim);
let cpu_elapsed = cpu_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: GPU = {:.3}ms, CPU = {:.3}ms, speedup = {:.1}x",
gpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() / gpu_elapsed.as_secs_f64(),
);
}
}