use crate::driver::{result, sys};
use super::{alloc::DeviceRepr, device_ptr::DeviceSlice};
use std::{
marker::PhantomData,
ops::{Bound, RangeBounds},
string::String,
};
#[cfg(feature = "no-std")]
use spin::RwLock;
#[cfg(not(feature = "no-std"))]
use std::sync::RwLock;
use std::{collections::BTreeMap, marker::Unpin, pin::Pin, sync::Arc, vec::Vec};
#[derive(Debug)]
pub struct CudaDevice {
pub(crate) cu_device: sys::CUdevice,
pub(crate) cu_primary_ctx: sys::CUcontext,
pub(crate) stream: sys::CUstream,
pub(crate) event: sys::CUevent,
pub(crate) modules: RwLock<BTreeMap<String, CudaModule>>,
}
unsafe impl Send for CudaDevice {}
unsafe impl Sync for CudaDevice {}
impl CudaDevice {
pub fn new(ordinal: usize) -> Result<Arc<Self>, result::DriverError> {
result::init().unwrap();
let cu_device = result::device::get(ordinal as i32).unwrap();
let cu_primary_ctx = unsafe { result::primary_ctx::retain(cu_device) }?;
unsafe { result::ctx::set_current(cu_primary_ctx) }.unwrap();
let event = result::event::create(sys::CUevent_flags::CU_EVENT_DISABLE_TIMING)?;
let device = CudaDevice {
cu_device,
cu_primary_ctx,
stream: std::ptr::null_mut(),
event,
modules: RwLock::new(BTreeMap::new()),
};
Ok(Arc::new(device))
}
}
impl Drop for CudaDevice {
fn drop(&mut self) {
let modules = RwLock::get_mut(&mut self.modules);
#[cfg(not(feature = "no-std"))]
let modules = modules.unwrap();
for (_, module) in modules.iter() {
unsafe { result::module::unload(module.cu_module) }.unwrap();
}
modules.clear();
let stream = std::mem::replace(&mut self.stream, std::ptr::null_mut());
if !stream.is_null() {
unsafe { result::stream::destroy(stream) }.unwrap();
}
let event = std::mem::replace(&mut self.event, std::ptr::null_mut());
if !event.is_null() {
unsafe { result::event::destroy(event) }.unwrap();
}
let ctx = std::mem::replace(&mut self.cu_primary_ctx, std::ptr::null_mut());
if !ctx.is_null() {
unsafe { result::primary_ctx::release(self.cu_device) }.unwrap();
}
}
}
#[derive(Debug)]
pub struct CudaSlice<T> {
pub(crate) cu_device_ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) device: Arc<CudaDevice>,
pub(crate) host_buf: Option<Pin<Vec<T>>>,
}
unsafe impl<T: Send> Send for CudaSlice<T> {}
unsafe impl<T: Sync> Sync for CudaSlice<T> {}
impl<T> Drop for CudaSlice<T> {
fn drop(&mut self) {
unsafe {
result::free_async(self.cu_device_ptr, self.device.stream).unwrap();
}
}
}
impl<T> CudaSlice<T> {
pub fn device(&self) -> Arc<CudaDevice> {
self.device.clone()
}
}
impl<T: DeviceRepr> CudaSlice<T> {
pub fn try_clone(&self) -> Result<Self, result::DriverError> {
let mut dst = unsafe { self.device.alloc(self.len) }?;
self.device.dtod_copy(self, &mut dst)?;
Ok(dst)
}
}
impl<T: DeviceRepr> Clone for CudaSlice<T> {
fn clone(&self) -> Self {
self.try_clone().unwrap()
}
}
impl<T: Clone + Default + DeviceRepr + Unpin> TryFrom<CudaSlice<T>> for Vec<T> {
type Error = result::DriverError;
fn try_from(value: CudaSlice<T>) -> Result<Self, Self::Error> {
value.device.clone().sync_reclaim(value)
}
}
#[derive(Debug)]
pub(crate) struct CudaModule {
pub(crate) cu_module: sys::CUmodule,
pub(crate) functions: BTreeMap<&'static str, sys::CUfunction>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
#[derive(Debug, Clone)]
pub struct CudaFunction {
pub(crate) cu_function: sys::CUfunction,
pub(crate) device: Arc<CudaDevice>,
}
unsafe impl Send for CudaFunction {}
unsafe impl Sync for CudaFunction {}
#[derive(Debug)]
pub struct CudaStream {
pub stream: sys::CUstream,
device: Arc<CudaDevice>,
}
impl CudaDevice {
pub fn fork_default_stream(self: &Arc<Self>) -> Result<CudaStream, result::DriverError> {
let stream = CudaStream {
stream: result::stream::create(result::stream::StreamKind::NonBlocking)?,
device: self.clone(),
};
stream.wait_for_default()?;
Ok(stream)
}
#[allow(unused_variables)]
pub fn wait_for(self: &Arc<Self>, stream: &CudaStream) -> Result<(), result::DriverError> {
unsafe {
result::event::record(self.event, stream.stream)?;
result::stream::wait_event(
self.stream,
self.event,
sys::CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
)
}
}
}
impl CudaStream {
pub fn wait_for_default(&self) -> Result<(), result::DriverError> {
unsafe {
result::event::record(self.device.event, self.device.stream)?;
result::stream::wait_event(
self.stream,
self.device.event,
sys::CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
)
}
}
}
impl Drop for CudaStream {
fn drop(&mut self) {
self.device.wait_for(self).unwrap();
unsafe {
result::stream::destroy(self.stream).unwrap();
}
}
}
#[derive(Debug)]
pub struct CudaView<'a, T> {
pub(crate) root: &'a sys::CUdeviceptr,
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
marker: PhantomData<T>,
}
impl<T> CudaSlice<T> {
pub fn slice(&self, range: impl RangeBounds<usize>) -> CudaView<'_, T> {
self.try_slice(range).unwrap()
}
pub fn try_slice(&self, range: impl RangeBounds<usize>) -> Option<CudaView<'_, T>> {
range.bounds(..self.len()).map(|(start, end)| CudaView {
root: &self.cu_device_ptr,
ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
marker: PhantomData,
})
}
pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
(len * std::mem::size_of::<S>() <= self.num_bytes()).then_some(CudaView {
root: &self.cu_device_ptr,
ptr: self.cu_device_ptr,
len,
marker: PhantomData,
})
}
}
impl<'a, T> CudaView<'a, T> {
pub fn slice(&self, range: impl RangeBounds<usize>) -> CudaView<'a, T> {
self.try_slice(range).unwrap()
}
pub fn try_slice(&self, range: impl RangeBounds<usize>) -> Option<CudaView<'a, T>> {
range.bounds(..self.len()).map(|(start, end)| CudaView {
root: self.root,
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
marker: PhantomData,
})
}
}
#[derive(Debug)]
pub struct CudaViewMut<'a, T> {
pub(crate) root: &'a mut sys::CUdeviceptr,
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
marker: PhantomData<T>,
}
impl<T> CudaSlice<T> {
pub fn slice_mut(&mut self, range: impl RangeBounds<usize>) -> CudaViewMut<'_, T> {
self.try_slice_mut(range).unwrap()
}
pub fn try_slice_mut(&mut self, range: impl RangeBounds<usize>) -> Option<CudaViewMut<'_, T>> {
range.bounds(..self.len()).map(|(start, end)| CudaViewMut {
ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
root: &mut self.cu_device_ptr,
len: end - start,
marker: PhantomData,
})
}
pub unsafe fn transmute_mut<S>(&mut self, len: usize) -> Option<CudaViewMut<'_, S>> {
(len * std::mem::size_of::<S>() <= self.num_bytes()).then_some(CudaViewMut {
ptr: self.cu_device_ptr,
root: &mut self.cu_device_ptr,
len,
marker: PhantomData,
})
}
}
impl<'a, T> CudaViewMut<'a, T> {
pub fn slice<'b: 'a>(&'b self, range: impl RangeBounds<usize>) -> CudaView<'a, T> {
self.try_slice(range).unwrap()
}
pub fn try_slice<'b: 'a>(&'b self, range: impl RangeBounds<usize>) -> Option<CudaView<'a, T>> {
range.bounds(..self.len()).map(|(start, end)| CudaView {
root: self.root,
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
marker: PhantomData,
})
}
pub fn slice_mut<'b: 'a>(&'b mut self, range: impl RangeBounds<usize>) -> CudaViewMut<'a, T> {
self.try_slice_mut(range).unwrap()
}
pub fn try_slice_mut<'b: 'a>(
&'b mut self,
range: impl RangeBounds<usize>,
) -> Option<CudaViewMut<'a, T>> {
range.bounds(..self.len()).map(|(start, end)| CudaViewMut {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
root: self.root,
len: end - start,
marker: PhantomData,
})
}
}
trait RangeHelper: RangeBounds<usize> {
fn inclusive_start(&self, valid_start: usize) -> usize;
fn exclusive_end(&self, valid_end: usize) -> usize;
fn bounds(&self, valid: impl RangeHelper) -> Option<(usize, usize)> {
let vs = valid.inclusive_start(0);
let ve = valid.exclusive_end(usize::MAX);
let s = self.inclusive_start(vs);
let e = self.exclusive_end(ve);
let inside = s >= vs && e <= ve;
let valid = s < e || (s == e && !matches!(self.end_bound(), Bound::Included(_)));
(inside && valid).then_some((s, e))
}
}
impl<R: RangeBounds<usize>> RangeHelper for R {
fn inclusive_start(&self, valid_start: usize) -> usize {
match self.start_bound() {
Bound::Included(n) => *n,
Bound::Excluded(n) => *n + 1,
Bound::Unbounded => valid_start,
}
}
fn exclusive_end(&self, valid_end: usize) -> usize {
match self.end_bound() {
Bound::Included(n) => *n + 1,
Bound::Excluded(n) => *n,
Bound::Unbounded => valid_end,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::reversed_empty_ranges)]
fn test_bounds_helper() {
assert_eq!((..2usize).bounds(0..usize::MAX), Some((0, 2)));
assert_eq!((1..2usize).bounds(..usize::MAX), Some((1, 2)));
assert_eq!((..).bounds(1..10), Some((1, 10)));
assert_eq!((2..=2usize).bounds(0..usize::MAX), Some((2, 3)));
assert_eq!((2..=2usize).bounds(0..=1), None);
assert_eq!((2..2usize).bounds(0..usize::MAX), Some((2, 2)));
assert_eq!((1..0usize).bounds(0..usize::MAX), None);
assert_eq!((1..=0usize).bounds(0..usize::MAX), None);
}
#[test]
fn test_transmutes() {
let dev = CudaDevice::new(0).unwrap();
let mut slice = dev.alloc_zeros::<u8>(100).unwrap();
assert!(unsafe { slice.transmute::<f32>(25) }.is_some());
assert!(unsafe { slice.transmute::<f32>(26) }.is_none());
assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());
}
}