#[cfg(feature = "cuda")]
use cudarc::cublas::{Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig, sys};
#[cfg(feature = "cuda")]
use cudarc::driver::DevicePtr;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64};
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul(a, b, m, k, n, device);
}
let blas = device.blas();
let mut c = alloc_zeros_f32(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f64(m * n, device)?;
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut c = alloc_zeros_f64(batch * m * n, device)?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
fn flat_offset(flat: usize, src_lead: &[usize], out_lead: &[usize]) -> usize {
if src_lead.is_empty() {
return 0;
}
let nd = out_lead.len();
let prefix = nd - src_lead.len();
let mut idx = vec![0usize; nd];
let mut rem = flat;
for i in 0..nd {
let stride: usize = out_lead[i + 1..].iter().product();
let s = stride.max(1);
idx[i] = rem / s;
rem %= s;
}
let mut src_strides = vec![1usize; src_lead.len()];
if src_lead.len() >= 2 {
for i in (0..src_lead.len() - 1).rev() {
src_strides[i] = src_strides[i + 1] * src_lead[i + 1];
}
}
let mut off = 0usize;
for (i, &stride) in src_strides.iter().enumerate() {
let out_axis = i + prefix;
let dim = src_lead[i];
let coord = if dim == 1 { 0 } else { idx[out_axis] };
off += coord * stride;
}
off
}
#[cfg(feature = "cuda")]
struct StrideRun {
batch: usize,
a_start: usize, b_start: usize, c_start: usize, stride_a: i64, stride_b: i64,
stride_c: i64,
}
#[cfg(feature = "cuda")]
fn compute_stride_runs(
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
m: usize,
k: usize,
n: usize,
) -> Vec<StrideRun> {
let total: usize = out_lead.iter().product();
if total == 0 {
return Vec::new();
}
let a_mat = m * k;
let b_mat = k * n;
let c_mat = m * n;
let offs: Vec<(usize, usize, usize)> = (0..total)
.map(|f| {
(
flat_offset(f, a_lead, out_lead) * a_mat,
flat_offset(f, b_lead, out_lead) * b_mat,
f * c_mat,
)
})
.collect();
let mut runs = Vec::new();
let mut i = 0;
while i < total {
let (a0, b0, c0) = offs[i];
if i + 1 == total {
runs.push(StrideRun {
batch: 1,
a_start: a0,
b_start: b0,
c_start: c0,
stride_a: 0,
stride_b: 0,
stride_c: 0,
});
i += 1;
continue;
}
let (a1, b1, c1) = offs[i + 1];
let da = a1 as i64 - a0 as i64;
let db = b1 as i64 - b0 as i64;
let dc = c1 as i64 - c0 as i64;
let mut j = i + 1;
while j + 1 < total {
let (an, bn, cn) = offs[j];
let (an1, bn1, cn1) = offs[j + 1];
if an1 as i64 - an as i64 == da
&& bn1 as i64 - bn as i64 == db
&& cn1 as i64 - cn as i64 == dc
{
j += 1;
} else {
break;
}
}
runs.push(StrideRun {
batch: j - i + 1,
a_start: a0,
b_start: b0,
c_start: c0,
stride_a: da,
stride_b: db,
stride_c: dc,
});
i = j + 1;
}
runs
}
#[cfg(feature = "cuda")]
fn validate_broadcast_shapes(
op: &'static str,
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
a_len: usize,
b_len: usize,
m: usize,
k: usize,
n: usize,
) -> GpuResult<()> {
let a_batch_count: usize = if a_lead.is_empty() {
1
} else {
a_lead.iter().product()
};
let b_batch_count: usize = if b_lead.is_empty() {
1
} else {
b_lead.iter().product()
};
let expected_a = a_batch_count * m * k;
let expected_b = b_batch_count * k * n;
if a_len != expected_a {
return Err(GpuError::ShapeMismatch {
op,
expected: vec![expected_a],
got: vec![a_len],
});
}
if b_len != expected_b {
return Err(GpuError::ShapeMismatch {
op,
expected: vec![expected_b],
got: vec![b_len],
});
}
let max_len = a_lead.len().max(b_lead.len());
if out_lead.len() < max_len {
return Err(GpuError::ShapeMismatch {
op,
expected: vec![max_len],
got: vec![out_lead.len()],
});
}
for i in 0..out_lead.len() {
let pa = out_lead.len() - a_lead.len();
let pb = out_lead.len() - b_lead.len();
let da = if i < pa { 1 } else { a_lead[i - pa] };
let db = if i < pb { 1 } else { b_lead[i - pb] };
let target = out_lead[i];
let ok = (da == target || da == 1) && (db == target || db == 1);
if !ok {
return Err(GpuError::ShapeMismatch {
op,
expected: vec![target],
got: vec![da, db],
});
}
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_bmm_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_broadcast_shapes(
"broadcast_bmm_f32",
a_lead,
b_lead,
out_lead,
a.len(),
b.len(),
m,
k,
n,
)?;
let total: usize = out_lead.iter().product();
let out_numel = total * m * n;
if total == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(out_numel, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f32",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f32",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f32",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let runs = compute_stride_runs(a_lead, b_lead, out_lead, m, k, n);
let blas = device.blas();
let mut c = alloc_zeros_f32(out_numel, device)?;
for run in &runs {
let batch_i32 = i32::try_from(run.batch).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f32",
expected: vec![i32::MAX as usize],
got: vec![run.batch],
})?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch_i32,
stride_a: run.stride_b,
stride_b: run.stride_a,
stride_c: run.stride_c,
};
let b_extent = b_mat_extent_f32(run.batch, run.stride_b, k * n);
let a_extent = b_mat_extent_f32(run.batch, run.stride_a, m * k);
let c_extent = b_mat_extent_f32(run.batch, run.stride_c, m * n);
let a_view = a.inner().slice(run.a_start..(run.a_start + a_extent));
let b_view = b.inner().slice(run.b_start..(run.b_start + b_extent));
let mut c_view = c
.inner_mut()
.slice_mut(run.c_start..(run.c_start + c_extent));
unsafe {
blas.gemm_strided_batched(cfg, &b_view, &a_view, &mut c_view)?;
}
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_bmm_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
validate_broadcast_shapes(
"broadcast_bmm_f64",
a_lead,
b_lead,
out_lead,
a.len(),
b.len(),
m,
k,
n,
)?;
let total: usize = out_lead.iter().product();
let out_numel = total * m * n;
if total == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f64(out_numel, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let runs = compute_stride_runs(a_lead, b_lead, out_lead, m, k, n);
let blas = device.blas();
let mut c = alloc_zeros_f64(out_numel, device)?;
for run in &runs {
let batch_i32 = i32::try_from(run.batch).map_err(|_| GpuError::ShapeMismatch {
op: "broadcast_bmm_f64",
expected: vec![i32::MAX as usize],
got: vec![run.batch],
})?;
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f64,
lda: n_i32,
ldb: k_i32,
beta: 0.0f64,
ldc: n_i32,
},
batch_size: batch_i32,
stride_a: run.stride_b,
stride_b: run.stride_a,
stride_c: run.stride_c,
};
let b_extent = b_mat_extent_f32(run.batch, run.stride_b, k * n);
let a_extent = b_mat_extent_f32(run.batch, run.stride_a, m * k);
let c_extent = b_mat_extent_f32(run.batch, run.stride_c, m * n);
let a_view = a.inner().slice(run.a_start..(run.a_start + a_extent));
let b_view = b.inner().slice(run.b_start..(run.b_start + b_extent));
let mut c_view = c
.inner_mut()
.slice_mut(run.c_start..(run.c_start + c_extent));
unsafe {
blas.gemm_strided_batched(cfg, &b_view, &a_view, &mut c_view)?;
}
}
Ok(c)
}
#[cfg(feature = "cuda")]
fn b_mat_extent_f32(batch: usize, stride: i64, per_mat: usize) -> usize {
if batch == 0 {
return 0;
}
if stride == 0 {
return per_mat;
}
let abs = stride.unsigned_abs() as usize;
(batch - 1) * abs + per_mat
}
#[cfg(feature = "cuda")]
pub fn gpu_dot_f32(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != n {
return Err(GpuError::ShapeMismatch {
op: "dot",
expected: vec![n],
got: vec![a.len()],
});
}
if b.len() != n {
return Err(GpuError::ShapeMismatch {
op: "dot",
expected: vec![n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if n == 0 {
return alloc_zeros_f32(1, device);
}
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "dot",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let stream = device.stream();
let mut host_result: f32 = 0.0;
let (a_ptr, _record_a) = a.inner().device_ptr(&stream);
let (b_ptr, _record_b) = b.inner().device_ptr(&stream);
unsafe {
sys::cublasSdot_v2(
*blas.handle(),
n_i32,
a_ptr as *const f32,
1,
b_ptr as *const f32,
1,
&mut host_result as *mut f32,
)
.result()?;
}
stream.synchronize()?;
let mut out = alloc_zeros_f32(1, device)?;
stream.memcpy_htod(std::slice::from_ref(&host_result), out.inner_mut())?;
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_dot_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != n {
return Err(GpuError::ShapeMismatch {
op: "dot_f64",
expected: vec![n],
got: vec![a.len()],
});
}
if b.len() != n {
return Err(GpuError::ShapeMismatch {
op: "dot_f64",
expected: vec![n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if n == 0 {
return alloc_zeros_f64(1, device);
}
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "dot_f64",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let stream = device.stream();
let mut host_result: f64 = 0.0;
let (a_ptr, _record_a) = a.inner().device_ptr(&stream);
let (b_ptr, _record_b) = b.inner().device_ptr(&stream);
unsafe {
sys::cublasDdot_v2(
*blas.handle(),
n_i32,
a_ptr as *const f64,
1,
b_ptr as *const f64,
1,
&mut host_result as *mut f64,
)
.result()?;
}
stream.synchronize()?;
let mut out = alloc_zeros_f64(1, device)?;
stream.memcpy_htod(std::slice::from_ref(&host_result), out.inner_mut())?;
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_mv_f32(
a: &CudaBuffer<f32>,
x: &CudaBuffer<f32>,
m: usize,
k: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "mv",
expected: vec![m, k],
got: vec![a.len()],
});
}
if x.len() != k {
return Err(GpuError::ShapeMismatch {
op: "mv",
expected: vec![k],
got: vec![x.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if x.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: x.device_ordinal(),
});
}
if m == 0 || k == 0 {
return alloc_zeros_f32(m, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "mv",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "mv",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let blas = device.blas();
let mut y = alloc_zeros_f32(m, device)?;
let cfg = GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_T,
m: k_i32,
n: m_i32,
alpha: 1.0f32,
lda: k_i32,
incx: 1,
beta: 0.0f32,
incy: 1,
};
unsafe {
blas.gemv(cfg, a.inner(), x.inner(), y.inner_mut())?;
}
Ok(y)
}
#[cfg(feature = "cuda")]
pub fn gpu_mv_f64(
a: &CudaBuffer<f64>,
x: &CudaBuffer<f64>,
m: usize,
k: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "mv_f64",
expected: vec![m, k],
got: vec![a.len()],
});
}
if x.len() != k {
return Err(GpuError::ShapeMismatch {
op: "mv_f64",
expected: vec![k],
got: vec![x.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if x.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: x.device_ordinal(),
});
}
if m == 0 || k == 0 {
return alloc_zeros_f64(m, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "mv_f64",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "mv_f64",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let blas = device.blas();
let mut y = alloc_zeros_f64(m, device)?;
let cfg = GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_T,
m: k_i32,
n: m_i32,
alpha: 1.0f64,
lda: k_i32,
incx: 1,
beta: 0.0f64,
incy: 1,
};
unsafe {
blas.gemv(cfg, a.inner(), x.inner(), y.inner_mut())?;
}
Ok(y)
}
#[cfg(feature = "cuda")]
pub fn gpu_vm_f32(
x: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if x.len() != k {
return Err(GpuError::ShapeMismatch {
op: "vm",
expected: vec![k],
got: vec![x.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "vm",
expected: vec![k, n],
got: vec![b.len()],
});
}
if x.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: x.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if k == 0 || n == 0 {
return alloc_zeros_f32(n, device);
}
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "vm",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "vm",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut y = alloc_zeros_f32(n, device)?;
let cfg = GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: k_i32,
alpha: 1.0f32,
lda: n_i32,
incx: 1,
beta: 0.0f32,
incy: 1,
};
unsafe {
blas.gemv(cfg, b.inner(), x.inner(), y.inner_mut())?;
}
Ok(y)
}
#[cfg(feature = "cuda")]
pub fn gpu_vm_f64(
x: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if x.len() != k {
return Err(GpuError::ShapeMismatch {
op: "vm_f64",
expected: vec![k],
got: vec![x.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "vm_f64",
expected: vec![k, n],
got: vec![b.len()],
});
}
if x.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: x.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if k == 0 || n == 0 {
return alloc_zeros_f64(n, device);
}
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "vm_f64",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "vm_f64",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let blas = device.blas();
let mut y = alloc_zeros_f64(n, device)?;
let cfg = GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: k_i32,
alpha: 1.0f64,
lda: n_i32,
incx: 1,
beta: 0.0f64,
incy: 1,
};
unsafe {
blas.gemv(cfg, b.inner(), x.inner(), y.inner_mut())?;
}
Ok(y)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dot_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dot_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mv_f32(
_a: &CudaBuffer<f32>,
_x: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mv_f64(
_a: &CudaBuffer<f64>,
_x: &CudaBuffer<f64>,
_m: usize,
_k: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_vm_f32(
_x: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_vm_f64(
_x: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_broadcast_bmm_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_lead: &[usize],
_b_lead: &[usize],
_out_lead: &[usize],
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_broadcast_bmm_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_a_lead: &[usize],
_b_lead: &[usize],
_out_lead: &[usize],
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_into",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(());
}
let total_ops = m * k * n;
if m <= 4 || total_ops < 500_000 {
return crate::kernels::gpu_small_matmul_into(a, b, m, k, n, c, device);
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
};
unsafe {
blas.gemm(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_bmm_f32_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
c: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
if batch == 0 || m == 0 || k == 0 || n == 0 {
return Ok(());
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_into",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = m as i32;
let k_i32 = k as i32;
let n_i32 = n as i32;
let blas = device.blas();
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: n_i32,
n: m_i32,
k: k_i32,
alpha: 1.0f32,
lda: n_i32,
ldb: k_i32,
beta: 0.0f32,
ldc: n_i32,
},
batch_size: batch as i32,
stride_a: (k * n) as i64,
stride_b: (m * k) as i64,
stride_c: (m * n) as i64,
};
unsafe {
blas.gemm_strided_batched(cfg, b.inner(), a.inner(), c.inner_mut())?;
}
Ok(())
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f32_into(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_c: &mut CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_f16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: a.device_ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: device.ordinal(),
got: b.device_ordinal(),
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_f16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_f16 = crate::kernels::gpu_f32_to_f16(a, device)?;
let b_f16 = crate::kernels::gpu_f32_to_f16(b, device)?;
let mut c = alloc_zeros_f32(m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _record_a) = a_f16.device_ptr(&stream);
let (b_ptr, _record_b) = b_f16.device_ptr(&stream);
let (c_ptr, _record_c) = c.inner_mut().device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32, m_i32, k_i32, (&alpha) as *const f32 as *const c_void, b_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, n_i32, a_ptr as *const c_void, cublas_sys::cudaDataType_t::CUDA_R_16F, k_i32, (&beta) as *const f32 as *const c_void, c_ptr as *mut c_void, cublas_sys::cudaDataType_t::CUDA_R_32F, n_i32, cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, )?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_f16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_f16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_f16 = crate::kernels::gpu_f32_to_f16(a, device)?;
let b_f16 = crate::kernels::gpu_f32_to_f16(b, device)?;
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a_f16.device_ptr(&stream);
let (b_ptr, _rb) = b_f16.device_ptr(&stream);
let (c_ptr, _rc) = c.inner_mut().device_ptr_mut(&stream);
let stride_a_f16 = (k * n) as i64; let stride_b_f16 = (m * k) as i64; let stride_c = (m * n) as i64;
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16F,
n_i32,
stride_a_f16,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16F,
k_i32,
stride_b_f16,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_32F,
n_i32,
stride_c,
batch as i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_f16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(m * n, device);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_bf16 = crate::kernels::gpu_f32_to_bf16(a, device)?;
let b_bf16 = crate::kernels::gpu_f32_to_bf16(b, device)?;
let mut c = alloc_zeros_f32(m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a_bf16.device_ptr(&stream);
let (b_ptr, _rb) = b_bf16.device_ptr(&stream);
let (c_ptr, _rc) = c.inner_mut().device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_32F,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_bmm_bf16(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch == 0 || m == 0 || k == 0 || n == 0 {
return alloc_zeros_f32(batch * m * n, device);
}
if a.len() != batch * m * k {
return Err(GpuError::ShapeMismatch {
op: "bmm_bf16",
expected: vec![batch, m, k],
got: vec![a.len()],
});
}
if b.len() != batch * k * n {
return Err(GpuError::ShapeMismatch {
op: "bmm_bf16",
expected: vec![batch, k, n],
got: vec![b.len()],
});
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_bf16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_bf16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "bmm_bf16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let a_bf16 = crate::kernels::gpu_f32_to_bf16(a, device)?;
let b_bf16 = crate::kernels::gpu_f32_to_bf16(b, device)?;
let mut c = alloc_zeros_f32(batch * m * n, device)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a_bf16.device_ptr(&stream);
let (b_ptr, _rb) = b_bf16.device_ptr(&stream);
let (c_ptr, _rc) = c.inner_mut().device_ptr_mut(&stream);
let stride_a_bf16 = (k * n) as i64; let stride_b_bf16 = (m * k) as i64; let stride_c = (m * n) as i64;
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
stride_a_bf16,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_b_bf16,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_32F,
n_i32,
stride_c,
batch as i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_bmm_bf16(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_batch: usize,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16_bf16(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() < m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() < k * n {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![k, n],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(m * n)?);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let mut c = device.stream().alloc_zeros::<u16>(m * n)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16_bf16(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_matmul_bf16_bf16_nt(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if a.len() < m * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![m, k],
got: vec![a.len()],
});
}
if b.len() < n * k {
return Err(GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![n, k],
got: vec![b.len()],
});
}
if m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(m * n)?);
}
let m_i32 = i32::try_from(m).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![m],
})?;
let k_i32 = i32::try_from(k).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![k],
})?;
let n_i32 = i32::try_from(n).map_err(|_| GpuError::ShapeMismatch {
op: "matmul_bf16_bf16_nt",
expected: vec![i32::MAX as usize],
got: vec![n],
})?;
let mut c = device.stream().alloc_zeros::<u16>(m * n)?;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_T,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_bf16_bf16_nt(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched_nt(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
batch_count: usize,
stride_a_elems: usize,
stride_b_elems: usize,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch_count == 0 || m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(batch_count * m * n)?);
}
let (m_i32, k_i32, n_i32, bc_i32) = (m as i32, k as i32, n as i32, batch_count as i32);
let mut c = device.stream().alloc_zeros::<u16>(batch_count * m * n)?;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_T,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_b_elems as i64,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_a_elems as i64,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
(m * n) as i64,
bc_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
m: usize,
k: usize,
n: usize,
batch_count: usize,
stride_a_elems: usize,
stride_b_elems: usize,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use core::ffi::c_void;
use cudarc::cublas::{result as cublas_result, sys as cublas_sys};
use cudarc::driver::{DevicePtr, DevicePtrMut};
if batch_count == 0 || m == 0 || k == 0 || n == 0 {
return Ok(device.stream().alloc_zeros::<u16>(batch_count * m * n)?);
}
let (m_i32, k_i32, n_i32, bc_i32) = (m as i32, k as i32, n as i32, batch_count as i32);
let mut c = device.stream().alloc_zeros::<u16>(batch_count * m * n)?;
let beta: f32 = 0.0;
let blas = device.blas();
let stream = device.stream();
{
let (a_ptr, _ra) = a.device_ptr(&stream);
let (b_ptr, _rb) = b.device_ptr(&stream);
let (c_ptr, _rc) = c.device_ptr_mut(&stream);
unsafe {
cublas_result::gemm_strided_batched_ex(
*blas.handle(),
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
cublas_sys::cublasOperation_t::CUBLAS_OP_N,
n_i32,
m_i32,
k_i32,
(&alpha) as *const f32 as *const c_void,
b_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
stride_b_elems as i64,
a_ptr as *const c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
k_i32,
stride_a_elems as i64,
(&beta) as *const f32 as *const c_void,
c_ptr as *mut c_void,
cublas_sys::cudaDataType_t::CUDA_R_16BF,
n_i32,
(m * n) as i64,
bc_i32,
cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)?;
}
}
Ok(c)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched_nt(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_batch_count: usize,
_stride_a_elems: usize,
_stride_b_elems: usize,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_matmul_bf16_bf16_strided_batched(
_a: &(),
_b: &(),
_m: usize,
_k: usize,
_n: usize,
_batch_count: usize,
_stride_a_elems: usize,
_stride_b_elems: usize,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
fn cpu_matmul_naive<T>(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Vec<T>
where
T: Copy + Default + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
{
let mut c = vec![T::default(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = T::default();
for p in 0..k {
sum = sum + a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
c
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f32(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_matmul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_m: usize,
_k: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
use crate::device::GpuDevice;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
fn setup_f32(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_close_f32(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32], tol: f32) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
fn assert_buf_close_f64(buf: &CudaBuffer<f64>, device: &GpuDevice, expected: &[f64], tol: f64) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < tol,
"element {i}: got {got}, expected {exp}, diff {}",
(got - exp).abs(),
);
}
}
#[test]
fn matmul_f32_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-4);
}
#[test]
fn matmul_f32_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f32(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &a_data, 1e-6);
}
#[test]
fn matmul_f32_row_vector() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0];
let b_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let expected: Vec<f32> = vec![4.0, 5.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 1, 3, 2, &dev).expect("gpu_matmul_f32");
assert_eq!(c.len(), 2);
assert_buf_close_f32(&c, &dev, &expected, 1e-6);
}
#[test]
fn matmul_f32_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f32(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch { op: "matmul", .. } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f32_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f32 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f64_basic() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f64> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f64> = vec![58.0, 64.0, 139.0, 154.0];
let a = cpu_to_gpu(&a_data, &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f64(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f64");
assert_eq!(c.len(), 4);
assert_buf_close_f64(&c, &dev, &expected, 1e-10);
}
#[test]
fn matmul_f32_vs_cpu() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f32(&a, &b, m, k, n, &dev).expect("gpu_matmul_f32");
assert_buf_close_f32(&c, &dev, &expected, 1e-3);
}
#[test]
fn matmul_f32_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let gpu_start = std::time::Instant::now();
let _c = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let gpu_elapsed = gpu_start.elapsed();
let cpu_start = std::time::Instant::now();
let _c_cpu = cpu_matmul_naive(&a_data, &b_data, dim, dim, dim);
let cpu_elapsed = cpu_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: GPU = {:.3}ms, CPU = {:.3}ms, speedup = {:.1}x",
gpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() * 1000.0,
cpu_elapsed.as_secs_f64() / gpu_elapsed.as_secs_f64(),
);
}
#[test]
fn matmul_f16_basic() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_f16");
assert_eq!(c.len(), 4);
assert_buf_close_f32(&c, &dev, &expected, 1e-2);
}
#[test]
fn matmul_f16_identity() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let i_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
let (dev, a) = setup_f32(&a_data);
let i_buf = cpu_to_gpu(&i_data, &dev).expect("cpu_to_gpu i");
let c = gpu_matmul_f16(&a, &i_buf, 2, 2, 2, &dev).expect("gpu_matmul_f16");
assert_buf_close_f32(&c, &dev, &a_data, 1e-3);
}
#[test]
fn matmul_f16_empty() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 0, 0, 0, &dev).expect("gpu_matmul_f16 empty");
assert_eq!(c.len(), 0);
}
#[test]
fn matmul_f16_k_zero() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu a");
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, 2, 0, 3, &dev).expect("gpu_matmul_f16 k=0");
assert_eq!(c.len(), 6);
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert!(host.iter().all(|&x| x == 0.0));
}
#[test]
fn matmul_f16_wrong_a_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0]); let b = cpu_to_gpu(&[1.0, 2.0, 3.0, 4.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_wrong_b_length() {
let (dev, a) = setup_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = cpu_to_gpu(&[1.0, 2.0, 3.0], &dev).expect("cpu_to_gpu b");
let err = gpu_matmul_f16(&a, &b, 2, 2, 2, &dev).unwrap_err();
match err {
GpuError::ShapeMismatch {
op: "matmul_f16", ..
} => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn matmul_f16_vs_f32_reference() {
let m = 64;
let k = 48;
let n = 32;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 13) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 3) % 100) as f32 / 100.0)
.collect();
let expected = cpu_matmul_naive(&a_data, &b_data, m, k, n);
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let c = gpu_matmul_f16(&a, &b, m, k, n, &dev).expect("gpu_matmul_f16");
let host = gpu_to_cpu(&c, &dev).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "output length mismatch");
let mut max_err: f32 = 0.0;
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
let abs_err = (got - exp).abs();
let rel_err = if exp.abs() > 1e-6 {
abs_err / exp.abs()
} else {
abs_err
};
max_err = max_err.max(abs_err);
assert!(
rel_err < 0.05 || abs_err < 0.1,
"element {i}: f16 got {got}, f32 ref {exp}, abs_err {abs_err}, rel_err {rel_err}",
);
}
eprintln!("matmul_f16_vs_f32: {m}x{k} @ {k}x{n}, max absolute error = {max_err:.6}",);
}
#[test]
fn matmul_f16_1024x1024_perf() {
let dim = 1024;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 7 + 13) % 1000) as f32 / 1000.0)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| ((i * 11 + 3) % 1000) as f32 / 1000.0)
.collect();
let (dev, a) = setup_f32(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let f32_start = std::time::Instant::now();
let _c32 = gpu_matmul_f32(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f32");
let f32_elapsed = f32_start.elapsed();
let f16_start = std::time::Instant::now();
let _c16 = gpu_matmul_f16(&a, &b, dim, dim, dim, &dev).expect("gpu_matmul_f16");
let f16_elapsed = f16_start.elapsed();
eprintln!(
"matmul {dim}x{dim}: f32 = {:.3}ms, f16 = {:.3}ms, f16 speedup = {:.1}x",
f32_elapsed.as_secs_f64() * 1000.0,
f16_elapsed.as_secs_f64() * 1000.0,
f32_elapsed.as_secs_f64() / f16_elapsed.as_secs_f64(),
);
}
fn upload_as_bf16(dev: &GpuDevice, data: &[f32]) -> cudarc::driver::CudaSlice<u16> {
let u16_data: Vec<u16> = data
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
dev.stream().clone_htod(&u16_data).expect("bf16 upload")
}
fn download_bf16_as_f32(dev: &GpuDevice, buf: &cudarc::driver::CudaSlice<u16>) -> Vec<f32> {
let bits: Vec<u16> = dev.stream().clone_dtoh(buf).expect("bf16 download");
bits.into_iter()
.map(|b| half::bf16::from_bits(b).to_f32())
.collect()
}
#[test]
fn matmul_bf16_bf16_basic_2x3_3x2() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16(&a, &b, 2, 3, 2, &dev).expect("gpu_matmul_bf16_bf16");
let got = download_bf16_as_f32(&dev, &c);
assert_eq!(got.len(), expected.len());
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
assert!(
(g - e).abs() < 1.0,
"element {i}: got {g}, expected {e}, diff {}",
(g - e).abs(),
);
}
}
#[test]
fn matmul_bf16_bf16_identity_rows() {
let i_data: Vec<f32> = {
let mut v = vec![0.0f32; 16];
for d in 0..4 {
v[d * 4 + d] = 1.0;
}
v
};
let x_data: Vec<f32> = (0..12).map(|i| (i as f32) - 6.0).collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let i = upload_as_bf16(&dev, &i_data);
let x = upload_as_bf16(&dev, &x_data);
let c = gpu_matmul_bf16_bf16(&i, &x, 4, 4, 3, &dev).expect("matmul");
let got = download_bf16_as_f32(&dev, &c);
for (a, b) in got.iter().zip(x_data.iter()) {
assert_eq!(a, b, "identity matmul must preserve X exactly");
}
}
#[test]
fn matmul_bf16_bf16_nt_basic_2x3_2x3() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![50.0, 68.0, 122.0, 167.0];
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16_nt(&a, &b, 2, 3, 2, &dev).expect("matmul_nt");
let got = download_bf16_as_f32(&dev, &c);
for (i, (&g, &e)) in got.iter().zip(expected.iter()).enumerate() {
assert!(
(g - e).abs() <= e.abs() * 0.02 + 1.0,
"nt[{i}]: got {g}, expected {e}",
);
}
}
#[test]
fn matmul_bf16_bf16_nt_equivalent_to_explicit_transpose() {
let m = 4;
let k = 3;
let n = 5;
let a_data: Vec<f32> = (0..m * k).map(|i| (i as f32) * 0.5 - 2.0).collect();
let b_data: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.25 + 1.0).collect();
let mut b_t: Vec<f32> = vec![0.0; k * n];
for i in 0..n {
for j in 0..k {
b_t[j * n + i] = b_data[i * k + j];
}
}
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let bt = upload_as_bf16(&dev, &b_t);
let c_nt = gpu_matmul_bf16_bf16_nt(&a, &b, m, k, n, &dev).unwrap();
let c_ref = gpu_matmul_bf16_bf16(&a, &bt, m, k, n, &dev).unwrap();
let nt = download_bf16_as_f32(&dev, &c_nt);
let rf = download_bf16_as_f32(&dev, &c_ref);
for (i, (&a, &b)) in nt.iter().zip(rf.iter()).enumerate() {
assert!((a - b).abs() < 0.01, "nt[{i}]={a} vs ref[{i}]={b}",);
}
}
#[test]
fn matmul_bf16_strided_batched_matches_per_batch_reference() {
let dev = GpuDevice::new(0).expect("cuda");
let a0: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b0: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let a1: Vec<f32> = vec![0.5, 0.5, 0.5, 1.0, 1.0, 1.0];
let b1: Vec<f32> = vec![2.0, 2.0, 2.0, -1.0, -1.0, -1.0];
let a: Vec<f32> = [&a0[..], &a1[..]].concat();
let b: Vec<f32> = [&b0[..], &b1[..]].concat();
let a_gpu = upload_as_bf16(&dev, &a);
let b_gpu = upload_as_bf16(&dev, &b);
let c =
gpu_matmul_bf16_bf16_strided_batched_nt(&a_gpu, &b_gpu, 2, 3, 2, 2, 6, 6, 0.5, &dev)
.expect("batched");
let got = download_bf16_as_f32(&dev, &c);
let ref0 = gpu_matmul_bf16_bf16_nt(
&upload_as_bf16(&dev, &a0),
&upload_as_bf16(&dev, &b0),
2,
3,
2,
&dev,
)
.unwrap();
let ref1 = gpu_matmul_bf16_bf16_nt(
&upload_as_bf16(&dev, &a1),
&upload_as_bf16(&dev, &b1),
2,
3,
2,
&dev,
)
.unwrap();
let expected0 = download_bf16_as_f32(&dev, &ref0);
let expected1 = download_bf16_as_f32(&dev, &ref1);
for (i, (&g, &e)) in got[..4].iter().zip(expected0.iter()).enumerate() {
let scaled = e * 0.5;
assert!(
(g - scaled).abs() < scaled.abs() * 0.05 + 0.1,
"b0[{i}]: got {g}, expected {scaled}",
);
}
for (i, (&g, &e)) in got[4..].iter().zip(expected1.iter()).enumerate() {
let scaled = e * 0.5;
assert!(
(g - scaled).abs() < scaled.abs() * 0.05 + 0.1,
"b1[{i}]: got {g}, expected {scaled}",
);
}
}
#[test]
fn matmul_bf16_bf16_large_dims_finite() {
let dim = 512;
let a_data: Vec<f32> = (0..dim * dim)
.map(|i| (((i * 7 + 13) % 1000) as f32 / 1000.0) - 0.5)
.collect();
let b_data: Vec<f32> = (0..dim * dim)
.map(|i| (((i * 11 + 3) % 1000) as f32 / 1000.0) - 0.5)
.collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = upload_as_bf16(&dev, &a_data);
let b = upload_as_bf16(&dev, &b_data);
let c = gpu_matmul_bf16_bf16(&a, &b, dim, dim, dim, &dev).expect("matmul");
let got = download_bf16_as_f32(&dev, &c);
assert_eq!(got.len(), dim * dim);
for &v in &got {
assert!(v.is_finite(), "non-finite output in bf16 matmul");
}
}
}