use std::sync::{Arc, Mutex};
use cudarc::cublas::sys::cublasOperation_t;
use cudarc::cublas::CudaBlas;
use cudarc::driver::{
CudaContext, CudaEvent, CudaFunction, CudaSlice, CudaStream, DevicePtr, DevicePtrMut,
LaunchConfig, PushKernelArg,
};
use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions};
use crate::backend::{Backend, BackendBufferOps};
use crate::error::{backend_err, BackendOp, GpuError, Result};
pub struct CudaBackend {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
blas: Mutex<CudaBlas>,
device_name: String,
device_memory: u64,
}
pub struct CudaBuffer {
pub(crate) inner: CudaSlice<u8>,
size: u64,
stream: Arc<CudaStream>,
}
pub struct CudaKernel {
pub(crate) function: CudaFunction,
pub(crate) block_dim: (u32, u32, u32),
}
pub struct CudaBatch {
stream: Arc<CudaStream>,
}
impl CudaBackend {
pub fn compile_cuda(
&self,
source: &str,
entry_point: &str,
block_dim: (u32, u32, u32),
) -> Result<CudaKernel> {
let opts = CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = compile_ptx_with_opts(source, opts)
.map_err(|e| backend_err(BackendOp::CompileKernel, e))?;
let module = self
.ctx
.load_module(ptx)
.map_err(|e| backend_err(BackendOp::LoadModule, e))?;
let function = module
.load_function(entry_point)
.map_err(|e| backend_err(BackendOp::LoadFunction, e))?;
Ok(CudaKernel {
function,
block_dim,
})
}
pub fn dispatch_cuda(
&self,
kernel: &CudaKernel,
buffers: &[&CudaBuffer],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<()> {
let config = LaunchConfig {
grid_dim: (workgroups[0], workgroups[1], workgroups[2]),
block_dim: kernel.block_dim,
shared_mem_bytes: 0,
};
let pc_values: Vec<u32> = push_constants
.map(|pc| {
pc.chunks_exact(4)
.map(|c| u32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
.collect()
})
.unwrap_or_default();
unsafe {
let mut builder = self.stream.launch_builder(&kernel.function);
for buf in buffers {
builder.arg(&buf.inner);
}
for val in &pc_values {
builder.arg(val);
}
builder
.launch(config)
.map_err(|e| backend_err(BackendOp::LaunchKernel, e))?;
}
self.stream
.synchronize()
.map_err(|e| backend_err(BackendOp::StreamSync, e))?;
Ok(())
}
#[allow(clippy::many_single_char_names)]
pub fn cublas_matmul(
&self,
a: &CudaBuffer,
b: &CudaBuffer,
c: &mut CudaBuffer,
m: u32,
n: u32,
k: u32,
) -> Result<()> {
#[allow(clippy::cast_possible_wrap)]
unsafe {
let blas = self
.blas
.lock()
.map_err(|_| backend_err(BackendOp::MutexPoisoned, "cublas"))?;
let (a_ptr, _a_guard) = a.inner.device_ptr(&self.stream);
let (b_ptr, _b_guard) = b.inner.device_ptr(&self.stream);
let (c_ptr, _c_guard) = c.inner.device_ptr_mut(&self.stream);
cudarc::cublas::result::sgemm(
*blas.handle(),
cublasOperation_t::CUBLAS_OP_N,
cublasOperation_t::CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&1.0f32,
b_ptr as *const f32,
n as i32,
a_ptr as *const f32,
k as i32,
&0.0f32,
c_ptr as *mut f32,
n as i32,
)
.map_err(|e| backend_err(BackendOp::CuBlas, e))?;
}
self.stream
.synchronize()
.map_err(|e| backend_err(BackendOp::StreamSync, e))?;
Ok(())
}
#[allow(clippy::unnecessary_wraps)]
pub fn begin_batch(&self) -> Result<CudaBatch> {
Ok(CudaBatch {
stream: Arc::clone(&self.stream),
})
}
}
impl Backend for CudaBackend {
type Buffer = CudaBuffer;
type Pipeline = CudaKernel;
fn create() -> Result<Self> {
let ctx = CudaContext::new(0).map_err(|e| backend_err(BackendOp::CreateDevice, e))?;
let device_name = ctx
.name()
.map_err(|e| backend_err(BackendOp::DeviceQuery, e))?;
let device_memory =
ctx.total_mem()
.map_err(|e| backend_err(BackendOp::DeviceQuery, e))? as u64;
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone()).map_err(|e| backend_err(BackendOp::CuBlas, e))?;
Ok(Self {
ctx,
stream,
blas: Mutex::new(blas),
device_name,
device_memory,
})
}
fn upload(&self, data: &[u8]) -> Result<Self::Buffer> {
let size = data.len() as u64;
let inner = self
.stream
.clone_htod(data)
.map_err(|e| backend_err(BackendOp::CopyBuffer, e))?;
Ok(CudaBuffer {
inner,
size,
stream: Arc::clone(&self.stream),
})
}
fn alloc(&self, size: u64) -> Result<Self::Buffer> {
let inner = self
.stream
.alloc_zeros::<u8>(size as usize)
.map_err(|e| backend_err(BackendOp::CreateBuffer, e))?;
Ok(CudaBuffer {
inner,
size,
stream: Arc::clone(&self.stream),
})
}
fn dispatch(
&self,
_spirv: &[u32],
_entry_point: &str,
_buffers: &[&Self::Buffer],
_workgroups: [u32; 3],
_push_constants: Option<&[u8]>,
) -> Result<()> {
Err(GpuError::BackendUnavailable(
"CUDA cannot execute SPIR-V shaders — use compile_cuda() instead".into(),
))
}
fn create_pipeline(
&self,
_spirv: &[u32],
_entry_point: &str,
_binding_count: usize,
_push_constant_size: u32,
) -> Result<Self::Pipeline> {
Err(GpuError::BackendUnavailable(
"CUDA cannot compile SPIR-V pipelines — use compile_cuda() instead".into(),
))
}
fn dispatch_pipeline(
&self,
pipeline: &Self::Pipeline,
buffers: &[&Self::Buffer],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<()> {
self.dispatch_cuda(pipeline, buffers, workgroups, push_constants)
}
fn device_name(&self) -> &str {
&self.device_name
}
fn device_memory(&self) -> u64 {
self.device_memory
}
fn subgroup_size(&self) -> u32 {
32
}
fn copy_buffer(&self, src: &Self::Buffer, size: u64) -> Result<Self::Buffer> {
let mut dst = self
.stream
.alloc_zeros::<u8>(size as usize)
.map_err(|e| backend_err(BackendOp::CreateBuffer, e))?;
self.stream
.memcpy_dtod(&mut dst, &src.inner, size as usize)
.map_err(|e| backend_err(BackendOp::CopyBuffer, e))?;
Ok(CudaBuffer {
inner: dst,
size,
stream: Arc::clone(&self.stream),
})
}
}
impl BackendBufferOps for CudaBuffer {
fn read_back(&self) -> Result<Vec<u8>> {
self.stream
.clone_dtoh(&self.inner)
.map_err(|e| backend_err(BackendOp::CopyBuffer, e))
}
fn byte_size(&self) -> u64 {
self.size
}
}
impl CudaBatch {
pub fn record_dispatch(
&mut self,
kernel: &CudaKernel,
buffers: &[&CudaBuffer],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<()> {
let config = LaunchConfig {
grid_dim: (workgroups[0], workgroups[1], workgroups[2]),
block_dim: kernel.block_dim,
shared_mem_bytes: 0,
};
let pc_values: Vec<u32> = push_constants
.map(|pc| {
pc.chunks_exact(4)
.map(|c| u32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
.collect()
})
.unwrap_or_default();
unsafe {
let mut builder = self.stream.launch_builder(&kernel.function);
for buf in buffers {
builder.arg(&buf.inner);
}
for val in &pc_values {
builder.arg(val);
}
builder
.launch(config)
.map_err(|e| backend_err(BackendOp::LaunchKernel, e))?;
}
Ok(())
}
#[allow(
clippy::unused_self,
clippy::needless_pass_by_ref_mut,
clippy::missing_const_for_fn
)]
pub fn record_barrier(&mut self) {
}
pub fn submit_async(self) -> Result<CudaTicket> {
let event = self
.stream
.record_event(None)
.map_err(|e| backend_err(BackendOp::RecordEvent, e))?;
Ok(CudaTicket {
stream: self.stream,
event,
})
}
}
pub struct CudaTicket {
stream: Arc<CudaStream>,
event: CudaEvent,
}
impl CudaTicket {
pub(crate) fn wait(self) -> Result<()> {
self.stream
.synchronize()
.map_err(|e| backend_err(BackendOp::StreamSync, e))
}
pub(crate) fn is_ready(&self) -> bool {
self.event.is_complete()
}
}