use std::cell::RefCell;
use std::marker::PhantomData;
use super::{
api::{
create_context, create_stream, cuInit, cuMemcpy, cuStreamDestroy, cu_write,
cublas::{create_handle, cublasDestroy_v2, cublasSetStream_v2, CublasHandle},
cumalloc, device, Context, CudaIntDevice, Module, Stream,
},
chosen_cu_idx, launch_kernel1d, AsCudaCvoidPtr, CUDAPtr, KernelCacheCU,
};
use crate::{
cache::Cache, flag::AllocFlag, Addons, AddonsReturn, Alloc, Buffer, CacheReturn, CloneBuf,
Device, PtrConv, Shape,
};
#[derive(Debug)]
pub struct CUDA {
pub kernel_cache: RefCell<KernelCacheCU>,
pub modules: RefCell<Vec<Module>>,
device: CudaIntDevice,
ctx: Context,
stream: Stream,
handle: CublasHandle,
pub addons: Addons<CUDA>,
}
pub type CU = CUDA;
impl CUDA {
pub fn new(idx: usize) -> crate::Result<CUDA> {
unsafe { cuInit(0) }.to_result()?;
let device = device(idx as i32)?;
let ctx = create_context(&device)?;
let stream = create_stream()?;
let handle = create_handle()?;
unsafe { cublasSetStream_v2(handle.0, stream.0) }.to_result()?;
Ok(CUDA {
kernel_cache: Default::default(),
modules: Default::default(),
addons: Default::default(),
device,
ctx,
stream,
handle,
})
}
#[inline]
pub fn device(&self) -> &CudaIntDevice {
&self.device
}
#[inline]
pub fn ctx(&self) -> &Context {
&self.ctx
}
#[inline]
pub fn cublas_handle(&self) -> &CublasHandle {
&self.handle
}
#[inline]
pub fn stream(&self) -> &Stream {
&self.stream
}
#[inline]
pub fn launch_kernel1d(
&self,
len: usize,
src: &str,
fn_name: &str,
args: &[&dyn AsCudaCvoidPtr],
) -> crate::Result<()> {
launch_kernel1d(len, self, src, fn_name, args)
}
}
impl Device for CUDA {
type Ptr<U, S: Shape> = CUDAPtr<U>;
type Cache = Cache<CUDA>;
fn new() -> crate::Result<Self> {
CUDA::new(chosen_cu_idx())
}
}
impl AddonsReturn for CUDA {
#[inline]
fn addons(&self) -> &Addons<Self>
where
Self: Device,
{
&self.addons
}
}
impl PtrConv for CUDA {
#[inline]
unsafe fn convert<T, IS: Shape, Conv, OS: Shape>(
ptr: &Self::Ptr<T, IS>,
flag: AllocFlag,
) -> Self::Ptr<Conv, OS> {
CUDAPtr {
ptr: ptr.ptr,
len: ptr.len,
flag,
p: PhantomData,
}
}
}
impl Default for CUDA {
#[inline]
fn default() -> Self {
CUDA::new(chosen_cu_idx()).expect("A valid CUDA device index should be set via the environment variable `CUSTOS_CL_DEVICE_IDX`")
}
}
impl Drop for CUDA {
fn drop(&mut self) {
self.cache_mut().nodes.clear();
unsafe {
cublasDestroy_v2(self.handle.0);
cuStreamDestroy(self.stream.0);
}
}
}
impl<T> Alloc<'_, T> for CUDA {
fn alloc(&self, len: usize, flag: AllocFlag) -> CUDAPtr<T> {
let ptr = cumalloc::<T>(len).unwrap();
CUDAPtr {
ptr,
len,
flag,
p: PhantomData,
}
}
fn with_slice(&self, data: &[T]) -> CUDAPtr<T> {
let ptr = cumalloc::<T>(data.len()).unwrap();
cu_write(ptr, data).unwrap();
CUDAPtr {
ptr,
len: data.len(),
flag: AllocFlag::None,
p: PhantomData,
}
}
}
impl<'a, T> CloneBuf<'a, T> for CUDA {
fn clone_buf(&'a self, buf: &Buffer<'a, T, CUDA>) -> Buffer<'a, T, CUDA> {
let cloned = Buffer::new(self, buf.len());
unsafe {
cuMemcpy(
cloned.ptrs().2,
buf.ptrs().2,
buf.len() * std::mem::size_of::<T>(),
);
}
cloned
}
}