#[cfg(feature = "cuda")]
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys::cublasOperation_t};
#[cfg(feature = "cudnn")]
use cudarc::cudnn::Cudnn;
#[cfg(feature = "cuda")]
use cudarc::driver::{
CudaContext, CudaSlice, CudaStream, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
};
use super::Backend;
#[cfg(feature = "cuda")]
use super::cuda_kernels::{self, BLOCK_SIZE, CudaKernels};
use crate::device::DeviceCapabilities;
#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
use std::sync::OnceLock;
#[cfg(feature = "cuda")]
static CUDA_BACKEND: OnceLock<Option<CudaBackend>> = OnceLock::new();
#[cfg(feature = "cuda")]
pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
CUDA_BACKEND
.get_or_init(|| {
let backend = CudaBackend::new(0);
if backend.is_some() {
eprintln!("[AxonML] CUDA backend initialized (GPU 0)");
}
backend
})
.as_ref()
}
#[cfg(not(feature = "cuda"))]
pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
None
}
#[cfg(feature = "cuda")]
pub struct CudaBackend {
device_index: usize,
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
blas: CudaBlas,
kernels: CudaKernels,
#[cfg(feature = "cudnn")]
cudnn_handle: Option<Arc<Cudnn>>,
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct CudaBackend {
device_index: usize,
}
#[cfg(feature = "cuda")]
unsafe impl Send for CudaBackend {}
#[cfg(feature = "cuda")]
unsafe impl Sync for CudaBackend {}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for CudaBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaBackend")
.field("device_index", &self.device_index)
.finish()
}
}
impl CudaBackend {
#[cfg(feature = "cuda")]
pub fn new(device_index: usize) -> Option<Self> {
let ctx = CudaContext::new(device_index).ok()?;
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone()).ok()?;
let kernels = match CudaKernels::load(ctx.clone()) {
Ok(k) => k,
Err(e) => {
eprintln!("[AxonML CUDA] Kernel loading failed: {:?}", e);
return None;
}
};
#[cfg(feature = "cudnn")]
let cudnn_handle = match Cudnn::new(stream.clone()) {
Ok(handle) => {
eprintln!("[AxonML] cuDNN handle initialized");
Some(handle)
}
Err(e) => {
eprintln!(
"[AxonML CUDA] cuDNN init failed: {:?} (falling back to im2col+GEMM)",
e
);
None
}
};
Some(Self {
device_index,
ctx,
stream,
blas,
kernels,
#[cfg(feature = "cudnn")]
cudnn_handle,
})
}
#[cfg(not(feature = "cuda"))]
pub fn new(device_index: usize) -> Option<Self> {
let _ = device_index;
None }
pub fn device_index(&self) -> usize {
self.device_index
}
#[cfg(feature = "cuda")]
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
#[cfg(feature = "cuda")]
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
#[cfg(feature = "cuda")]
pub fn blas(&self) -> &CudaBlas {
&self.blas
}
#[cfg(feature = "cudnn")]
pub fn cudnn(&self) -> Option<&Arc<Cudnn>> {
self.cudnn_handle.as_ref()
}
#[cfg(feature = "cuda")]
pub fn alloc<T: DeviceRepr + ValidAsZeroBits>(
&self,
len: usize,
) -> Result<CudaSlice<T>, CudaError> {
self.stream.alloc_zeros(len).map_err(CudaError::from)
}
#[cfg(feature = "cuda")]
pub fn alloc_uninit<T: DeviceRepr>(&self, len: usize) -> Result<CudaSlice<T>, CudaError> {
unsafe { self.stream.alloc(len).map_err(CudaError::from) }
}
#[cfg(feature = "cuda")]
pub fn htod_copy<T: DeviceRepr>(&self, src: &[T]) -> Result<CudaSlice<T>, CudaError> {
self.stream.clone_htod(src).map_err(CudaError::from)
}
#[cfg(feature = "cuda")]
pub fn dtoh_copy<T: DeviceRepr>(&self, src: &CudaSlice<T>) -> Result<Vec<T>, CudaError> {
self.stream.clone_dtoh(src).map_err(CudaError::from)
}
}
#[cfg(feature = "cuda")]
impl Backend for CudaBackend {
fn name(&self) -> &'static str {
"cuda"
}
fn is_available(&self) -> bool {
true
}
fn capabilities(&self) -> DeviceCapabilities {
let name = format!("CUDA Device {}", self.device_index);
let (free, total) = cudarc::driver::result::mem_get_info().unwrap_or((0, 0));
DeviceCapabilities {
name,
total_memory: total,
available_memory: free,
supports_f16: true,
supports_f64: true,
max_threads_per_block: 1024,
compute_capability: None, }
}
fn allocate(&self, size: usize) -> *mut u8 {
match self.stream.alloc_zeros::<u8>(size) {
Ok(slice) => {
let ptr = slice.leak() as *mut u8;
ptr
}
Err(_) => std::ptr::null_mut(),
}
}
fn deallocate(&self, ptr: *mut u8, size: usize) {
if !ptr.is_null() {
unsafe {
let slice: CudaSlice<u8> = self
.stream
.upgrade_device_ptr(ptr as cudarc::driver::sys::CUdeviceptr, size);
drop(slice);
}
}
}
fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
if dst.is_null() || src.is_null() || size == 0 {
return;
}
unsafe {
let src_slice = std::slice::from_raw_parts(src, size);
let _ = cudarc::driver::result::memcpy_htod_sync(
dst as cudarc::driver::sys::CUdeviceptr,
src_slice,
);
}
}
fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
if dst.is_null() || src.is_null() || size == 0 {
return;
}
unsafe {
let dst_slice = std::slice::from_raw_parts_mut(dst, size);
let _ = cudarc::driver::result::memcpy_dtoh_sync(
dst_slice,
src as cudarc::driver::sys::CUdeviceptr,
);
}
}
fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
if dst.is_null() || src.is_null() || size == 0 {
return;
}
unsafe {
let _ = cudarc::driver::result::memcpy_dtod_sync(
dst as cudarc::driver::sys::CUdeviceptr,
src as cudarc::driver::sys::CUdeviceptr,
size,
);
}
}
fn synchronize(&self) {
let _ = self.stream.synchronize();
}
}
#[cfg(feature = "cuda")]
pub fn cuda_sync() -> bool {
if let Some(backend) = get_cuda_backend() {
let _ = backend.stream.synchronize();
true
} else {
false
}
}
#[cfg(not(feature = "cuda"))]
pub fn cuda_sync() -> bool {
false
}
#[cfg(not(feature = "cuda"))]
impl Backend for CudaBackend {
fn name(&self) -> &'static str {
"cuda"
}
fn is_available(&self) -> bool {
false
}
fn capabilities(&self) -> DeviceCapabilities {
DeviceCapabilities {
name: format!("CUDA Device {} (unavailable)", self.device_index),
total_memory: 0,
available_memory: 0,
supports_f16: false,
supports_f64: false,
max_threads_per_block: 0,
compute_capability: None,
}
}
fn allocate(&self, _size: usize) -> *mut u8 {
std::ptr::null_mut()
}
fn deallocate(&self, _ptr: *mut u8, _size: usize) {}
fn copy_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
fn copy_to_host(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
fn copy_device_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
fn synchronize(&self) {}
}
#[derive(Debug)]
pub enum CudaError {
DeviceNotFound,
AllocationFailed,
CopyFailed,
KernelLaunchFailed,
BlasError(String),
DriverError(String),
ModuleLoadFailed(String),
KernelNotFound(String),
}
impl std::fmt::Display for CudaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CudaError::DeviceNotFound => write!(f, "CUDA device not found"),
CudaError::AllocationFailed => write!(f, "CUDA memory allocation failed"),
CudaError::CopyFailed => write!(f, "CUDA memory copy failed"),
CudaError::KernelLaunchFailed => write!(f, "CUDA kernel launch failed"),
CudaError::BlasError(s) => write!(f, "cuBLAS error: {}", s),
CudaError::DriverError(s) => write!(f, "CUDA driver error: {}", s),
CudaError::ModuleLoadFailed(s) => write!(f, "CUDA module load failed: {}", s),
CudaError::KernelNotFound(s) => write!(f, "CUDA kernel not found: {}", s),
}
}
}
impl std::error::Error for CudaError {}
#[cfg(feature = "cuda")]
impl From<cudarc::driver::DriverError> for CudaError {
fn from(e: cudarc::driver::DriverError) -> Self {
CudaError::DriverError(e.to_string())
}
}
#[cfg(feature = "cuda")]
impl From<cudarc::cublas::result::CublasError> for CudaError {
fn from(e: cudarc::cublas::result::CublasError) -> Self {
CudaError::BlasError(format!("{:?}", e))
}
}
pub fn is_available() -> bool {
#[cfg(feature = "cuda")]
{
CudaContext::new(0).is_ok()
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
pub fn device_count() -> usize {
#[cfg(feature = "cuda")]
{
cudarc::driver::result::device::get_count().unwrap_or(0) as usize
}
#[cfg(not(feature = "cuda"))]
{
0
}
}
pub fn is_device_available(index: usize) -> bool {
index < device_count()
}
pub fn get_capabilities(index: usize) -> DeviceCapabilities {
#[cfg(feature = "cuda")]
{
if let Some(backend) = CudaBackend::new(index) {
return backend.capabilities();
}
}
#[allow(unreachable_code)]
DeviceCapabilities {
name: format!("CUDA Device {}", index),
total_memory: 0,
available_memory: 0,
supports_f16: true,
supports_f64: true,
max_threads_per_block: 1024,
compute_capability: None,
}
}
#[cfg(feature = "cuda")]
pub fn stream_synchronize(_handle: usize) {
}
#[cfg(not(feature = "cuda"))]
pub fn stream_synchronize(_handle: usize) {
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn gemm_f32(
&self,
transa: bool,
transb: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &CudaSlice<f32>,
lda: usize,
b: &CudaSlice<f32>,
ldb: usize,
beta: f32,
c: &mut CudaSlice<f32>,
ldc: usize,
) -> Result<(), CudaError> {
use cudarc::cublas::result::sgemm;
use cudarc::driver::DevicePtr as _;
use cudarc::driver::DevicePtrMut as _;
let op_a = if transa {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let op_b = if transb {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let (a_ptr, _ga) = a.device_ptr(&self.stream);
let (b_ptr, _gb) = b.device_ptr(&self.stream);
let (c_ptr, _gc) = c.device_ptr_mut(&self.stream);
unsafe {
sgemm(
*self.blas.handle(),
op_a,
op_b,
m as i32,
n as i32,
k as i32,
&alpha as *const f32,
a_ptr as *const f32,
lda as i32,
b_ptr as *const f32,
ldb as i32,
&beta as *const f32,
c_ptr as *mut f32,
ldc as i32,
)
.map_err(CudaError::from)
}
}
pub fn gemm_batched_f32(
&self,
transa: bool,
transb: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a_array: &[&CudaSlice<f32>],
lda: usize,
b_array: &[&CudaSlice<f32>],
ldb: usize,
beta: f32,
c_array: &mut [&mut CudaSlice<f32>],
ldc: usize,
batch_count: usize,
) -> Result<(), CudaError> {
for i in 0..batch_count {
let cfg = GemmConfig {
transa: if transa {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
},
transb: if transb {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
},
m: m as i32,
n: n as i32,
k: k as i32,
alpha,
lda: lda as i32,
ldb: ldb as i32,
beta,
ldc: ldc as i32,
};
unsafe {
self.blas
.gemm(cfg, a_array[i], b_array[i], c_array[i])
.map_err(CudaError::from)?;
}
}
Ok(())
}
pub fn gemm_strided_batched_f32(
&self,
transa: bool,
transb: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: &CudaSlice<f32>,
lda: usize,
stride_a: i64,
b: &CudaSlice<f32>,
ldb: usize,
stride_b: i64,
beta: f32,
c: &mut CudaSlice<f32>,
ldc: usize,
stride_c: i64,
batch_count: usize,
) -> Result<(), CudaError> {
use cudarc::cublas::result::sgemm_strided_batched;
use cudarc::driver::DevicePtr as _;
use cudarc::driver::DevicePtrMut as _;
let op_a = if transa {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let op_b = if transb {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let (a_devptr, _ga) = a.device_ptr(&self.stream);
let (b_devptr, _gb) = b.device_ptr(&self.stream);
let (c_devptr, _gc) = c.device_ptr_mut(&self.stream);
let a_ptr = a_devptr as *const f32;
let b_ptr = b_devptr as *const f32;
let c_ptr = c_devptr as *mut f32;
unsafe {
sgemm_strided_batched(
*self.blas.handle(),
op_a,
op_b,
m as i32,
n as i32,
k as i32,
&alpha as *const f32,
a_ptr,
lda as i32,
stride_a,
b_ptr,
ldb as i32,
stride_b,
&beta as *const f32,
c_ptr,
ldc as i32,
stride_c,
batch_count as i32,
)
.map_err(CudaError::from)
}
}
pub fn add_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("add_f32")
.ok_or_else(|| CudaError::KernelNotFound("add_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn scale_f32(
&self,
dst: &mut CudaSlice<f32>,
alpha: f32,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("scale_f32")
.ok_or_else(|| CudaError::KernelNotFound("scale_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(dst)
.arg(&alpha)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn mul_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("mul_f32")
.ok_or_else(|| CudaError::KernelNotFound("mul_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn q4k_gemm_f32(
&self,
w: &CudaSlice<u8>,
a: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
m_dim: usize,
out_dim: usize,
in_dim: usize,
) -> Result<(), CudaError> {
debug_assert!(in_dim % 256 == 0, "Q4_K GEMM requires in_dim % 256 == 0");
let func = self
.kernels
.get("q4k_gemm_f32")
.ok_or_else(|| CudaError::KernelNotFound("q4k_gemm_f32".to_string()))?;
let total = m_dim * out_dim;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(w)
.arg(a)
.arg(c)
.arg(&(m_dim as u32))
.arg(&(out_dim as u32))
.arg(&(in_dim as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn q4k_gemv_f32(
&self,
w: &CudaSlice<u8>,
a: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
out_dim: usize,
in_dim: usize,
) -> Result<(), CudaError> {
debug_assert!(in_dim % 256 == 0, "Q4_K GEMV requires in_dim % 256 == 0");
let func = self
.kernels
.get("q4k_gemv_f32")
.ok_or_else(|| CudaError::KernelNotFound("q4k_gemv_f32".to_string()))?;
const WARPS_PER_CTA: u32 = 4;
const THREADS_PER_CTA: u32 = WARPS_PER_CTA * 32;
let grid = ((out_dim as u32) + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS_PER_CTA, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
self.stream
.launch_builder(func)
.arg(w)
.arg(a)
.arg(c)
.arg(&(out_dim as u32))
.arg(&(in_dim as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn q6k_gemm_f32(
&self,
w: &CudaSlice<u8>,
a: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
m_dim: usize,
out_dim: usize,
in_dim: usize,
) -> Result<(), CudaError> {
debug_assert!(in_dim % 256 == 0, "Q6_K GEMM requires in_dim % 256 == 0");
let func = self
.kernels
.get("q6k_gemm_f32")
.ok_or_else(|| CudaError::KernelNotFound("q6k_gemm_f32".to_string()))?;
let total = m_dim * out_dim;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(w)
.arg(a)
.arg(c)
.arg(&(m_dim as u32))
.arg(&(out_dim as u32))
.arg(&(in_dim as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn q6k_gemv_f32(
&self,
w: &CudaSlice<u8>,
a: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
out_dim: usize,
in_dim: usize,
) -> Result<(), CudaError> {
debug_assert!(in_dim % 256 == 0, "Q6_K GEMV requires in_dim % 256 == 0");
let func = self
.kernels
.get("q6k_gemv_f32")
.ok_or_else(|| CudaError::KernelNotFound("q6k_gemv_f32".to_string()))?;
const WARPS_PER_CTA: u32 = 4;
const THREADS_PER_CTA: u32 = WARPS_PER_CTA * 32;
let grid = ((out_dim as u32) + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS_PER_CTA, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
self.stream
.launch_builder(func)
.arg(w)
.arg(a)
.arg(c)
.arg(&(out_dim as u32))
.arg(&(in_dim as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn relu_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("relu_f32")
.ok_or_else(|| CudaError::KernelNotFound("relu_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn sigmoid_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("sigmoid_f32")
.ok_or_else(|| CudaError::KernelNotFound("sigmoid_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn tanh_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("tanh_f32")
.ok_or_else(|| CudaError::KernelNotFound("tanh_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn sub_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("sub_f32")
.ok_or_else(|| CudaError::KernelNotFound("sub_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn div_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("div_f32")
.ok_or_else(|| CudaError::KernelNotFound("div_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_add_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
b_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_add_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_add_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(b_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_sub_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
b_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_sub_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(b_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_mul_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
b_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_mul_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(b_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_div_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
b_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_div_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_div_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(b_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_add_rev_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
a_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_add_rev_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_add_rev_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(a_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_sub_rev_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
a_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_sub_rev_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_rev_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(a_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_mul_rev_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
a_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_mul_rev_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_rev_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(a_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_div_rev_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
a_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_div_rev_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_div_rev_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(n as u32))
.arg(&(a_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn neg_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("neg_f32")
.ok_or_else(|| CudaError::KernelNotFound("neg_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn pow_f32(
&self,
dst: &mut CudaSlice<f32>,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("pow_f32")
.ok_or_else(|| CudaError::KernelNotFound("pow_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(a)
.arg(b)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn pow_scalar_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
exp: f32,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("pow_scalar_f32")
.ok_or_else(|| CudaError::KernelNotFound("pow_scalar_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(&exp)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn exp_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("exp_f32")
.ok_or_else(|| CudaError::KernelNotFound("exp_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn log_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("log_f32")
.ok_or_else(|| CudaError::KernelNotFound("log_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn sqrt_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("sqrt_f32")
.ok_or_else(|| CudaError::KernelNotFound("sqrt_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn gelu_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("gelu_f32")
.ok_or_else(|| CudaError::KernelNotFound("gelu_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn silu_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("silu_f32")
.ok_or_else(|| CudaError::KernelNotFound("silu_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn add_scalar_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
scalar: f32,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("add_scalar_f32")
.ok_or_else(|| CudaError::KernelNotFound("add_scalar_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(&scalar)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn relu_backward_f32(
&self,
dst: &mut CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
input: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("relu_backward_f32")
.ok_or_else(|| CudaError::KernelNotFound("relu_backward_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(input)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn sigmoid_backward_f32(
&self,
dst: &mut CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
output: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("sigmoid_backward_f32")
.ok_or_else(|| CudaError::KernelNotFound("sigmoid_backward_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(output)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn tanh_backward_f32(
&self,
dst: &mut CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
output: &CudaSlice<f32>,
len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("tanh_backward_f32")
.ok_or_else(|| CudaError::KernelNotFound("tanh_backward_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(len);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(output)
.arg(dst)
.arg(&(len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn sum_dim_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
outer_size: usize,
dim_size: usize,
inner_size: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("sum_dim_f32")
.ok_or_else(|| CudaError::KernelNotFound("sum_dim_f32".to_string()))?;
let out_len = outer_size * inner_size;
let cfg = cuda_kernels::launch_config(out_len);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(outer_size as u32))
.arg(&(dim_size as u32))
.arg(&(inner_size as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn softmax_row_f32(
&self,
data: &mut CudaSlice<f32>,
num_rows: usize,
row_size: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("softmax_row_f32")
.ok_or_else(|| CudaError::KernelNotFound("softmax_row_f32".to_string()))?;
let cfg = LaunchConfig {
grid_dim: (num_rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: BLOCK_SIZE * 4,
};
unsafe {
self.stream
.launch_builder(func)
.arg(data)
.arg(&(num_rows as u32))
.arg(&(row_size as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn broadcast_copy_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
n: usize,
src_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("broadcast_copy_f32")
.ok_or_else(|| CudaError::KernelNotFound("broadcast_copy_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(&(n as u32))
.arg(&(src_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn layer_norm_f32(
&self,
dst: &mut CudaSlice<f32>,
input: &CudaSlice<f32>,
gamma: &CudaSlice<f32>,
beta: &CudaSlice<f32>,
norm_size: usize,
eps: f32,
num_rows: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("layer_norm_f32")
.ok_or_else(|| CudaError::KernelNotFound("layer_norm_f32".to_string()))?;
let cfg = LaunchConfig {
grid_dim: (num_rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: BLOCK_SIZE * 4,
};
unsafe {
self.stream
.launch_builder(func)
.arg(input)
.arg(gamma)
.arg(beta)
.arg(dst)
.arg(&(norm_size as u32))
.arg(&eps)
.arg(&(num_rows as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn softmax_backward_row_f32(
&self,
dst: &mut CudaSlice<f32>,
softmax_output: &CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
num_rows: usize,
row_size: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("softmax_backward_row_f32")
.ok_or_else(|| CudaError::KernelNotFound("softmax_backward_row_f32".to_string()))?;
let cfg = LaunchConfig {
grid_dim: (num_rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: BLOCK_SIZE * 4,
};
unsafe {
self.stream
.launch_builder(func)
.arg(softmax_output)
.arg(grad_output)
.arg(dst)
.arg(&(num_rows as u32))
.arg(&(row_size as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn layer_norm_backward_dinput_f32(
&self,
d_input: &mut CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
input: &CudaSlice<f32>,
gamma: &CudaSlice<f32>,
norm_size: usize,
eps: f32,
num_rows: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("layer_norm_backward_dinput_f32")
.ok_or_else(|| {
CudaError::KernelNotFound("layer_norm_backward_dinput_f32".to_string())
})?;
let cfg = LaunchConfig {
grid_dim: (num_rows as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: BLOCK_SIZE * 4 * 2, };
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(input)
.arg(gamma)
.arg(d_input)
.arg(&(norm_size as u32))
.arg(&eps)
.arg(&(num_rows as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn layer_norm_backward_dweight_dbias_f32(
&self,
d_weight: &mut CudaSlice<f32>,
d_bias: &mut CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
input: &CudaSlice<f32>,
norm_size: usize,
eps: f32,
num_rows: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("layer_norm_backward_dweight_dbias_f32")
.ok_or_else(|| {
CudaError::KernelNotFound("layer_norm_backward_dweight_dbias_f32".to_string())
})?;
let cfg = cuda_kernels::launch_config(norm_size);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(input)
.arg(d_weight)
.arg(d_bias)
.arg(&(norm_size as u32))
.arg(&eps)
.arg(&(num_rows as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn gather_contiguous_f32(
&self,
dst: &mut CudaSlice<f32>,
src: &CudaSlice<f32>,
indices: &CudaSlice<u32>,
n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("gather_contiguous_f32")
.ok_or_else(|| CudaError::KernelNotFound("gather_contiguous_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(indices)
.arg(dst)
.arg(&(n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn embedding_scatter_add_f32(
&self,
grad_src: &CudaSlice<f32>,
indices: &CudaSlice<u32>,
weight_grad: &mut CudaSlice<f32>,
total_n: usize,
emb_dim: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("embedding_scatter_add_f32")
.ok_or_else(|| CudaError::KernelNotFound("embedding_scatter_add_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total_n);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_src)
.arg(indices)
.arg(weight_grad)
.arg(&(total_n as u32))
.arg(&(emb_dim as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step_f32(
&self,
param: &mut CudaSlice<f32>,
grad: &CudaSlice<f32>,
exp_avg: &mut CudaSlice<f32>,
exp_avg_sq: &mut CudaSlice<f32>,
n: usize,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
bias_correction1: f32,
bias_correction2: f32,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("adam_step_f32")
.ok_or_else(|| CudaError::KernelNotFound("adam_step_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(param)
.arg(grad)
.arg(exp_avg)
.arg(exp_avg_sq)
.arg(&(n as u32))
.arg(&lr)
.arg(&beta1)
.arg(&beta2)
.arg(&eps)
.arg(&weight_decay)
.arg(&bias_correction1)
.arg(&bias_correction2)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn grad_norm_sq_f32(
&self,
data: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("grad_norm_sq_f32")
.ok_or_else(|| CudaError::KernelNotFound("grad_norm_sq_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(data)
.arg(output)
.arg(&(n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn grad_scale_f32(
&self,
data: &mut CudaSlice<f32>,
n: usize,
scale: f32,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("grad_scale_f32")
.ok_or_else(|| CudaError::KernelNotFound("grad_scale_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(data)
.arg(&(n as u32))
.arg(&scale)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn cross_entropy_fwd_f32(
&self,
logits: &CudaSlice<f32>,
targets: &CudaSlice<f32>,
losses: &mut CudaSlice<f32>,
softmax_out: &mut CudaSlice<f32>,
batch_size: usize,
num_classes: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("cross_entropy_fwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("cross_entropy_fwd_f32".to_string()))?;
let cfg = LaunchConfig {
grid_dim: (batch_size as u32, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: BLOCK_SIZE * 4,
};
unsafe {
self.stream
.launch_builder(func)
.arg(logits)
.arg(targets)
.arg(losses)
.arg(softmax_out)
.arg(&(num_classes as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn cross_entropy_bwd_f32(
&self,
softmax_probs: &CudaSlice<f32>,
targets: &CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
grad_input: &mut CudaSlice<f32>,
batch_size: usize,
num_classes: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("cross_entropy_bwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("cross_entropy_bwd_f32".to_string()))?;
let total = batch_size * num_classes;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(softmax_probs)
.arg(targets)
.arg(grad_output)
.arg(grad_input)
.arg(&(batch_size as u32))
.arg(&(num_classes as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn memset_zeros_f32(&self, dst: &mut CudaSlice<f32>) -> Result<(), CudaError> {
self.stream
.memset_zeros(dst)
.map_err(|e| CudaError::DriverError(e.to_string()))
}
#[cfg(feature = "cuda")]
pub fn memcpy_dtod_f32(
&self,
dst: &mut CudaSlice<f32>,
dst_offset: usize,
src: &CudaSlice<f32>,
src_offset: usize,
count: usize,
) -> Result<(), CudaError> {
use cudarc::driver::DevicePtr as _;
let (src_ptr, _guard_s) = src.device_ptr(&self.stream);
let src_ptr =
src_ptr + (src_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
use cudarc::driver::DevicePtrMut as _;
let (dst_ptr, _guard_d) = dst.device_ptr_mut(&self.stream);
let dst_ptr =
dst_ptr + (dst_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
let size = count * std::mem::size_of::<f32>();
unsafe {
cudarc::driver::result::memcpy_dtod_sync(dst_ptr, src_ptr, size)
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn mask_expand_causal_f32(
&self,
mask: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
total_n: usize,
tgt_len: usize,
src_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("mask_expand_causal_f32")
.ok_or_else(|| CudaError::KernelNotFound("mask_expand_causal_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total_n);
unsafe {
self.stream
.launch_builder(func)
.arg(mask)
.arg(output)
.arg(&(total_n as u32))
.arg(&(tgt_len as u32))
.arg(&(src_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn mask_expand_padding_f32(
&self,
mask: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
total_n: usize,
num_heads: usize,
tgt_len: usize,
src_len: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("mask_expand_padding_f32")
.ok_or_else(|| CudaError::KernelNotFound("mask_expand_padding_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total_n);
unsafe {
self.stream
.launch_builder(func)
.arg(mask)
.arg(output)
.arg(&(total_n as u32))
.arg(&(num_heads as u32))
.arg(&(tgt_len as u32))
.arg(&(src_len as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn strided_gather_f32(
&self,
src: &CudaSlice<f32>,
dst: &mut CudaSlice<f32>,
strides: &CudaSlice<i64>,
shape: &CudaSlice<u32>,
ndim: usize,
offset: usize,
total_n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("strided_gather_f32")
.ok_or_else(|| CudaError::KernelNotFound("strided_gather_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total_n);
unsafe {
self.stream
.launch_builder(func)
.arg(src)
.arg(dst)
.arg(strides)
.arg(shape)
.arg(&(ndim as u32))
.arg(&(offset as u32))
.arg(&(total_n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn lstm_gates_f32(
&self,
gates: &CudaSlice<f32>,
c_prev: &CudaSlice<f32>,
h_new: &mut CudaSlice<f32>,
c_new: &mut CudaSlice<f32>,
hidden_size: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("lstm_gates_f32")
.ok_or_else(|| CudaError::KernelNotFound("lstm_gates_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(gates)
.arg(c_prev)
.arg(h_new)
.arg(c_new)
.arg(&(hidden_size as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn lstm_gates_backward_f32(
&self,
gates: &CudaSlice<f32>,
c_prev: &CudaSlice<f32>,
c_new: &CudaSlice<f32>,
grad_h: &CudaSlice<f32>,
grad_c_next: &CudaSlice<f32>,
grad_gates: &mut CudaSlice<f32>,
grad_c_prev: &mut CudaSlice<f32>,
hidden_size: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("lstm_gates_backward_f32")
.ok_or_else(|| CudaError::KernelNotFound("lstm_gates_backward_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(gates)
.arg(c_prev)
.arg(c_new)
.arg(grad_h)
.arg(grad_c_next)
.arg(grad_gates)
.arg(grad_c_prev)
.arg(&(hidden_size as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn gru_gates_f32(
&self,
gates_ih: &CudaSlice<f32>,
gates_hh: &CudaSlice<f32>,
h_prev: &CudaSlice<f32>,
h_new: &mut CudaSlice<f32>,
hidden_size: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("gru_gates_f32")
.ok_or_else(|| CudaError::KernelNotFound("gru_gates_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(gates_ih)
.arg(gates_hh)
.arg(h_prev)
.arg(h_new)
.arg(&(hidden_size as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn gru_gates_backward_f32(
&self,
gates_ih: &CudaSlice<f32>,
gates_hh: &CudaSlice<f32>,
h_prev: &CudaSlice<f32>,
grad_h_new: &CudaSlice<f32>,
grad_gates_ih: &mut CudaSlice<f32>,
grad_gates_hh: &mut CudaSlice<f32>,
grad_h_prev: &mut CudaSlice<f32>,
hidden_size: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("gru_gates_backward_f32")
.ok_or_else(|| CudaError::KernelNotFound("gru_gates_backward_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(gates_ih)
.arg(gates_hh)
.arg(h_prev)
.arg(grad_h_new)
.arg(grad_gates_ih)
.arg(grad_gates_hh)
.arg(grad_h_prev)
.arg(&(hidden_size as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn batchnorm_stats_f32(
&self,
x: &CudaSlice<f32>,
sum_out: &mut CudaSlice<f32>,
sum_sq_out: &mut CudaSlice<f32>,
n: usize,
c: usize,
spatial: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("batchnorm_stats_f32")
.ok_or_else(|| CudaError::KernelNotFound("batchnorm_stats_f32".to_string()))?;
let total = n * c * spatial;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(x)
.arg(sum_out)
.arg(sum_sq_out)
.arg(&(n as u32))
.arg(&(c as u32))
.arg(&(spatial as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn batchnorm_norm_f32(
&self,
x: &CudaSlice<f32>,
mean: &CudaSlice<f32>,
var: &CudaSlice<f32>,
gamma: &CudaSlice<f32>,
beta: &CudaSlice<f32>,
y: &mut CudaSlice<f32>,
eps: f32,
c: usize,
spatial: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("batchnorm_norm_f32")
.ok_or_else(|| CudaError::KernelNotFound("batchnorm_norm_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(x)
.arg(mean)
.arg(var)
.arg(gamma)
.arg(beta)
.arg(y)
.arg(&eps)
.arg(&(c as u32))
.arg(&(spatial as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn fused_attention_fwd_f32(
&self,
q: &CudaSlice<f32>,
k: &CudaSlice<f32>,
v: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
scale: f32,
batch_size: usize,
num_heads: usize,
tgt_len: usize,
src_len: usize,
head_dim: usize,
is_causal: bool,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("fused_attention_fwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("fused_attention_fwd_f32".to_string()))?;
let total_rows = batch_size * num_heads * tgt_len;
let cfg = cuda_kernels::launch_config(total_rows);
let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
unsafe {
self.stream
.launch_builder(func)
.arg(q)
.arg(k)
.arg(v)
.arg(output)
.arg(&scale)
.arg(&(batch_size as u32))
.arg(&(num_heads as u32))
.arg(&(tgt_len as u32))
.arg(&(src_len as u32))
.arg(&(head_dim as u32))
.arg(&is_causal_u32)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn fused_attn_prefill_f32(
&self,
q: &CudaSlice<f32>,
k_cache: &CudaSlice<f32>,
v_cache: &CudaSlice<f32>,
out: &mut CudaSlice<f32>,
seq_len: usize,
total_kv_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
pos_offset: usize,
swa_window: usize,
scale: f32,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("fused_attn_prefill_f32")
.ok_or_else(|| CudaError::KernelNotFound("fused_attn_prefill_f32".to_string()))?;
let total_ctas = seq_len * n_heads;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_ctas as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
self.stream
.launch_builder(func)
.arg(q)
.arg(k_cache)
.arg(v_cache)
.arg(out)
.arg(&(seq_len as u32))
.arg(&(total_kv_len as u32))
.arg(&(n_heads as u32))
.arg(&(n_kv_heads as u32))
.arg(&(head_dim as u32))
.arg(&(pos_offset as u32))
.arg(&(swa_window as u32))
.arg(&scale)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn fused_attn_decode_f32(
&self,
q: &CudaSlice<f32>,
k_cache: &CudaSlice<f32>,
v_cache: &CudaSlice<f32>,
out: &mut CudaSlice<f32>,
kv_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
swa_window: usize,
scale: f32,
) -> Result<(), CudaError> {
debug_assert!(
head_dim <= 512,
"fused_attn_decode_f32: head_dim {head_dim} exceeds kernel MAX_DIMS budget"
);
debug_assert!(
n_kv_heads > 0 && n_heads % n_kv_heads == 0,
"fused_attn_decode_f32: n_heads ({n_heads}) must be a multiple of n_kv_heads ({n_kv_heads})"
);
let func = self
.kernels
.get("fused_attn_decode_f32")
.ok_or_else(|| CudaError::KernelNotFound("fused_attn_decode_f32".to_string()))?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (n_heads as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
self.stream
.launch_builder(func)
.arg(q)
.arg(k_cache)
.arg(v_cache)
.arg(out)
.arg(&(kv_len as u32))
.arg(&(n_heads as u32))
.arg(&(n_kv_heads as u32))
.arg(&(head_dim as u32))
.arg(&(swa_window as u32))
.arg(&scale)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn fused_attention_bwd_f32(
&self,
q: &CudaSlice<f32>,
k: &CudaSlice<f32>,
v: &CudaSlice<f32>,
o: &CudaSlice<f32>,
grad_o: &CudaSlice<f32>,
grad_q: &mut CudaSlice<f32>,
grad_k: &mut CudaSlice<f32>,
grad_v: &mut CudaSlice<f32>,
scale: f32,
batch_size: usize,
num_heads: usize,
tgt_len: usize,
src_len: usize,
head_dim: usize,
is_causal: bool,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("fused_attention_bwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("fused_attention_bwd_f32".to_string()))?;
let total_rows = batch_size * num_heads * tgt_len;
let cfg = cuda_kernels::launch_config(total_rows);
let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
unsafe {
self.stream
.launch_builder(func)
.arg(q)
.arg(k)
.arg(v)
.arg(o)
.arg(grad_o)
.arg(grad_q)
.arg(grad_k)
.arg(grad_v)
.arg(&scale)
.arg(&(batch_size as u32))
.arg(&(num_heads as u32))
.arg(&(tgt_len as u32))
.arg(&(src_len as u32))
.arg(&(head_dim as u32))
.arg(&is_causal_u32)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn im2col_f32(
&self,
input: &CudaSlice<f32>,
col: &mut CudaSlice<f32>,
params: &CudaSlice<u32>,
n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("im2col_f32")
.ok_or_else(|| CudaError::KernelNotFound("im2col_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(input)
.arg(col)
.arg(params)
.arg(&(n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn col2im_f32(
&self,
col: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
params: &CudaSlice<u32>,
n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("col2im_f32")
.ok_or_else(|| CudaError::KernelNotFound("col2im_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(col)
.arg(output)
.arg(params)
.arg(&(n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn bias_add_channels_f32(
&self,
data: &mut CudaSlice<f32>,
bias: &CudaSlice<f32>,
spatial: usize,
n: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("bias_add_channels_f32")
.ok_or_else(|| CudaError::KernelNotFound("bias_add_channels_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(n);
unsafe {
self.stream
.launch_builder(func)
.arg(data)
.arg(bias)
.arg(&(spatial as u32))
.arg(&(n as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn conv2d_forward(
&self,
input: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
batch_size: usize,
in_channels: usize,
in_height: usize,
in_width: usize,
out_channels: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> Option<Vec<f32>> {
let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_h = in_channels * kernel_h * kernel_w;
let col_w = out_h * out_w;
let col_n = col_h * col_w;
let spatial = out_h * out_w;
let out_per_batch = out_channels * spatial;
let in_per_batch = in_channels * in_height * in_width;
use super::cuda_pool::pool_alloc;
let weight_gpu = self.htod_copy(weight).ok()?;
let bias_gpu = bias.and_then(|b| self.htod_copy(b).ok());
let im2col_params: [u32; 10] = [
in_height as u32,
in_width as u32,
kernel_h as u32,
kernel_w as u32,
pad_h as u32,
pad_w as u32,
stride_h as u32,
stride_w as u32,
out_h as u32,
out_w as u32,
];
let params_gpu = self.htod_copy(&im2col_params[..]).ok()?;
let mut col_gpu = pool_alloc(col_n).ok()?;
let mut batch_out_gpu = pool_alloc(out_per_batch).ok()?;
let mut output = vec![0.0f32; batch_size * out_per_batch];
for b in 0..batch_size {
let input_slice = &input[b * in_per_batch..(b + 1) * in_per_batch];
let input_gpu = self.htod_copy(input_slice).ok()?;
self.im2col_f32(&input_gpu, &mut col_gpu, ¶ms_gpu, col_n)
.ok()?;
self.gemm_f32(
false,
false,
col_w,
out_channels,
col_h,
1.0,
&col_gpu,
col_w,
&weight_gpu,
col_h,
0.0,
&mut batch_out_gpu,
col_w,
)
.ok()?;
if let Some(ref bg) = bias_gpu {
self.bias_add_channels_f32(&mut batch_out_gpu, bg, spatial, out_per_batch)
.ok()?;
}
let batch_result = self.dtoh_copy(&batch_out_gpu).ok()?;
output[b * out_per_batch..(b + 1) * out_per_batch]
.copy_from_slice(&batch_result[..out_per_batch]);
}
Some(output)
}
}
#[cfg(feature = "cuda")]
pub fn cuda_conv2d_forward(
input: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
batch_size: usize,
in_channels: usize,
in_height: usize,
in_width: usize,
out_channels: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> Option<Vec<f32>> {
let cuda = get_cuda_backend()?;
cuda.conv2d_forward(
input,
weight,
bias,
batch_size,
in_channels,
in_height,
in_width,
out_channels,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
)
}
#[cfg(not(feature = "cuda"))]
pub fn cuda_conv2d_forward(
_input: &[f32],
_weight: &[f32],
_bias: Option<&[f32]>,
_batch_size: usize,
_in_channels: usize,
_in_height: usize,
_in_width: usize,
_out_channels: usize,
_kernel_h: usize,
_kernel_w: usize,
_stride_h: usize,
_stride_w: usize,
_pad_h: usize,
_pad_w: usize,
) -> Option<Vec<f32>> {
None
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn maxpool2d_fwd_f32(
&self,
input: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
indices: &mut CudaSlice<i32>,
params: &CudaSlice<u32>,
channels: usize,
out_h: usize,
out_w: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("maxpool2d_fwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("maxpool2d_fwd_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(input)
.arg(output)
.arg(indices)
.arg(params)
.arg(&(channels as u32))
.arg(&(out_h as u32))
.arg(&(out_w as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn maxpool2d_bwd_f32(
&self,
grad_output: &CudaSlice<f32>,
indices: &CudaSlice<i32>,
grad_input: &mut CudaSlice<f32>,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("maxpool2d_bwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("maxpool2d_bwd_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(indices)
.arg(grad_input)
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn avgpool2d_fwd_f32(
&self,
input: &CudaSlice<f32>,
output: &mut CudaSlice<f32>,
params: &CudaSlice<u32>,
channels: usize,
out_h: usize,
out_w: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("avgpool2d_fwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("avgpool2d_fwd_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(input)
.arg(output)
.arg(params)
.arg(&(channels as u32))
.arg(&(out_h as u32))
.arg(&(out_w as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
pub fn avgpool2d_bwd_f32(
&self,
grad_output: &CudaSlice<f32>,
grad_input: &mut CudaSlice<f32>,
params: &CudaSlice<u32>,
channels: usize,
out_h: usize,
out_w: usize,
total: usize,
) -> Result<(), CudaError> {
let func = self
.kernels
.get("avgpool2d_bwd_f32")
.ok_or_else(|| CudaError::KernelNotFound("avgpool2d_bwd_f32".to_string()))?;
let cfg = cuda_kernels::launch_config(total);
unsafe {
self.stream
.launch_builder(func)
.arg(grad_output)
.arg(grad_input)
.arg(params)
.arg(&(channels as u32))
.arg(&(out_h as u32))
.arg(&(out_w as u32))
.arg(&(total as u32))
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaError::DriverError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(feature = "cuda")]
pub struct PinnedBuffer {
ptr: *mut f32,
len: usize,
}
#[cfg(feature = "cuda")]
unsafe impl Send for PinnedBuffer {}
#[cfg(feature = "cuda")]
unsafe impl Sync for PinnedBuffer {}
#[cfg(feature = "cuda")]
impl PinnedBuffer {
pub fn from_slice(data: &[f32]) -> Result<Self, CudaError> {
use std::ptr;
if data.is_empty() {
return Ok(Self {
ptr: ptr::null_mut(),
len: 0,
});
}
let byte_size = data.len() * std::mem::size_of::<f32>();
let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
unsafe {
let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::AllocationFailed);
}
ptr::copy_nonoverlapping(data.as_ptr(), host_ptr as *mut f32, data.len());
}
Ok(Self {
ptr: host_ptr as *mut f32,
len: data.len(),
})
}
pub fn alloc(len: usize) -> Result<Self, CudaError> {
use std::ptr;
if len == 0 {
return Ok(Self {
ptr: ptr::null_mut(),
len: 0,
});
}
let byte_size = len * std::mem::size_of::<f32>();
let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
unsafe {
let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
return Err(CudaError::AllocationFailed);
}
}
Ok(Self {
ptr: host_ptr as *mut f32,
len,
})
}
pub fn as_slice(&self) -> &[f32] {
if self.ptr.is_null() || self.len == 0 {
return &[];
}
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
pub fn as_slice_mut(&mut self) -> &mut [f32] {
if self.ptr.is_null() || self.len == 0 {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_ptr(&self) -> *const f32 {
self.ptr
}
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.ptr
}
pub fn to_gpu(&self) -> Result<CudaSlice<f32>, CudaError> {
let backend = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
backend.htod_copy(self.as_slice())
}
}
#[cfg(feature = "cuda")]
impl Drop for PinnedBuffer {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut std::ffi::c_void);
}
self.ptr = std::ptr::null_mut();
}
}
}
#[cfg(feature = "cuda")]
pub fn pin_memory(data: &[f32]) -> Result<PinnedBuffer, CudaError> {
PinnedBuffer::from_slice(data)
}
#[cfg(not(feature = "cuda"))]
pub fn pin_memory(_data: &[f32]) -> Result<(), CudaError> {
Err(CudaError::DeviceNotFound)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_availability() {
let available = is_available();
println!("CUDA available: {}", available);
}
#[test]
fn test_device_count() {
let count = device_count();
println!("CUDA device count: {}", count);
assert!(count <= 16);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_backend_creation() {
if is_available() {
let backend = CudaBackend::new(0);
assert!(backend.is_some());
}
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_memory_operations() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let gpu_data = backend.htod_copy(&data).unwrap();
let result = backend.dtoh_copy(&gpu_data).unwrap();
assert_eq!(data, result);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_gemm() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let a: Vec<f32> = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; let b: Vec<f32> = vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; let c: Vec<f32> = vec![0.0; 4];
let a_gpu = backend.htod_copy(&a).unwrap();
let b_gpu = backend.htod_copy(&b).unwrap();
let mut c_gpu = backend.htod_copy(&c).unwrap();
backend
.gemm_f32(
false, false, 2, 2, 3, 1.0, &a_gpu, 2, &b_gpu, 3, 0.0, &mut c_gpu, 2, )
.unwrap();
let result = backend.dtoh_copy(&c_gpu).unwrap();
assert!((result[0] - 22.0).abs() < 1e-5, "result[0] = {}", result[0]);
assert!((result[1] - 49.0).abs() < 1e-5, "result[1] = {}", result[1]);
assert!((result[2] - 28.0).abs() < 1e-5, "result[2] = {}", result[2]);
assert!((result[3] - 64.0).abs() < 1e-5, "result[3] = {}", result[3]);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_add_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
let a_gpu = backend.htod_copy(&a).unwrap();
let b_gpu = backend.htod_copy(&b).unwrap();
let mut c_gpu = backend.alloc::<f32>(4).unwrap();
backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
let result = backend.dtoh_copy(&c_gpu).unwrap();
assert!((result[0] - 6.0).abs() < 1e-5);
assert!((result[1] - 8.0).abs() < 1e-5);
assert!((result[2] - 10.0).abs() < 1e-5);
assert!((result[3] - 12.0).abs() < 1e-5);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_mul_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b: Vec<f32> = vec![2.0, 3.0, 4.0, 5.0];
let a_gpu = backend.htod_copy(&a).unwrap();
let b_gpu = backend.htod_copy(&b).unwrap();
let mut c_gpu = backend.alloc::<f32>(4).unwrap();
backend.mul_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
let result = backend.dtoh_copy(&c_gpu).unwrap();
assert!((result[0] - 2.0).abs() < 1e-5);
assert!((result[1] - 6.0).abs() < 1e-5);
assert!((result[2] - 12.0).abs() < 1e-5);
assert!((result[3] - 20.0).abs() < 1e-5);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_scale_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let mut data_gpu = backend.htod_copy(&data).unwrap();
backend.scale_f32(&mut data_gpu, 2.5, 4).unwrap();
let result = backend.dtoh_copy(&data_gpu).unwrap();
assert!((result[0] - 2.5).abs() < 1e-5);
assert!((result[1] - 5.0).abs() < 1e-5);
assert!((result[2] - 7.5).abs() < 1e-5);
assert!((result[3] - 10.0).abs() < 1e-5);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_relu_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let input: Vec<f32> = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let input_gpu = backend.htod_copy(&input).unwrap();
let mut output_gpu = backend.alloc::<f32>(5).unwrap();
backend.relu_f32(&mut output_gpu, &input_gpu, 5).unwrap();
let result = backend.dtoh_copy(&output_gpu).unwrap();
assert!((result[0] - 0.0).abs() < 1e-5);
assert!((result[1] - 0.0).abs() < 1e-5);
assert!((result[2] - 0.0).abs() < 1e-5);
assert!((result[3] - 1.0).abs() < 1e-5);
assert!((result[4] - 2.0).abs() < 1e-5);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_sigmoid_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let input: Vec<f32> = vec![0.0, 1.0, -1.0];
let input_gpu = backend.htod_copy(&input).unwrap();
let mut output_gpu = backend.alloc::<f32>(3).unwrap();
backend.sigmoid_f32(&mut output_gpu, &input_gpu, 3).unwrap();
let result = backend.dtoh_copy(&output_gpu).unwrap();
assert!((result[0] - 0.5).abs() < 1e-4);
assert!((result[1] - 0.7311).abs() < 1e-3);
assert!((result[2] - 0.2689).abs() < 1e-3);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_tanh_kernel() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let input: Vec<f32> = vec![0.0, 1.0, -1.0];
let input_gpu = backend.htod_copy(&input).unwrap();
let mut output_gpu = backend.alloc::<f32>(3).unwrap();
backend.tanh_f32(&mut output_gpu, &input_gpu, 3).unwrap();
let result = backend.dtoh_copy(&output_gpu).unwrap();
assert!((result[0] - 0.0).abs() < 1e-5);
assert!((result[1] - 0.7616).abs() < 1e-3);
assert!((result[2] - (-0.7616)).abs() < 1e-3);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_large_tensor_add() {
if !is_available() {
return;
}
let backend = CudaBackend::new(0).unwrap();
let n = 1_000_000;
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
let a_gpu = backend.htod_copy(&a).unwrap();
let b_gpu = backend.htod_copy(&b).unwrap();
let mut c_gpu = backend.alloc::<f32>(n).unwrap();
backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, n).unwrap();
let result = backend.dtoh_copy(&c_gpu).unwrap();
assert!((result[0] - n as f32).abs() < 1e-3);
assert!((result[n / 2] - n as f32).abs() < 1e-3);
assert!((result[n - 1] - n as f32).abs() < 1e-3);
}
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_conv2d_forward() {
if !is_available() {
return;
}
let input = vec![1.0f32; 1 * 3 * 4 * 4]; let mut weight = vec![0.0f32; 2 * 3 * 1 * 1];
weight[0] = 1.0;
weight[4] = 1.0;
let bias = vec![0.5f32; 2];
let result = cuda_conv2d_forward(
&input,
&weight,
Some(&bias),
1,
3,
4,
4,
2,
1,
1,
1,
1,
0,
0,
);
let out = result.expect("CUDA conv2d should succeed");
assert_eq!(out.len(), 2 * 4 * 4);
assert!(
(out[0] - 1.5).abs() < 0.01,
"1x1 conv ch0: expected 1.5, got {}",
out[0]
);
assert!(
(out[16] - 1.5).abs() < 0.01,
"1x1 conv ch1: expected 1.5, got {}",
out[16]
);
let input2 = vec![1.0f32; 1 * 3 * 8 * 8];
let weight2 = vec![1.0f32; 2 * 3 * 3 * 3]; let bias2 = vec![0.0f32; 2];
let result2 = cuda_conv2d_forward(
&input2,
&weight2,
Some(&bias2),
1,
3,
8,
8,
2,
3,
3,
1,
1,
1,
1,
);
let out2 = result2.expect("CUDA 3x3 conv should succeed");
assert_eq!(out2.len(), 2 * 8 * 8);
let center = 4 * 8 + 4;
assert!(
(out2[center] - 27.0).abs() < 0.1,
"3x3 conv center: expected 27.0, got {}",
out2[center]
);
assert!(
(out2[0] - 12.0).abs() < 0.1,
"3x3 conv corner: expected 12.0, got {}",
out2[0]
);
}
}