use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DeviceRepr};
use cudarc::nvrtc::Ptx;
use std::sync::Arc;
pub struct CudaDevice {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
}
impl CudaDevice {
pub fn new() -> Option<Self> {
let ctx = CudaContext::new(0).ok()?;
let stream = ctx.default_stream();
Some(Self { ctx, stream })
}
pub fn name(&self) -> String {
"CUDA Device 0".to_string()
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub fn alloc_zeros<T: DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> CudaSlice<T> {
self.stream
.alloc_zeros::<T>(len)
.expect("CUDA alloc failed")
}
pub fn htod_copy<T: DeviceRepr + Clone>(&self, data: &[T]) -> CudaSlice<T> {
self.stream.clone_htod(data).expect("CUDA htod failed")
}
pub fn dtoh_copy<T: DeviceRepr + Clone>(&self, slice: &CudaSlice<T>) -> Vec<T> {
self.stream.clone_dtoh(slice).expect("CUDA dtoh failed")
}
pub fn htod_copy_into<T: DeviceRepr + Clone>(&self, data: &[T], dst: &mut CudaSlice<T>) {
self.stream
.memcpy_htod(data, dst)
.expect("CUDA memcpy_htod failed");
}
pub fn load_module(&self, ptx: Ptx) -> Arc<CudaModule> {
self.ctx
.load_module(ptx)
.expect("Failed to load PTX module")
}
pub fn load_function(module: &Arc<CudaModule>, name: &str) -> CudaFunction {
module.load_function(name).expect("Function not found")
}
pub fn synchronize(&self) {
self.stream.synchronize().expect("CUDA sync failed");
}
}