use std::sync::Mutex;
use oxicuda_driver::device::Device;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::primary_context::PrimaryContext;
pub use oxicuda_backend::{
BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
};
#[cfg(feature = "ptx")]
mod ptx_ops;
#[derive(Debug)]
pub struct CudaBackend {
initialized: bool,
context: Mutex<Option<PrimaryContext>>,
}
impl CudaBackend {
#[must_use]
pub fn new() -> Self {
Self {
initialized: false,
context: Mutex::new(None),
}
}
fn check_init(&self) -> BackendResult<()> {
if self.initialized {
Ok(())
} else {
Err(BackendError::NotInitialized)
}
}
fn has_gpu_context(&self) -> bool {
self.context.lock().map(|g| g.is_some()).unwrap_or(false)
}
#[cfg(any(feature = "blas", feature = "dnn", feature = "ptx"))]
fn activate_gpu(&self) -> BackendResult<Device> {
let guard = self
.context
.lock()
.map_err(|_| BackendError::DeviceError("backend context lock poisoned".into()))?;
let ctx = guard
.as_ref()
.ok_or_else(|| BackendError::DeviceError("no CUDA GPU context available".into()))?;
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(ctx.raw()) })
.map_err(|e| BackendError::DeviceError(e.to_string()))?;
Ok(*ctx.device())
}
}
impl Default for CudaBackend {
fn default() -> Self {
Self::new()
}
}
impl ComputeBackend for CudaBackend {
fn name(&self) -> &str {
"cuda"
}
fn init(&mut self) -> BackendResult<()> {
if self.initialized {
return Ok(());
}
if let Ok(()) = oxicuda_driver::init() {
if let Ok(dev) = Device::get(0) {
if let Ok(ctx) = PrimaryContext::retain(&dev) {
if let Ok(api) = try_driver() {
let raw = ctx.raw();
let _ =
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(raw) });
}
if let Ok(mut guard) = self.context.lock() {
*guard = Some(ctx);
}
}
}
}
self.initialized = true;
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized
}
fn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()> {
self.check_init()?;
#[cfg(feature = "blas")]
{
gemm_impl(
self, trans_a, trans_b, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc,
)
}
#[cfg(not(feature = "blas"))]
{
let _ = (
trans_a, trans_b, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc,
);
Err(BackendError::Unsupported(
"GEMM requires the 'blas' feature".into(),
))
}
}
fn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()> {
self.check_init()?;
if input_shape.len() != 4 {
return Err(BackendError::InvalidArgument(format!(
"input_shape must have 4 elements (NCHW), got {}",
input_shape.len()
)));
}
if filter_shape.len() != 4 {
return Err(BackendError::InvalidArgument(format!(
"filter_shape must have 4 elements (KCRS), got {}",
filter_shape.len()
)));
}
if output_shape.len() != 4 {
return Err(BackendError::InvalidArgument(format!(
"output_shape must have 4 elements (NKPQ), got {}",
output_shape.len()
)));
}
if stride.len() != 2 {
return Err(BackendError::InvalidArgument(format!(
"stride must have 2 elements, got {}",
stride.len()
)));
}
if padding.len() != 2 {
return Err(BackendError::InvalidArgument(format!(
"padding must have 2 elements, got {}",
padding.len()
)));
}
#[cfg(feature = "dnn")]
{
conv2d_forward_impl(
self,
input_ptr,
input_shape,
filter_ptr,
filter_shape,
output_ptr,
output_shape,
stride,
padding,
)
}
#[cfg(not(feature = "dnn"))]
{
let _ = (input_ptr, filter_ptr, output_ptr);
Err(BackendError::Unsupported(
"conv2d_forward requires the 'dnn' feature".into(),
))
}
}
fn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()> {
self.check_init()?;
if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
return Err(BackendError::InvalidArgument(
"sequence lengths and head_dim must be > 0".into(),
));
}
if scale <= 0.0 || !scale.is_finite() {
return Err(BackendError::InvalidArgument(format!(
"scale must be a positive finite number, got {scale}"
)));
}
#[cfg(feature = "dnn")]
{
attention_impl(
self, q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq_q, seq_kv, head_dim, scale,
causal,
)
}
#[cfg(not(feature = "dnn"))]
{
let _ = (
q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq_q, seq_kv, head_dim, scale, causal,
);
Err(BackendError::Unsupported(
"attention requires the 'dnn' feature".into(),
))
}
}
fn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()> {
self.check_init()?;
if shape.is_empty() {
return Err(BackendError::InvalidArgument(
"shape must not be empty".into(),
));
}
if axis >= shape.len() {
return Err(BackendError::InvalidArgument(format!(
"axis {} out of bounds for shape with {} dimensions",
axis,
shape.len()
)));
}
#[cfg(feature = "ptx")]
{
ptx_ops::reduce_axis(self, op, input_ptr, output_ptr, shape, axis)
}
#[cfg(not(feature = "ptx"))]
{
let _ = (op, input_ptr, output_ptr);
Err(BackendError::Unsupported(
"reduce requires the 'ptx' feature".into(),
))
}
}
fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
self.check_init()?;
if n == 0 {
return Ok(()); }
#[cfg(feature = "ptx")]
{
ptx_ops::unary_elementwise(self, op, input_ptr, output_ptr, n)
}
#[cfg(not(feature = "ptx"))]
{
let _ = (op, input_ptr, output_ptr);
Err(BackendError::Unsupported(
"unary requires the 'ptx' feature".into(),
))
}
}
fn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()> {
self.check_init()?;
if n == 0 {
return Ok(()); }
#[cfg(feature = "ptx")]
{
ptx_ops::binary_elementwise(self, op, a_ptr, b_ptr, output_ptr, n)
}
#[cfg(not(feature = "ptx"))]
{
let _ = (op, a_ptr, b_ptr, output_ptr);
Err(BackendError::Unsupported(
"binary requires the 'ptx' feature".into(),
))
}
}
fn synchronize(&self) -> BackendResult<()> {
self.check_init()?;
if !self.has_gpu_context() {
return Ok(());
}
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_synchronize)() })
.map_err(|e| BackendError::DeviceError(e.to_string()))
}
fn alloc(&self, bytes: usize) -> BackendResult<u64> {
self.check_init()?;
if bytes == 0 {
return Err(BackendError::InvalidArgument(
"cannot allocate 0 bytes".into(),
));
}
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
let mut ptr: CUdeviceptr = 0;
oxicuda_driver::error::check(unsafe { (api.cu_mem_alloc_v2)(&mut ptr, bytes) }).map_err(
|e| match e {
oxicuda_driver::CudaError::OutOfMemory => BackendError::OutOfMemory,
other => BackendError::DeviceError(other.to_string()),
},
)?;
Ok(ptr)
}
fn free(&self, ptr: u64) -> BackendResult<()> {
self.check_init()?;
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
oxicuda_driver::error::check(unsafe { (api.cu_mem_free_v2)(ptr) })
.map_err(|e| BackendError::DeviceError(e.to_string()))
}
fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
self.check_init()?;
if src.is_empty() {
return Ok(());
}
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_htod_v2)(dst, src.as_ptr().cast(), src.len())
})
.map_err(|e| BackendError::DeviceError(e.to_string()))
}
fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
self.check_init()?;
if dst.is_empty() {
return Ok(());
}
let api = try_driver().map_err(|e| BackendError::DeviceError(e.to_string()))?;
oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_dtoh_v2)(dst.as_mut_ptr().cast(), src, dst.len())
})
.map_err(|e| BackendError::DeviceError(e.to_string()))
}
}
#[cfg(any(feature = "blas", feature = "dnn"))]
fn handle_context_token(
backend: &CudaBackend,
device: Device,
) -> BackendResult<std::sync::Arc<oxicuda_driver::Context>> {
use std::sync::Arc;
let token = oxicuda_driver::Context::new(&device)
.map_err(|e| BackendError::DeviceError(format!("context token creation failed: {e}")))?;
backend.activate_gpu()?;
Ok(Arc::new(token))
}
#[cfg(feature = "blas")]
#[allow(clippy::too_many_arguments)]
fn gemm_impl(
backend: &CudaBackend,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()> {
use oxicuda_blas::BlasHandle;
use oxicuda_blas::level3::gemm_api::gemm as blas_gemm;
use oxicuda_blas::types::{Layout, MatrixDesc, MatrixDescMut, Transpose};
if m == 0 || n == 0 || k == 0 {
return Err(BackendError::InvalidArgument(
"GEMM dimensions m, n, k must all be > 0".into(),
));
}
let to_u32 = |v: usize, what: &str| -> BackendResult<u32> {
u32::try_from(v)
.map_err(|_| BackendError::InvalidArgument(format!("GEMM {what} exceeds u32 range")))
};
let m_u = to_u32(m, "m")?;
let n_u = to_u32(n, "n")?;
let k_u = to_u32(k, "k")?;
let lda_u = to_u32(lda, "lda")?;
let ldb_u = to_u32(ldb, "ldb")?;
let ldc_u = to_u32(ldc, "ldc")?;
let device = backend.activate_gpu()?;
let ctx = handle_context_token(backend, device)?;
let handle = BlasHandle::new(&ctx)
.map_err(|e| BackendError::DeviceError(format!("BLAS handle creation failed: {e}")))?;
let map_trans = |t: BackendTranspose| match t {
BackendTranspose::NoTrans => Transpose::NoTrans,
BackendTranspose::Trans => Transpose::Trans,
BackendTranspose::ConjTrans => Transpose::ConjTrans,
};
let blas_trans_a = map_trans(trans_a);
let blas_trans_b = map_trans(trans_b);
let stored_dims = |trans: Transpose, eff_rows: u32, eff_cols: u32| -> (u32, u32) {
match trans {
Transpose::NoTrans => (eff_rows, eff_cols),
Transpose::Trans | Transpose::ConjTrans => (eff_cols, eff_rows),
}
};
let (a_rows, a_cols) = stored_dims(blas_trans_a, m_u, k_u);
let (b_rows, b_cols) = stored_dims(blas_trans_b, k_u, n_u);
let a_desc = MatrixDesc::<f64>::from_raw(
a_ptr as CUdeviceptr,
a_rows,
a_cols,
lda_u,
Layout::ColMajor,
);
let b_desc = MatrixDesc::<f64>::from_raw(
b_ptr as CUdeviceptr,
b_rows,
b_cols,
ldb_u,
Layout::ColMajor,
);
let mut c_desc =
MatrixDescMut::<f64>::from_raw(c_ptr as CUdeviceptr, m_u, n_u, ldc_u, Layout::ColMajor);
blas_gemm(
&handle,
blas_trans_a,
blas_trans_b,
alpha,
&a_desc,
&b_desc,
beta,
&mut c_desc,
)
.map_err(map_blas_error)
}
#[cfg(feature = "blas")]
fn map_blas_error(e: oxicuda_blas::BlasError) -> BackendError {
use oxicuda_blas::BlasError;
match e {
BlasError::InvalidDimension(m) | BlasError::DimensionMismatch(m) => {
BackendError::InvalidArgument(m)
}
BlasError::InvalidArgument(m) => BackendError::InvalidArgument(m),
BlasError::UnsupportedOperation(m) => BackendError::Unsupported(m),
BlasError::BufferTooSmall { expected, actual } => BackendError::InvalidArgument(format!(
"BLAS buffer too small: expected {expected}, got {actual}"
)),
other => BackendError::DeviceError(other.to_string()),
}
}
#[cfg(feature = "dnn")]
#[allow(clippy::too_many_arguments)]
fn conv2d_forward_impl(
backend: &CudaBackend,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()> {
use oxicuda_dnn::conv::api::conv_forward;
use oxicuda_dnn::handle::DnnHandle;
use oxicuda_dnn::types::{ConvolutionDescriptor, TensorDesc, TensorDescMut, TensorLayout};
use oxicuda_memory::DeviceBuffer;
let to_u32 = |v: usize, what: &str| -> BackendResult<u32> {
u32::try_from(v)
.map_err(|_| BackendError::InvalidArgument(format!("conv2d {what} exceeds u32 range")))
};
let dims_u32 = |s: &[usize], what: &str| -> BackendResult<Vec<u32>> {
s.iter().map(|&v| to_u32(v, what)).collect()
};
let in_dims = dims_u32(input_shape, "input dim")?;
let filt_dims = dims_u32(filter_shape, "filter dim")?;
let out_dims = dims_u32(output_shape, "output dim")?;
let nchw_strides = |d: &[u32]| -> Vec<u32> { vec![d[1] * d[2] * d[3], d[2] * d[3], d[3], 1] };
let device = backend.activate_gpu()?;
let ctx = handle_context_token(backend, device)?;
let mut handle = DnnHandle::new(&ctx)
.map_err(|e| BackendError::DeviceError(format!("DNN handle creation failed: {e}")))?;
let conv_desc = ConvolutionDescriptor::conv2d(
to_u32(padding[0], "pad_h")?,
to_u32(padding[1], "pad_w")?,
to_u32(stride[0], "stride_h")?,
to_u32(stride[1], "stride_w")?,
1,
1,
1,
)
.map_err(map_dnn_error)?;
let run = |handle: &DnnHandle,
workspace: Option<&mut DeviceBuffer<u8>>|
-> oxicuda_dnn::error::DnnResult<()> {
let input = TensorDesc::<f32>::from_raw(
input_ptr as CUdeviceptr,
in_dims.clone(),
nchw_strides(&in_dims),
TensorLayout::Nchw,
)?;
let filter = TensorDesc::<f32>::from_raw(
filter_ptr as CUdeviceptr,
filt_dims.clone(),
nchw_strides(&filt_dims),
TensorLayout::Nchw,
)?;
let mut output = TensorDescMut::<f32>::from_raw(
output_ptr as CUdeviceptr,
out_dims.clone(),
nchw_strides(&out_dims),
TensorLayout::Nchw,
)?;
conv_forward(handle, &input, &filter, &mut output, &conv_desc, workspace)
};
match run(&handle, None) {
Ok(()) => Ok(()),
Err(oxicuda_dnn::error::DnnError::WorkspaceRequired(bytes)) => {
let mut ws = DeviceBuffer::<u8>::alloc(bytes.max(1))
.map_err(|e| BackendError::DeviceError(format!("workspace alloc failed: {e}")))?;
run(&handle, Some(&mut ws)).map_err(map_dnn_error)
}
Err(other) => Err(map_dnn_error(other)),
}
.map(|()| {
let _ = &mut handle;
})
}
#[cfg(feature = "dnn")]
#[allow(clippy::too_many_arguments)]
fn attention_impl(
backend: &CudaBackend,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()> {
use oxicuda_dnn::attn::flash_attn::forward::{FlashAttentionConfig, flash_attention_forward};
use oxicuda_dnn::handle::DnnHandle;
use oxicuda_dnn::types::{TensorDesc, TensorDescMut, TensorLayout};
use oxicuda_memory::DeviceBuffer;
if batch == 0 || heads == 0 {
return Err(BackendError::InvalidArgument(
"attention batch and heads must be > 0".into(),
));
}
let to_u32 = |v: usize, what: &str| -> BackendResult<u32> {
u32::try_from(v).map_err(|_| {
BackendError::InvalidArgument(format!("attention {what} exceeds u32 range"))
})
};
let batch_u = to_u32(batch, "batch")?;
let heads_u = to_u32(heads, "heads")?;
let seq_q_u = to_u32(seq_q, "seq_q")?;
let seq_kv_u = to_u32(seq_kv, "seq_kv")?;
let head_dim_u = to_u32(head_dim, "head_dim")?;
let device = backend.activate_gpu()?;
let ctx = handle_context_token(backend, device)?;
let mut handle = DnnHandle::new(&ctx)
.map_err(|e| BackendError::DeviceError(format!("DNN handle creation failed: {e}")))?;
let sm = handle.sm_version();
let mut config = FlashAttentionConfig::auto(head_dim_u, seq_q_u, seq_kv_u, causal, sm);
config.num_heads = heads_u;
config.sm_scale = scale as f32;
config.block_m = config.block_m.min(seq_q_u).max(1);
config.block_n = config.block_n.min(seq_kv_u).max(1);
const SHARED_BUDGET: u32 = 46 * 1024;
while config.shared_mem_bytes() > SHARED_BUDGET && (config.block_m > 1 || config.block_n > 1) {
if config.block_m >= config.block_n && config.block_m > 1 {
config.block_m = (config.block_m / 2).max(1);
} else if config.block_n > 1 {
config.block_n = (config.block_n / 2).max(1);
} else {
break;
}
}
if config.shared_mem_bytes() > SHARED_BUDGET {
return Err(BackendError::Unsupported(format!(
"attention head_dim {head_dim} too large for the device shared-memory budget"
)));
}
let attn_strides = |seq: u32| -> Vec<u32> {
vec![heads_u * seq * head_dim_u, seq * head_dim_u, head_dim_u, 1]
};
let q = TensorDesc::<f32>::from_raw(
q_ptr as CUdeviceptr,
vec![batch_u, heads_u, seq_q_u, head_dim_u],
attn_strides(seq_q_u),
TensorLayout::Nchw,
)
.map_err(map_dnn_error)?;
let k = TensorDesc::<f32>::from_raw(
k_ptr as CUdeviceptr,
vec![batch_u, heads_u, seq_kv_u, head_dim_u],
attn_strides(seq_kv_u),
TensorLayout::Nchw,
)
.map_err(map_dnn_error)?;
let v = TensorDesc::<f32>::from_raw(
v_ptr as CUdeviceptr,
vec![batch_u, heads_u, seq_kv_u, head_dim_u],
attn_strides(seq_kv_u),
TensorLayout::Nchw,
)
.map_err(map_dnn_error)?;
let mut o = TensorDescMut::<f32>::from_raw(
o_ptr as CUdeviceptr,
vec![batch_u, heads_u, seq_q_u, head_dim_u],
attn_strides(seq_q_u),
TensorLayout::Nchw,
)
.map_err(map_dnn_error)?;
let lse_elems = batch * heads * seq_q;
let mut lse = DeviceBuffer::<f32>::alloc(lse_elems.max(1))
.map_err(|e| BackendError::DeviceError(format!("LSE buffer alloc failed: {e}")))?;
let result = flash_attention_forward(&handle, &q, &k, &v, &mut o, &mut lse, &config)
.map_err(map_dnn_error);
let _ = &mut handle;
result
}
#[cfg(feature = "dnn")]
fn map_dnn_error(e: oxicuda_dnn::error::DnnError) -> BackendError {
use oxicuda_dnn::error::DnnError;
match e {
DnnError::InvalidDimension(m) | DnnError::InvalidArgument(m) => {
BackendError::InvalidArgument(m)
}
DnnError::UnsupportedOperation(m) => BackendError::Unsupported(m),
DnnError::BufferTooSmall { expected, actual } => BackendError::InvalidArgument(format!(
"DNN buffer too small: expected {expected}, got {actual}"
)),
DnnError::WorkspaceRequired(bytes) => {
BackendError::DeviceError(format!("DNN workspace required: {bytes} bytes"))
}
other => BackendError::DeviceError(other.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cuda_backend_new_is_uninitialized() {
let backend = CudaBackend::new();
assert!(!backend.is_initialized());
}
#[test]
fn cuda_backend_init_sets_initialized() {
let mut backend = CudaBackend::new();
let result = backend.init();
assert!(result.is_ok());
assert!(backend.is_initialized());
}
#[test]
fn cuda_backend_double_init_is_noop() {
let mut backend = CudaBackend::new();
assert!(backend.init().is_ok());
assert!(backend.init().is_ok());
assert!(backend.is_initialized());
}
#[test]
fn cuda_backend_name() {
let backend = CudaBackend::new();
assert_eq!(backend.name(), "cuda");
}
#[test]
fn cuda_backend_default() {
let backend = CudaBackend::default();
assert!(!backend.is_initialized());
assert_eq!(backend.name(), "cuda");
}
#[test]
fn trait_is_object_safe() {
let mut backend = CudaBackend::new();
backend.init().ok();
let boxed: Box<dyn ComputeBackend> = Box::new(backend);
assert_eq!(boxed.name(), "cuda");
assert!(boxed.is_initialized());
assert!(boxed.synchronize().is_ok());
}
#[test]
fn operations_fail_when_not_initialized() {
let backend = CudaBackend::new();
assert_eq!(
backend.synchronize().unwrap_err(),
BackendError::NotInitialized
);
assert_eq!(
backend
.gemm(
BackendTranspose::NoTrans,
BackendTranspose::NoTrans,
1,
1,
1,
1.0,
0,
1,
0,
1,
0.0,
0,
1,
)
.unwrap_err(),
BackendError::NotInitialized
);
assert_eq!(
backend.alloc(1024).unwrap_err(),
BackendError::NotInitialized
);
assert_eq!(backend.free(0).unwrap_err(), BackendError::NotInitialized);
assert_eq!(
backend.copy_htod(0, &[1, 2, 3]).unwrap_err(),
BackendError::NotInitialized
);
let mut buf = [0u8; 4];
assert_eq!(
backend.copy_dtoh(&mut buf, 0).unwrap_err(),
BackendError::NotInitialized
);
}
#[test]
fn conv2d_validates_shapes() {
let mut backend = CudaBackend::new();
backend.init().ok();
let result = backend.conv2d_forward(
0,
&[1, 3, 32], 0,
&[64, 3, 3, 3],
0,
&[1, 64, 30, 30],
&[1, 1],
&[0, 0],
);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
let result = backend.conv2d_forward(
0,
&[1, 3, 32, 32],
0,
&[64, 3, 3], 0,
&[1, 64, 30, 30],
&[1, 1],
&[0, 0],
);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
let result = backend.conv2d_forward(
0,
&[1, 3, 32, 32],
0,
&[64, 3, 3, 3],
0,
&[1, 64, 30, 30],
&[1], &[0, 0],
);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
}
#[test]
fn attention_validates_params() {
let mut backend = CudaBackend::new();
backend.init().ok();
let result = backend.attention(0, 0, 0, 0, 1, 1, 0, 128, 64, 0.125, false);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
let result = backend.attention(0, 0, 0, 0, 1, 1, 128, 128, 64, 0.0, false);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
let result = backend.attention(0, 0, 0, 0, 1, 1, 128, 128, 64, f64::NAN, false);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
}
#[test]
fn reduce_validates_axis() {
let mut backend = CudaBackend::new();
backend.init().ok();
let result = backend.reduce(ReduceOp::Sum, 0, 0, &[10, 20], 2);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
let result = backend.reduce(ReduceOp::Sum, 0, 0, &[], 0);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
}
#[test]
fn unary_binary_empty_is_noop() {
let mut backend = CudaBackend::new();
backend.init().ok();
assert!(backend.unary(UnaryOp::Relu, 0, 0, 0).is_ok());
assert!(backend.binary(BinaryOp::Add, 0, 0, 0, 0).is_ok());
}
#[test]
fn alloc_zero_bytes_is_error() {
let mut backend = CudaBackend::new();
backend.init().ok();
let result = backend.alloc(0);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
}
#[test]
fn copy_empty_is_noop() {
let mut backend = CudaBackend::new();
backend.init().ok();
assert!(backend.copy_htod(0, &[]).is_ok());
let mut empty: [u8; 0] = [];
assert!(backend.copy_dtoh(&mut empty, 0).is_ok());
}
#[test]
fn debug_impl() {
let backend = CudaBackend::new();
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("CudaBackend"));
assert!(debug_str.contains("initialized"));
}
#[cfg(feature = "blas")]
fn gpu_available(backend: &CudaBackend) -> bool {
backend.has_gpu_context()
}
#[cfg(feature = "blas")]
#[test]
fn gemm_wiring_identity_multiply() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !gpu_available(&backend) {
return; }
let a: Vec<f64> = vec![1.0, 0.0, 0.0, 1.0]; let b: Vec<f64> = vec![1.0, 3.0, 2.0, 4.0]; let mut c: Vec<f64> = vec![0.0; 4];
let bytes = 4 * std::mem::size_of::<f64>();
let a_ptr = backend.alloc(bytes).expect("alloc a");
let b_ptr = backend.alloc(bytes).expect("alloc b");
let c_ptr = backend.alloc(bytes).expect("alloc c");
backend.copy_htod(a_ptr, bytemuck_cast(&a)).expect("copy a");
backend.copy_htod(b_ptr, bytemuck_cast(&b)).expect("copy b");
backend.copy_htod(c_ptr, bytemuck_cast(&c)).expect("zero c");
let result = backend.gemm(
BackendTranspose::NoTrans,
BackendTranspose::NoTrans,
2,
2,
2,
1.0,
a_ptr,
2,
b_ptr,
2,
0.0,
c_ptr,
2,
);
let read_back = if result.is_ok() {
backend.synchronize().ok();
let mut out = vec![0u8; bytes];
backend.copy_dtoh(&mut out, c_ptr).ok();
Some(out)
} else {
None
};
backend.free(a_ptr).ok();
backend.free(b_ptr).ok();
backend.free(c_ptr).ok();
match result {
Ok(()) => {
let out = read_back.expect("read back");
for (i, chunk) in out.chunks_exact(8).enumerate() {
let mut arr = [0u8; 8];
arr.copy_from_slice(chunk);
c[i] = f64::from_le_bytes(arr);
}
if c == b {
} else {
for v in &c {
assert!(v.is_finite(), "GEMM produced non-finite output");
}
}
}
Err(BackendError::DeviceError(_)) => {}
Err(e) => panic!("unexpected GEMM error: {e:?}"),
}
}
#[cfg(feature = "blas")]
#[test]
fn gemm_wiring_rejects_zero_dims() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !gpu_available(&backend) {
return;
}
let result = backend.gemm(
BackendTranspose::NoTrans,
BackendTranspose::NoTrans,
0,
2,
2,
1.0,
0,
1,
0,
1,
0.0,
0,
1,
);
assert!(matches!(result, Err(BackendError::InvalidArgument(_))));
}
#[cfg(feature = "dnn")]
#[test]
fn conv2d_wiring_executes() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !backend.has_gpu_context() {
return;
}
let n = 1usize;
let c = 3usize;
let h = 8usize;
let w = 8usize;
let kk = 4usize;
let input_elems = n * c * h * w;
let filter_elems = kk * c;
let output_elems = n * kk * h * w;
let f4 = std::mem::size_of::<f32>();
let in_ptr = backend.alloc(input_elems * f4).expect("alloc in");
let filt_ptr = backend.alloc(filter_elems * f4).expect("alloc filt");
let out_ptr = backend.alloc(output_elems * f4).expect("alloc out");
let input = vec![1.0f32; input_elems];
let filter = vec![0.5f32; filter_elems];
backend
.copy_htod(in_ptr, f32_bytes(&input))
.expect("copy in");
backend
.copy_htod(filt_ptr, f32_bytes(&filter))
.expect("copy filt");
let result = backend.conv2d_forward(
in_ptr,
&[n, c, h, w],
filt_ptr,
&[kk, c, 1, 1],
out_ptr,
&[n, kk, h, w],
&[1, 1],
&[0, 0],
);
if result.is_ok() {
backend.synchronize().ok();
}
backend.free(in_ptr).ok();
backend.free(filt_ptr).ok();
backend.free(out_ptr).ok();
match result {
Ok(()) => {}
Err(BackendError::DeviceError(_)) | Err(BackendError::Unsupported(_)) => {}
Err(e) => panic!("unexpected conv2d error: {e:?}"),
}
}
#[cfg(feature = "dnn")]
#[test]
fn attention_wiring_executes() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !backend.has_gpu_context() {
return;
}
let batch = 1usize;
let heads = 2usize;
let seq = 16usize;
let head_dim = 64usize;
let elems = batch * heads * seq * head_dim;
let f4 = std::mem::size_of::<f32>();
let q_ptr = backend.alloc(elems * f4).expect("alloc q");
let k_ptr = backend.alloc(elems * f4).expect("alloc k");
let v_ptr = backend.alloc(elems * f4).expect("alloc v");
let o_ptr = backend.alloc(elems * f4).expect("alloc o");
let data = vec![0.1f32; elems];
backend.copy_htod(q_ptr, f32_bytes(&data)).expect("copy q");
backend.copy_htod(k_ptr, f32_bytes(&data)).expect("copy k");
backend.copy_htod(v_ptr, f32_bytes(&data)).expect("copy v");
let scale = 1.0 / (head_dim as f64).sqrt();
let result = backend.attention(
q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq, seq, head_dim, scale, false,
);
if result.is_ok() {
backend.synchronize().ok();
}
backend.free(q_ptr).ok();
backend.free(k_ptr).ok();
backend.free(v_ptr).ok();
backend.free(o_ptr).ok();
match result {
Ok(()) => {}
Err(BackendError::DeviceError(_)) | Err(BackendError::Unsupported(_)) => {}
Err(e) => panic!("unexpected attention error: {e:?}"),
}
}
#[cfg(feature = "ptx")]
#[test]
fn unary_wiring_relu() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !backend.has_gpu_context() {
return;
}
let input: Vec<f32> = vec![-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, -5.0, 4.0];
let n = input.len();
let f4 = std::mem::size_of::<f32>();
let in_ptr = backend.alloc(n * f4).expect("alloc in");
let out_ptr = backend.alloc(n * f4).expect("alloc out");
backend
.copy_htod(in_ptr, f32_bytes(&input))
.expect("copy in");
let result = backend.unary(UnaryOp::Relu, in_ptr, out_ptr, n);
let read_back = if result.is_ok() {
backend.synchronize().ok();
let mut out = vec![0u8; n * f4];
backend.copy_dtoh(&mut out, out_ptr).ok();
Some(out)
} else {
None
};
backend.free(in_ptr).ok();
backend.free(out_ptr).ok();
match result {
Ok(()) => {
let out = read_back.expect("read back");
for (i, chunk) in out.chunks_exact(4).enumerate() {
let mut arr = [0u8; 4];
arr.copy_from_slice(chunk);
let got = f32::from_le_bytes(arr);
let want = input[i].max(0.0);
assert!(
(got - want).abs() < 1e-5,
"relu[{i}] = {got}, expected {want}"
);
}
}
Err(BackendError::DeviceError(_)) => {}
Err(e) => panic!("unexpected unary error: {e:?}"),
}
}
#[cfg(feature = "ptx")]
#[test]
fn binary_wiring_add() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !backend.has_gpu_context() {
return;
}
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0];
let n = a.len();
let f4 = std::mem::size_of::<f32>();
let a_ptr = backend.alloc(n * f4).expect("alloc a");
let b_ptr = backend.alloc(n * f4).expect("alloc b");
let c_ptr = backend.alloc(n * f4).expect("alloc c");
backend.copy_htod(a_ptr, f32_bytes(&a)).expect("copy a");
backend.copy_htod(b_ptr, f32_bytes(&b)).expect("copy b");
let result = backend.binary(BinaryOp::Add, a_ptr, b_ptr, c_ptr, n);
let read_back = if result.is_ok() {
backend.synchronize().ok();
let mut out = vec![0u8; n * f4];
backend.copy_dtoh(&mut out, c_ptr).ok();
Some(out)
} else {
None
};
backend.free(a_ptr).ok();
backend.free(b_ptr).ok();
backend.free(c_ptr).ok();
match result {
Ok(()) => {
let out = read_back.expect("read back");
for (i, chunk) in out.chunks_exact(4).enumerate() {
let mut arr = [0u8; 4];
arr.copy_from_slice(chunk);
let got = f32::from_le_bytes(arr);
let want = a[i] + b[i];
assert!((got - want).abs() < 1e-4, "add[{i}] = {got}, want {want}");
}
}
Err(BackendError::DeviceError(_)) => {}
Err(e) => panic!("unexpected binary error: {e:?}"),
}
}
#[cfg(feature = "ptx")]
#[test]
fn reduce_wiring_sum_axis() {
let mut backend = CudaBackend::new();
backend.init().ok();
if !backend.has_gpu_context() {
return;
}
let input: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let f4 = std::mem::size_of::<f32>();
let in_ptr = backend.alloc(input.len() * f4).expect("alloc in");
let out_ptr = backend.alloc(2 * f4).expect("alloc out");
backend
.copy_htod(in_ptr, f32_bytes(&input))
.expect("copy in");
let result = backend.reduce(ReduceOp::Sum, in_ptr, out_ptr, &[2, 4], 1);
let read_back = if result.is_ok() {
backend.synchronize().ok();
let mut out = vec![0u8; 2 * f4];
backend.copy_dtoh(&mut out, out_ptr).ok();
Some(out)
} else {
None
};
backend.free(in_ptr).ok();
backend.free(out_ptr).ok();
match result {
Ok(()) => {
let out = read_back.expect("read back");
let mut got = [0.0f32; 2];
for (i, chunk) in out.chunks_exact(4).enumerate() {
let mut arr = [0u8; 4];
arr.copy_from_slice(chunk);
got[i] = f32::from_le_bytes(arr);
}
assert!((got[0] - 10.0).abs() < 1e-4, "row0 sum = {}", got[0]);
assert!((got[1] - 26.0).abs() < 1e-4, "row1 sum = {}", got[1]);
}
Err(BackendError::DeviceError(_)) => {}
Err(e) => panic!("unexpected reduce error: {e:?}"),
}
}
#[cfg(any(feature = "ptx", feature = "dnn"))]
fn f32_bytes(data: &[f32]) -> &[u8] {
unsafe {
std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), std::mem::size_of_val(data))
}
}
#[cfg(feature = "blas")]
fn bytemuck_cast(data: &[f64]) -> &[u8] {
unsafe {
std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), std::mem::size_of_val(data))
}
}
}