use crate::driver::{
result::{self, DriverError},
sys::{self, CUfunc_cache_enum, CUfunction_attribute_enum},
};
use std::{
ffi::CString,
marker::PhantomData,
ops::{Bound, RangeBounds},
string::String,
sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering},
sync::Arc,
vec::Vec,
};
#[derive(Debug)]
pub struct CudaContext {
pub(crate) cu_device: sys::CUdevice,
pub(crate) cu_ctx: sys::CUcontext,
pub(crate) ordinal: usize,
pub(crate) has_async_alloc: bool,
pub(crate) is_primary: bool,
pub(crate) num_streams: AtomicUsize,
pub(crate) event_tracking: AtomicBool,
pub(crate) error_state: AtomicU32,
}
unsafe impl Send for CudaContext {}
unsafe impl Sync for CudaContext {}
impl Drop for CudaContext {
fn drop(&mut self) {
self.record_err(self.bind_to_thread());
let ctx = std::mem::replace(&mut self.cu_ctx, std::ptr::null_mut());
if !ctx.is_null() {
if self.is_primary {
self.record_err(unsafe { result::primary_ctx::release(self.cu_device) });
} else {
self.record_err(unsafe { sys::cuCtxDestroy_v2(ctx).result() });
}
}
}
}
impl PartialEq for CudaContext {
fn eq(&self, other: &Self) -> bool {
self.cu_device == other.cu_device
&& self.cu_ctx == other.cu_ctx
&& self.ordinal == other.ordinal
}
}
impl Eq for CudaContext {}
impl CudaContext {
pub fn new(ordinal: usize) -> Result<Arc<Self>, DriverError> {
result::init()?;
let cu_device = result::device::get(ordinal as i32)?;
let cu_ctx = unsafe { result::primary_ctx::retain(cu_device) }?;
let has_async_alloc = unsafe {
let memory_pools_supported = result::device::get_attribute(
cu_device,
sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
)?;
memory_pools_supported > 0
};
let ctx = Arc::new(CudaContext {
cu_device,
cu_ctx,
ordinal,
has_async_alloc,
is_primary: true,
num_streams: AtomicUsize::new(0),
event_tracking: AtomicBool::new(true),
error_state: AtomicU32::new(0),
});
ctx.bind_to_thread()?;
Ok(ctx)
}
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010"
))]
pub fn new_non_primary(ordinal: usize, flags: u32) -> Result<Arc<Self>, DriverError> {
result::init()?;
let cu_device = result::device::get(ordinal as i32)?;
#[cfg(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010"
))]
let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), flags, cu_device) }?;
#[cfg(not(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010"
)))]
let cu_ctx = unsafe { result::ctx::create_v3(flags, cu_device) }?;
let has_async_alloc = unsafe {
let memory_pools_supported = result::device::get_attribute(
cu_device,
sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
)?;
memory_pools_supported > 0
};
let ctx = Arc::new(CudaContext {
cu_device,
cu_ctx,
ordinal,
has_async_alloc,
is_primary: false,
num_streams: AtomicUsize::new(0),
event_tracking: AtomicBool::new(true),
error_state: AtomicU32::new(0),
});
ctx.bind_to_thread()?;
Ok(ctx)
}
#[cfg(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010"
))]
pub fn new_cig(
ordinal: usize,
flags: u32,
cig_params: &mut sys::CUctxCigParam,
) -> Result<Arc<Self>, DriverError> {
result::init()?;
let cu_device = result::device::get(ordinal as i32)?;
let mut ctx_create_params = sys::CUctxCreateParams_st {
execAffinityParams: std::ptr::null_mut(),
numExecAffinityParams: 0,
cigParams: cig_params as *mut sys::CUctxCigParam,
};
let cu_ctx = unsafe { result::ctx::create_v4(&mut ctx_create_params, flags, cu_device) }?;
let has_async_alloc = unsafe {
let memory_pools_supported = result::device::get_attribute(
cu_device,
sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
)?;
memory_pools_supported > 0
};
let ctx = Arc::new(CudaContext {
cu_device,
cu_ctx,
ordinal,
has_async_alloc,
is_primary: false,
num_streams: AtomicUsize::new(0),
event_tracking: AtomicBool::new(true),
error_state: AtomicU32::new(0),
});
ctx.bind_to_thread()?;
Ok(ctx)
}
pub unsafe fn from_raw_context(
ordinal: usize,
cu_device: sys::CUdevice,
cu_ctx: sys::CUcontext,
) -> Result<Arc<Self>, DriverError> {
let has_async_alloc = {
let memory_pools_supported = result::device::get_attribute(
cu_device,
sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
)?;
memory_pools_supported > 0
};
let ctx = Arc::new(CudaContext {
cu_device,
cu_ctx,
ordinal,
has_async_alloc,
is_primary: false,
num_streams: AtomicUsize::new(0),
event_tracking: AtomicBool::new(true),
error_state: AtomicU32::new(0),
});
ctx.bind_to_thread()?;
Ok(ctx)
}
pub fn is_primary(&self) -> bool {
self.is_primary
}
pub fn has_async_alloc(&self) -> bool {
self.has_async_alloc
}
pub fn device_count() -> Result<i32, DriverError> {
result::init()?;
result::device::get_count()
}
pub fn ordinal(&self) -> usize {
self.ordinal
}
pub fn name(&self) -> Result<String, result::DriverError> {
self.check_err()?;
result::device::get_name(self.cu_device)
}
pub fn uuid(&self) -> Result<sys::CUuuid, result::DriverError> {
self.check_err()?;
result::device::get_uuid(self.cu_device)
}
pub fn compute_capability(&self) -> Result<(i32, i32), result::DriverError> {
self.check_err()?;
let capability_major =
self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
let capability_minor =
self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)?;
Ok((capability_major, capability_minor))
}
pub fn total_mem(&self) -> Result<usize, DriverError> {
self.check_err()?;
unsafe { result::device::total_mem(self.cu_device) }
}
pub fn mem_get_info(&self) -> Result<(usize, usize), DriverError> {
self.bind_to_thread()?;
result::mem_get_info()
}
pub fn cu_device(&self) -> sys::CUdevice {
self.cu_device
}
pub fn cu_ctx(&self) -> sys::CUcontext {
self.cu_ctx
}
pub fn bind_to_thread(&self) -> Result<(), DriverError> {
self.check_err()?;
if match result::ctx::get_current()? {
Some(curr_ctx) => curr_ctx != self.cu_ctx,
None => true,
} {
unsafe { result::ctx::set_current(self.cu_ctx) }?;
}
Ok(())
}
pub fn attribute(&self, attrib: sys::CUdevice_attribute) -> Result<i32, result::DriverError> {
self.check_err()?;
unsafe { result::device::get_attribute(self.cu_device, attrib) }
}
pub fn synchronize(&self) -> Result<(), DriverError> {
self.bind_to_thread()?;
result::ctx::synchronize()
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000"
)))]
pub fn set_blocking_synchronize(&self) -> Result<(), DriverError> {
self.set_flags(sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC)
}
#[cfg(not(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000"
)))]
pub fn set_flags(&self, flags: sys::CUctx_flags) -> Result<(), DriverError> {
self.bind_to_thread()?;
result::ctx::set_flags(flags)
}
pub fn get_limit(&self, limit: sys::CUlimit) -> Result<usize, DriverError> {
self.bind_to_thread()?;
result::ctx::get_limit(limit)
}
pub fn set_limit(&self, limit: sys::CUlimit, value: usize) -> Result<(), DriverError> {
self.bind_to_thread()?;
result::ctx::set_limit(limit, value)
}
pub fn get_cache_config(&self) -> Result<sys::CUfunc_cache, DriverError> {
self.bind_to_thread()?;
result::ctx::get_cache_config()
}
pub fn set_cache_config(&self, config: sys::CUfunc_cache) -> Result<(), DriverError> {
self.bind_to_thread()?;
result::ctx::set_cache_config(config)
}
pub fn is_in_multi_stream_mode(&self) -> bool {
self.num_streams.load(Ordering::Relaxed) > 0
}
pub fn is_event_tracking(&self) -> bool {
self.event_tracking.load(Ordering::Relaxed)
}
pub fn is_managing_stream_synchronization(&self) -> bool {
self.is_in_multi_stream_mode() && self.is_event_tracking()
}
pub unsafe fn enable_event_tracking(&self) {
self.event_tracking.store(true, Ordering::Relaxed);
}
pub unsafe fn disable_event_tracking(&self) {
self.event_tracking.store(false, Ordering::Relaxed);
}
pub fn check_err(&self) -> Result<(), DriverError> {
let error_state = self.error_state.swap(0, Ordering::Relaxed);
if error_state == 0 {
Ok(())
} else {
Err(result::DriverError(unsafe {
std::mem::transmute::<u32, sys::cudaError_enum>(error_state)
}))
}
}
pub fn record_err<T>(&self, result: Result<T, DriverError>) {
if let Err(err) = result {
self.error_state.store(err.0 as u32, Ordering::Relaxed)
}
}
}
#[derive(Debug)]
pub struct CudaEvent {
pub(crate) cu_event: sys::CUevent,
pub(crate) ctx: Arc<CudaContext>,
}
unsafe impl Send for CudaEvent {}
unsafe impl Sync for CudaEvent {}
impl Drop for CudaEvent {
fn drop(&mut self) {
self.ctx.record_err(self.ctx.bind_to_thread());
self.ctx
.record_err(unsafe { result::event::destroy(self.cu_event) });
}
}
impl CudaContext {
pub fn new_event(
self: &Arc<Self>,
flags: Option<sys::CUevent_flags>,
) -> Result<CudaEvent, DriverError> {
let flags = flags.unwrap_or(sys::CUevent_flags::CU_EVENT_DISABLE_TIMING);
self.bind_to_thread()?;
let cu_event = result::event::create(flags)?;
Ok(CudaEvent {
cu_event,
ctx: self.clone(),
})
}
}
impl CudaEvent {
pub fn cu_event(&self) -> sys::CUevent {
self.cu_event
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
pub fn record(&self, stream: &CudaStream) -> Result<(), DriverError> {
if self.ctx != stream.ctx {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
}
self.ctx.bind_to_thread()?;
unsafe { result::event::record(self.cu_event, stream.cu_stream) }
}
pub fn synchronize(&self) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { result::event::synchronize(self.cu_event) }
}
pub fn elapsed_ms(&self, end: &Self) -> Result<f32, DriverError> {
if self.ctx != end.ctx {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
}
self.ctx.bind_to_thread()?;
self.synchronize()?;
end.synchronize()?;
unsafe { result::event::elapsed(self.cu_event, end.cu_event) }
}
pub fn is_complete(&self) -> bool {
unsafe { result::event::query(self.cu_event) }.is_ok()
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct CudaStream {
pub(crate) cu_stream: sys::CUstream,
pub(crate) ctx: Arc<CudaContext>,
}
unsafe impl Send for CudaStream {}
unsafe impl Sync for CudaStream {}
impl Drop for CudaStream {
fn drop(&mut self) {
self.ctx.record_err(self.ctx.bind_to_thread());
let cu_stream = std::mem::replace(&mut self.cu_stream, std::ptr::null_mut());
if !cu_stream.is_null() && cu_stream != (0x2 as _) {
self.ctx.num_streams.fetch_sub(1, Ordering::Relaxed);
self.ctx
.record_err(unsafe { result::stream::destroy(cu_stream) });
}
}
}
impl CudaContext {
pub fn default_stream(self: &Arc<Self>) -> Arc<CudaStream> {
Arc::new(CudaStream {
cu_stream: std::ptr::null_mut(),
ctx: self.clone(),
})
}
pub fn per_thread_stream(self: &Arc<Self>) -> Arc<CudaStream> {
Arc::new(CudaStream {
cu_stream: 0x2 as _,
ctx: self.clone(),
})
}
pub fn new_stream(self: &Arc<Self>) -> Result<Arc<CudaStream>, DriverError> {
self.bind_to_thread()?;
let prev_num_streams = self.num_streams.fetch_add(1, Ordering::Relaxed);
if prev_num_streams == 0 && self.is_event_tracking() {
self.synchronize()?;
}
let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
Ok(Arc::new(CudaStream {
cu_stream,
ctx: self.clone(),
}))
}
}
impl CudaStream {
pub fn fork(&self) -> Result<Arc<Self>, DriverError> {
self.ctx.bind_to_thread()?;
self.ctx.num_streams.fetch_add(1, Ordering::Relaxed);
let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
let stream = Arc::new(CudaStream {
cu_stream,
ctx: self.ctx.clone(),
});
stream.join(self)?;
Ok(stream)
}
pub fn cu_stream(&self) -> sys::CUstream {
self.cu_stream
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
pub fn synchronize(&self) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { result::stream::synchronize(self.cu_stream) }
}
pub fn record_event(
&self,
flags: Option<sys::CUevent_flags>,
) -> Result<CudaEvent, DriverError> {
let event = self.ctx.new_event(flags)?;
event.record(self)?;
Ok(event)
}
pub fn wait(&self, event: &CudaEvent) -> Result<(), DriverError> {
if self.ctx != event.ctx {
return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
}
self.ctx.bind_to_thread()?;
unsafe {
result::stream::wait_event(
self.cu_stream,
event.cu_event,
sys::CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
)
}
}
pub fn join(&self, other: &CudaStream) -> Result<(), DriverError> {
self.wait(&other.record_event(None)?)
}
}
#[derive(Debug)]
pub struct CudaSlice<T> {
pub(crate) cu_device_ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) read: Option<CudaEvent>,
pub(crate) write: Option<CudaEvent>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) marker: PhantomData<*const T>,
}
unsafe impl<T> Send for CudaSlice<T> {}
unsafe impl<T> Sync for CudaSlice<T> {}
impl<T> Drop for CudaSlice<T> {
fn drop(&mut self) {
let ctx = &self.stream.ctx;
if let Some(read) = self.read.as_ref() {
ctx.record_err(self.stream.wait(read));
}
if let Some(write) = self.write.as_ref() {
ctx.record_err(self.stream.wait(write));
}
if ctx.has_async_alloc {
ctx.record_err(unsafe {
result::free_async(self.cu_device_ptr, self.stream.cu_stream)
});
} else {
ctx.record_err(self.stream.synchronize());
ctx.record_err(unsafe { result::free_sync(self.cu_device_ptr) });
}
}
}
impl<T> CudaSlice<T> {
pub fn len(&self) -> usize {
self.len
}
pub fn num_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn ordinal(&self) -> usize {
self.stream.ctx.ordinal
}
pub fn context(&self) -> &Arc<CudaContext> {
&self.stream.ctx
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl<T: DeviceRepr> CudaSlice<T> {
pub fn try_clone(&self) -> Result<Self, result::DriverError> {
self.stream.clone_dtod(self)
}
}
impl<T: DeviceRepr> Clone for CudaSlice<T> {
fn clone(&self) -> Self {
self.try_clone().unwrap()
}
}
impl<T: Clone + Default + DeviceRepr> TryFrom<CudaSlice<T>> for Vec<T> {
type Error = result::DriverError;
fn try_from(value: CudaSlice<T>) -> Result<Self, Self::Error> {
value.stream.clone_dtoh(&value)
}
}
#[derive(Debug)]
pub struct CudaView<'a, T> {
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) read: &'a Option<CudaEvent>,
pub(crate) write: &'a Option<CudaEvent>,
pub(crate) stream: &'a Arc<CudaStream>,
marker: PhantomData<&'a [T]>,
}
impl<T> CudaSlice<T> {
pub fn as_view(&self) -> CudaView<'_, T> {
CudaView {
ptr: self.cu_device_ptr,
len: self.len,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
}
}
}
impl<T> CudaView<'_, T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
fn resize(&self, start: usize, end: usize) -> Self {
assert!(start <= end && end <= self.len);
Self {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
}
}
}
#[derive(Debug)]
pub struct CudaViewMut<'a, T> {
pub(crate) ptr: sys::CUdeviceptr,
pub(crate) len: usize,
pub(crate) read: &'a Option<CudaEvent>,
pub(crate) write: &'a Option<CudaEvent>,
pub(crate) stream: &'a Arc<CudaStream>,
marker: PhantomData<&'a mut [T]>,
}
impl<T> CudaSlice<T> {
pub fn as_view_mut(&mut self) -> CudaViewMut<'_, T> {
CudaViewMut {
ptr: self.cu_device_ptr,
len: self.len,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
}
}
}
impl<T> CudaViewMut<'_, T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_view<'b>(&'b self) -> CudaView<'b, T> {
CudaView {
ptr: self.ptr,
len: self.len,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
}
}
}
pub unsafe trait ValidAsZeroBits {}
unsafe impl ValidAsZeroBits for bool {}
unsafe impl ValidAsZeroBits for i8 {}
unsafe impl ValidAsZeroBits for i16 {}
unsafe impl ValidAsZeroBits for i32 {}
unsafe impl ValidAsZeroBits for i64 {}
unsafe impl ValidAsZeroBits for i128 {}
unsafe impl ValidAsZeroBits for isize {}
unsafe impl ValidAsZeroBits for u8 {}
unsafe impl ValidAsZeroBits for u16 {}
unsafe impl ValidAsZeroBits for u32 {}
unsafe impl ValidAsZeroBits for u64 {}
unsafe impl ValidAsZeroBits for u128 {}
unsafe impl ValidAsZeroBits for usize {}
unsafe impl ValidAsZeroBits for f32 {}
unsafe impl ValidAsZeroBits for f64 {}
#[cfg(feature = "f16")]
unsafe impl ValidAsZeroBits for half::f16 {}
#[cfg(feature = "f16")]
unsafe impl ValidAsZeroBits for half::bf16 {}
unsafe impl<T: ValidAsZeroBits, const M: usize> ValidAsZeroBits for [T; M] {}
macro_rules! impl_tuples {
($t:tt) => {
impl_tuples!(@ $t);
};
($l:tt $(,$t:tt)+) => {
impl_tuples!($($t),+);
impl_tuples!(@ $l $(,$t)+);
};
(@ $($t:tt),+) => {
unsafe impl<$($t: ValidAsZeroBits,)+> ValidAsZeroBits for ($($t,)+) {}
};
}
impl_tuples!(A, B, C, D, E, F, G, H, I, J, K, L);
pub unsafe trait DeviceRepr {}
unsafe impl DeviceRepr for bool {}
unsafe impl DeviceRepr for i8 {}
unsafe impl DeviceRepr for i16 {}
unsafe impl DeviceRepr for i32 {}
unsafe impl DeviceRepr for i64 {}
unsafe impl DeviceRepr for i128 {}
unsafe impl DeviceRepr for isize {}
unsafe impl DeviceRepr for u8 {}
unsafe impl DeviceRepr for u16 {}
unsafe impl DeviceRepr for u32 {}
unsafe impl DeviceRepr for u64 {}
unsafe impl DeviceRepr for u128 {}
unsafe impl DeviceRepr for usize {}
unsafe impl DeviceRepr for f32 {}
unsafe impl DeviceRepr for f64 {}
#[cfg(feature = "f16")]
unsafe impl DeviceRepr for half::f16 {}
#[cfg(feature = "f16")]
unsafe impl DeviceRepr for half::bf16 {}
#[cfg(feature = "f8")]
unsafe impl DeviceRepr for float8::F8E4M3 {}
#[cfg(feature = "f8")]
unsafe impl ValidAsZeroBits for float8::F8E4M3 {}
#[cfg(feature = "f8")]
unsafe impl DeviceRepr for float8::F8E5M2 {}
#[cfg(feature = "f8")]
unsafe impl ValidAsZeroBits for float8::F8E5M2 {}
#[cfg(feature = "f4")]
unsafe impl DeviceRepr for float4::F4E2M1 {}
#[cfg(feature = "f4")]
unsafe impl ValidAsZeroBits for float4::F4E2M1 {}
#[cfg(feature = "f4")]
unsafe impl DeviceRepr for float4::E8M0 {}
#[cfg(feature = "f4")]
unsafe impl ValidAsZeroBits for float4::E8M0 {}
#[cfg(feature = "f4")]
unsafe impl DeviceRepr for float4::F4E2M1x2 {}
#[cfg(feature = "f4")]
unsafe impl ValidAsZeroBits for float4::F4E2M1x2 {}
unsafe impl<const N: usize, T> DeviceRepr for [T; N] where T: DeviceRepr {}
pub trait DeviceSlice<T> {
fn len(&self) -> usize;
fn num_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn stream(&self) -> &Arc<CudaStream>;
}
impl<T> DeviceSlice<T> for CudaSlice<T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl<T> DeviceSlice<T> for CudaView<'_, T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
self.stream
}
}
impl<T> DeviceSlice<T> for CudaViewMut<'_, T> {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
self.stream
}
}
#[derive(Debug)]
#[must_use]
pub enum SyncOnDrop<'a> {
Record(Option<(&'a CudaEvent, &'a CudaStream)>),
Sync(Option<&'a CudaStream>),
}
impl<'a> SyncOnDrop<'a> {
pub fn record_event(event: &'a Option<CudaEvent>, stream: &'a CudaStream) -> Self {
SyncOnDrop::Record(event.as_ref().map(|e| (e, stream)))
}
pub fn sync_stream(stream: &'a CudaStream) -> Self {
SyncOnDrop::Sync(Some(stream))
}
}
impl Drop for SyncOnDrop<'_> {
fn drop(&mut self) {
match self {
SyncOnDrop::Record(target) => {
if let Some((event, stream)) = std::mem::take(target) {
stream.ctx.record_err(event.record(stream));
}
}
SyncOnDrop::Sync(target) => {
if let Some(stream) = std::mem::take(target) {
stream.ctx.record_err(stream.synchronize());
}
}
}
}
}
pub trait DevicePtr<T>: DeviceSlice<T> {
fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
}
impl<T> DevicePtr<T> for CudaSlice<T> {
fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
if self.stream.context().is_managing_stream_synchronization() {
if let Some(write) = self.write.as_ref() {
stream.ctx.record_err(stream.wait(write));
}
}
(
self.cu_device_ptr,
SyncOnDrop::record_event(&self.read, stream),
)
}
}
impl<T> DevicePtr<T> for CudaView<'_, T> {
fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
if self.stream.context().is_managing_stream_synchronization() {
if let Some(write) = self.write.as_ref() {
stream.ctx.record_err(stream.wait(write));
}
}
(self.ptr, SyncOnDrop::record_event(self.read, stream))
}
}
impl<T> DevicePtr<T> for CudaViewMut<'_, T> {
fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
if self.stream.context().is_managing_stream_synchronization() {
if let Some(write) = self.write.as_ref() {
stream.ctx.record_err(stream.wait(write));
}
}
(self.ptr, SyncOnDrop::record_event(self.read, stream))
}
}
pub trait DevicePtrMut<T>: DeviceSlice<T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
}
impl<T> DevicePtrMut<T> for CudaSlice<T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
if self.stream.context().is_managing_stream_synchronization() {
if let Some(read) = self.read.as_ref() {
stream.ctx.record_err(stream.wait(read));
}
if let Some(write) = self.write.as_ref() {
stream.ctx.record_err(stream.wait(write));
}
}
(
self.cu_device_ptr,
SyncOnDrop::record_event(&self.write, stream),
)
}
}
impl<T> DevicePtrMut<T> for CudaViewMut<'_, T> {
fn device_ptr_mut<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
if self.stream.context().is_managing_stream_synchronization() {
if let Some(read) = self.read.as_ref() {
stream.ctx.record_err(stream.wait(read));
}
if let Some(write) = self.write.as_ref() {
stream.ctx.record_err(stream.wait(write));
}
}
(self.ptr, SyncOnDrop::record_event(self.write, stream))
}
}
pub trait HostSlice<T> {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
unsafe fn stream_synced_slice<'a>(
&'a self,
stream: &'a CudaStream,
) -> (&'a [T], SyncOnDrop<'a>);
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (&'a mut [T], SyncOnDrop<'a>);
}
impl<T, const N: usize> HostSlice<T> for [T; N] {
fn len(&self) -> usize {
N
}
unsafe fn stream_synced_slice<'a>(
&'a self,
_stream: &'a CudaStream,
) -> (&'a [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
_stream: &'a CudaStream,
) -> (&'a mut [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
}
impl<T> HostSlice<T> for [T] {
fn len(&self) -> usize {
self.len()
}
unsafe fn stream_synced_slice<'a>(
&'a self,
_stream: &'a CudaStream,
) -> (&'a [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
_stream: &'a CudaStream,
) -> (&'a mut [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
}
impl<T> HostSlice<T> for Vec<T> {
fn len(&self) -> usize {
self.len()
}
unsafe fn stream_synced_slice<'a>(
&'a self,
_stream: &'a CudaStream,
) -> (&'a [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
_stream: &'a CudaStream,
) -> (&'a mut [T], SyncOnDrop<'a>) {
(self, SyncOnDrop::Sync(None))
}
}
#[derive(Debug)]
pub struct PinnedHostSlice<T> {
pub(crate) ptr: *mut T,
pub(crate) len: usize,
pub(crate) event: CudaEvent,
}
unsafe impl<T> Send for PinnedHostSlice<T> {}
unsafe impl<T> Sync for PinnedHostSlice<T> {}
impl<T> Drop for PinnedHostSlice<T> {
fn drop(&mut self) {
let ctx = &self.event.ctx;
ctx.record_err(self.event.synchronize());
ctx.record_err(unsafe { result::free_host(self.ptr as _) });
}
}
impl CudaContext {
pub unsafe fn alloc_pinned<T: DeviceRepr>(
self: &Arc<Self>,
len: usize,
) -> Result<PinnedHostSlice<T>, DriverError> {
self.bind_to_thread()?;
let ptr = result::malloc_host(
len * std::mem::size_of::<T>(),
sys::CU_MEMHOSTALLOC_WRITECOMBINED,
)?;
let ptr = ptr as *mut T;
assert!(!ptr.is_null());
assert!(len * std::mem::size_of::<T>() < isize::MAX as usize);
assert!(ptr.is_aligned());
let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
Ok(PinnedHostSlice { ptr, len, event })
}
}
impl<T> PinnedHostSlice<T> {
pub fn context(&self) -> &Arc<CudaContext> {
&self.event.ctx
}
pub fn len(&self) -> usize {
self.len
}
pub fn num_bytes(&self) -> usize {
self.len * std::mem::size_of::<T>()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: ValidAsZeroBits> PinnedHostSlice<T> {
pub fn as_ptr(&self) -> Result<*const T, DriverError> {
self.event.synchronize()?;
Ok(self.ptr)
}
pub fn as_mut_ptr(&mut self) -> Result<*mut T, DriverError> {
self.event.synchronize()?;
Ok(self.ptr)
}
pub fn as_slice(&self) -> Result<&[T], DriverError> {
self.event.synchronize()?;
Ok(unsafe { std::slice::from_raw_parts(self.ptr, self.len) })
}
pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
self.event.synchronize()?;
Ok(unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) })
}
}
impl<T> HostSlice<T> for PinnedHostSlice<T> {
fn len(&self) -> usize {
self.len
}
unsafe fn stream_synced_slice<'a>(
&'a self,
stream: &'a CudaStream,
) -> (&'a [T], SyncOnDrop<'a>) {
stream.ctx.record_err(stream.wait(&self.event));
(
std::slice::from_raw_parts(self.ptr, self.len),
SyncOnDrop::Record(Some((&self.event, stream))),
)
}
unsafe fn stream_synced_mut_slice<'a>(
&'a mut self,
stream: &'a CudaStream,
) -> (&'a mut [T], SyncOnDrop<'a>) {
stream.ctx.record_err(stream.wait(&self.event));
(
std::slice::from_raw_parts_mut(self.ptr, self.len),
SyncOnDrop::Record(Some((&self.event, stream))),
)
}
}
impl CudaStream {
pub fn null<T>(self: &Arc<Self>) -> Result<CudaSlice<T>, result::DriverError> {
self.ctx.bind_to_thread()?;
let cu_device_ptr = if self.ctx.has_async_alloc {
unsafe { result::malloc_async(self.cu_stream, 0) }?
} else {
unsafe { result::malloc_sync(0) }?
};
Ok(CudaSlice {
cu_device_ptr,
len: 0,
read: None,
write: None,
stream: self.clone(),
marker: PhantomData,
})
}
pub unsafe fn alloc<T: DeviceRepr>(
self: &Arc<Self>,
len: usize,
) -> Result<CudaSlice<T>, DriverError> {
self.ctx.bind_to_thread()?;
let cu_device_ptr = if self.ctx.has_async_alloc {
result::malloc_async(self.cu_stream, len * std::mem::size_of::<T>())?
} else {
result::malloc_sync(len * std::mem::size_of::<T>())?
};
let (read, write) = if self.ctx.is_event_tracking() {
(
Some(self.ctx.new_event(None)?),
Some(self.ctx.new_event(None)?),
)
} else {
(None, None)
};
Ok(CudaSlice {
cu_device_ptr,
len,
read,
write,
stream: self.clone(),
marker: PhantomData,
})
}
pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(
self: &Arc<Self>,
len: usize,
) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(len) }?;
self.memset_zeros(&mut dst)?;
Ok(dst)
}
pub fn memset_zeros<T: DeviceRepr + ValidAsZeroBits, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
dst: &mut Dst,
) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
let num_bytes = dst.num_bytes();
let (dptr, _record) = dst.device_ptr_mut(self);
unsafe { result::memset_d8_async(dptr, 0, num_bytes, self.cu_stream) }?;
Ok(())
}
#[deprecated = "Use clone_htod"]
pub fn memcpy_stod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
self: &Arc<Self>,
src: &Src,
) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.memcpy_htod(src, &mut dst)?;
Ok(dst)
}
pub fn clone_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
self: &Arc<Self>,
src: &Src,
) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.memcpy_htod(src, &mut dst)?;
Ok(dst)
}
pub fn memcpy_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
src: &Src,
dst: &mut Dst,
) -> Result<(), DriverError> {
assert!(dst.len() >= src.len());
self.ctx.bind_to_thread()?;
let (src, _record_src) = unsafe { src.stream_synced_slice(self) };
let (dst, _record_dst) = dst.device_ptr_mut(self);
unsafe { result::memcpy_htod_async(dst, src, self.cu_stream) }
}
#[deprecated = "Use clone_dtoh"]
pub fn memcpy_dtov<T: DeviceRepr, Src: DevicePtr<T>>(
self: &Arc<Self>,
src: &Src,
) -> Result<Vec<T>, DriverError> {
let mut dst = Vec::with_capacity(src.len());
#[allow(clippy::uninit_vec)]
unsafe {
dst.set_len(src.len())
};
self.memcpy_dtoh(src, &mut dst)?;
Ok(dst)
}
pub fn clone_dtoh<T: DeviceRepr, Src: DevicePtr<T>>(
self: &Arc<Self>,
src: &Src,
) -> Result<Vec<T>, DriverError> {
let mut dst = Vec::with_capacity(src.len());
#[allow(clippy::uninit_vec)]
unsafe {
dst.set_len(src.len())
};
self.memcpy_dtoh(src, &mut dst)?;
Ok(dst)
}
pub fn memcpy_dtoh<T: DeviceRepr, Src: DevicePtr<T>, Dst: HostSlice<T> + ?Sized>(
self: &Arc<Self>,
src: &Src,
dst: &mut Dst,
) -> Result<(), DriverError> {
assert!(dst.len() >= src.len());
self.ctx.bind_to_thread()?;
let (src, _record_src) = src.device_ptr(self);
let (dst, _record_dst) = unsafe { dst.stream_synced_mut_slice(self) };
unsafe { result::memcpy_dtoh_async(dst, src, self.cu_stream) }
}
pub fn memcpy_dtod<T, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
self: &Arc<Self>,
src: &Src,
dst: &mut Dst,
) -> Result<(), DriverError> {
assert!(dst.len() >= src.len());
self.ctx.bind_to_thread()?;
let num_bytes = src.num_bytes();
let src_ctx = src.stream().context();
let dst_ctx = self.context();
let (src, _record_src) = src.device_ptr(self);
let (dst, _record_dst) = dst.device_ptr_mut(self);
if src_ctx == dst_ctx {
unsafe { result::memcpy_dtod_async(dst, src, num_bytes, self.cu_stream) }
} else {
unsafe {
result::memcpy_peer_async(
dst_ctx.cu_ctx,
dst,
src_ctx.cu_ctx,
src,
num_bytes,
self.cu_stream,
)
}
}
}
pub fn clone_dtod<T: DeviceRepr, Src: DevicePtr<T>>(
self: &Arc<Self>,
src: &Src,
) -> Result<CudaSlice<T>, DriverError> {
let mut dst = unsafe { self.alloc(src.len()) }?;
self.memcpy_dtod(src, &mut dst)?;
Ok(dst)
}
}
impl<T> CudaSlice<T> {
pub fn slice(&self, bounds: impl RangeBounds<usize>) -> CudaView<'_, T> {
self.as_view().slice(bounds)
}
pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'_, T>> {
self.as_view().try_slice(bounds)
}
pub fn slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'_, T> {
self.try_slice_mut(bounds).unwrap()
}
pub fn try_slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> Option<CudaViewMut<'_, T>> {
to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
})
}
pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
self.as_view().transmute(len)
}
pub unsafe fn transmute_mut<S>(&mut self, len: usize) -> Option<CudaViewMut<'_, S>> {
(len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
CudaViewMut {
ptr: self.cu_device_ptr,
len,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
},
)
}
pub fn split_at(&self, mid: usize) -> (CudaView<'_, T>, CudaView<'_, T>) {
self.as_view().split_at(mid)
}
pub fn try_split_at(&self, mid: usize) -> Option<(CudaView<'_, T>, CudaView<'_, T>)> {
self.as_view().try_split_at(mid)
}
pub fn split_at_mut(&mut self, mid: usize) -> (CudaViewMut<'_, T>, CudaViewMut<'_, T>) {
self.try_split_at_mut(mid).unwrap()
}
pub fn try_split_at_mut(
&mut self,
mid: usize,
) -> Option<(CudaViewMut<'_, T>, CudaViewMut<'_, T>)> {
let length = self.len;
(mid <= length).then(|| {
let a = CudaViewMut {
ptr: self.cu_device_ptr,
len: mid,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
};
let b = CudaViewMut {
ptr: self.cu_device_ptr + (mid * std::mem::size_of::<T>()) as u64,
len: length - mid,
read: &self.read,
write: &self.write,
stream: &self.stream,
marker: PhantomData,
};
(a, b)
})
}
}
impl<'a, T> CudaView<'a, T> {
pub fn slice(&self, bounds: impl RangeBounds<usize>) -> Self {
self.try_slice(bounds).unwrap()
}
pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<Self> {
to_range(bounds, self.len).map(|(start, end)| self.resize(start, end))
}
pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'a, S>> {
(len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
CudaView {
ptr: self.ptr,
len,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
},
)
}
pub fn split_at(&self, mid: usize) -> (Self, Self) {
self.try_split_at(mid).unwrap()
}
pub fn try_split_at(&self, mid: usize) -> Option<(Self, Self)> {
(mid <= self.len()).then(|| (self.resize(0, mid), self.resize(mid, self.len)))
}
}
impl<'a, T> CudaViewMut<'a, T> {
pub fn slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> CudaView<'b, T> {
self.try_slice(bounds).unwrap()
}
pub fn try_slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'b, T>> {
to_range(bounds, self.len).map(move |(start, end)| self.as_view().resize(start, end))
}
pub unsafe fn transmute<'b, S>(&'b self, len: usize) -> Option<CudaView<'b, S>> {
(len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
CudaView {
ptr: self.ptr,
len,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
},
)
}
pub fn slice_mut<'b>(&'b mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'b, T> {
self.try_slice_mut(bounds).unwrap()
}
pub fn try_slice_mut<'b>(
&'b mut self,
bounds: impl RangeBounds<usize>,
) -> Option<CudaViewMut<'b, T>> {
to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
len: end - start,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
})
}
pub fn split_at_mut<'b>(&'b mut self, mid: usize) -> (CudaViewMut<'b, T>, CudaViewMut<'b, T>) {
self.try_split_at_mut(mid).unwrap()
}
pub fn try_split_at_mut<'b>(
&'b mut self,
mid: usize,
) -> Option<(CudaViewMut<'b, T>, CudaViewMut<'b, T>)> {
let length = self.len;
(mid <= length).then(|| {
let a = CudaViewMut {
ptr: self.ptr,
len: mid,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
};
let b = CudaViewMut {
ptr: self.ptr + (mid * std::mem::size_of::<T>()) as u64,
len: length - mid,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
};
(a, b)
})
}
pub unsafe fn transmute_mut<'b, S>(&'b mut self, len: usize) -> Option<CudaViewMut<'b, S>> {
(len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
CudaViewMut {
ptr: self.ptr,
len,
read: self.read,
write: self.write,
stream: self.stream,
marker: PhantomData,
},
)
}
}
pub(super) fn to_range(range: impl RangeBounds<usize>, len: usize) -> Option<(usize, usize)> {
let start = match range.start_bound() {
Bound::Included(&n) => n,
Bound::Excluded(&n) => n + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&n) => n + 1,
Bound::Excluded(&n) => n,
Bound::Unbounded => len,
};
(start <= end && end <= len).then_some((start, end))
}
#[derive(Debug)]
pub struct CudaModule {
pub(crate) cu_module: sys::CUmodule,
pub(crate) ctx: Arc<CudaContext>,
}
unsafe impl Send for CudaModule {}
unsafe impl Sync for CudaModule {}
impl Drop for CudaModule {
fn drop(&mut self) {
self.ctx.record_err(self.ctx.bind_to_thread());
self.ctx
.record_err(unsafe { result::module::unload(self.cu_module) });
}
}
impl CudaContext {
#[cfg(feature = "nvrtc")]
pub fn load_module(
self: &Arc<Self>,
ptx: crate::nvrtc::Ptx,
) -> Result<Arc<CudaModule>, result::DriverError> {
self.bind_to_thread()?;
let cu_module = match ptx.0 {
crate::nvrtc::PtxKind::Image(image) => unsafe {
result::module::load_data(image.as_ptr() as *const _)
},
crate::nvrtc::PtxKind::Src(src) => {
let c_src = CString::new(src).unwrap();
unsafe { result::module::load_data(c_src.as_ptr() as *const _) }
}
crate::nvrtc::PtxKind::File(path) => {
let name_c = CString::new(path.to_str().unwrap()).unwrap();
result::module::load(name_c)
}
crate::nvrtc::PtxKind::Binary(data) => unsafe {
result::module::load_data(data.as_ptr() as *const _)
},
}?;
Ok(Arc::new(CudaModule {
cu_module,
ctx: self.clone(),
}))
}
}
#[derive(Debug, Clone)]
pub struct CudaFunction {
pub(crate) cu_function: sys::CUfunction,
#[allow(unused)]
pub(crate) module: Arc<CudaModule>,
}
unsafe impl Send for CudaFunction {}
unsafe impl Sync for CudaFunction {}
impl CudaModule {
pub fn load_function(self: &Arc<Self>, fn_name: &str) -> Result<CudaFunction, DriverError> {
let fn_name_c = CString::new(fn_name).unwrap();
let cu_function = unsafe { result::module::get_function(self.cu_module, fn_name_c) }?;
Ok(CudaFunction {
cu_function,
module: self.clone(),
})
}
pub fn get_global<'a>(
self: &'a Arc<Self>,
name: &str,
stream: &'a Arc<CudaStream>,
) -> Result<CudaViewMut<'a, u8>, DriverError> {
let name_c =
CString::new(name).map_err(|_| DriverError(sys::CUresult::CUDA_ERROR_INVALID_VALUE))?;
let (cu_device_ptr, bytes) = unsafe { result::module::get_global(self.cu_module, name_c) }?;
Ok(CudaViewMut {
ptr: cu_device_ptr,
len: bytes,
read: &None,
write: &None,
stream,
marker: PhantomData,
})
}
}
impl CudaFunction {
pub fn occupancy_available_dynamic_smem_per_block(
&self,
num_blocks: u32,
block_size: u32,
) -> Result<usize, result::DriverError> {
let mut dynamic_smem_size: usize = 0;
unsafe {
sys::cuOccupancyAvailableDynamicSMemPerBlock(
&mut dynamic_smem_size,
self.cu_function,
num_blocks as std::ffi::c_int,
block_size as std::ffi::c_int,
)
.result()?
};
Ok(dynamic_smem_size)
}
pub fn occupancy_max_active_blocks_per_multiprocessor(
&self,
block_size: u32,
dynamic_smem_size: usize,
flags: Option<sys::CUoccupancy_flags_enum>,
) -> Result<u32, result::DriverError> {
let mut num_blocks: std::ffi::c_int = 0;
let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
unsafe {
sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
&mut num_blocks,
self.cu_function,
block_size as std::ffi::c_int,
dynamic_smem_size,
flags as std::ffi::c_uint,
)
.result()?
};
Ok(num_blocks as u32)
}
#[cfg(not(any(
feature = "cuda-11070",
feature = "cuda-11060",
feature = "cuda-11050",
feature = "cuda-11040"
)))]
pub fn occupancy_max_active_clusters(
&self,
config: crate::driver::LaunchConfig,
stream: &CudaStream,
) -> Result<u32, result::DriverError> {
let mut num_clusters: std::ffi::c_int = 0;
let cfg = sys::CUlaunchConfig {
gridDimX: config.grid_dim.0,
gridDimY: config.grid_dim.1,
gridDimZ: config.grid_dim.2,
blockDimX: config.block_dim.0,
blockDimY: config.block_dim.1,
blockDimZ: config.block_dim.2,
sharedMemBytes: config.shared_mem_bytes,
hStream: stream.cu_stream,
attrs: std::ptr::null_mut(),
numAttrs: 0,
};
unsafe {
sys::cuOccupancyMaxActiveClusters(&mut num_clusters, self.cu_function, &cfg).result()?
};
Ok(num_clusters as u32)
}
pub fn occupancy_max_potential_block_size(
&self,
block_size_to_dynamic_smem_size: extern "C" fn(block_size: std::ffi::c_int) -> usize,
dynamic_smem_size: usize,
block_size_limit: u32,
flags: Option<sys::CUoccupancy_flags_enum>,
) -> Result<(u32, u32), result::DriverError> {
let mut min_grid_size: std::ffi::c_int = 0;
let mut block_size: std::ffi::c_int = 0;
let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
unsafe {
sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
&mut min_grid_size,
&mut block_size,
self.cu_function,
Some(block_size_to_dynamic_smem_size),
dynamic_smem_size,
block_size_limit as std::ffi::c_int,
flags as std::ffi::c_uint,
)
.result()?
};
Ok((min_grid_size as u32, block_size as u32))
}
#[cfg(not(any(
feature = "cuda-11070",
feature = "cuda-11060",
feature = "cuda-11050",
feature = "cuda-11040"
)))]
pub fn occupancy_max_potential_cluster_size(
&self,
config: crate::driver::LaunchConfig,
stream: &CudaStream,
) -> Result<u32, result::DriverError> {
let mut cluster_size: std::ffi::c_int = 0;
let cfg = sys::CUlaunchConfig {
gridDimX: config.grid_dim.0,
gridDimY: config.grid_dim.1,
gridDimZ: config.grid_dim.2,
blockDimX: config.block_dim.0,
blockDimY: config.block_dim.1,
blockDimZ: config.block_dim.2,
sharedMemBytes: config.shared_mem_bytes,
hStream: stream.cu_stream,
attrs: std::ptr::null_mut(),
numAttrs: 0,
};
unsafe {
sys::cuOccupancyMaxPotentialClusterSize(&mut cluster_size, self.cu_function, &cfg)
.result()?
};
Ok(cluster_size as u32)
}
pub fn get_attribute(
&self,
attribute: CUfunction_attribute_enum,
) -> Result<i32, result::DriverError> {
self.module.ctx.bind_to_thread()?;
unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
}
pub fn num_regs(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
}
pub fn shared_size_bytes(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
}
pub fn const_size_bytes(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
}
pub fn local_size_bytes(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
}
pub fn max_threads_per_block(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
}
pub fn ptx_version(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
}
pub fn binary_version(&self) -> Result<i32, result::DriverError> {
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
}
pub fn set_attribute(
&self,
attribute: CUfunction_attribute_enum,
value: i32,
) -> Result<(), result::DriverError> {
unsafe { result::function::set_function_attribute(self.cu_function, attribute, value) }
}
pub fn set_function_cache_config(
&self,
attribute: CUfunc_cache_enum,
) -> Result<(), result::DriverError> {
unsafe { result::function::set_function_cache_config(self.cu_function, attribute) }
}
}
impl<T> CudaSlice<T> {
pub fn leak(self) -> sys::CUdeviceptr {
let mut s = std::mem::ManuallyDrop::new(self);
let ptr = s.cu_device_ptr;
if let Some(read) = s.read.as_ref() {
s.stream.ctx.record_err(s.stream.wait(read));
}
if let Some(write) = s.write.as_ref() {
s.stream.ctx.record_err(s.stream.wait(write));
}
unsafe {
std::ptr::drop_in_place(&mut s.read);
std::ptr::drop_in_place(&mut s.write);
std::ptr::drop_in_place(&mut s.stream);
}
ptr
}
}
impl CudaStream {
pub unsafe fn upgrade_device_ptr<T>(
self: &Arc<Self>,
cu_device_ptr: sys::CUdeviceptr,
len: usize,
) -> CudaSlice<T> {
let (read, write) = if self.ctx.is_event_tracking() {
(
Some(self.ctx.new_event(None).unwrap()),
Some(self.ctx.new_event(None).unwrap()),
)
} else {
(None, None)
};
CudaSlice {
cu_device_ptr,
len,
read,
write,
stream: self.clone(),
marker: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
#[test]
fn test_transmutes() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
assert!(unsafe { slice.transmute::<f32>(25) }.is_some());
assert!(unsafe { slice.transmute::<f32>(26) }.is_none());
assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());
{
let view = slice.slice(0..100);
assert!(unsafe { view.transmute::<f32>(25) }.is_some());
assert!(unsafe { view.transmute::<f32>(26) }.is_none());
}
{
let mut view_mut = slice.slice_mut(0..100);
assert!(unsafe { view_mut.transmute::<f32>(25) }.is_some());
assert!(unsafe { view_mut.transmute::<f32>(26) }.is_none());
assert!(unsafe { view_mut.transmute_mut::<f32>(25) }.is_some());
assert!(unsafe { view_mut.transmute_mut::<f32>(26) }.is_none());
}
}
#[test]
fn test_threading() {
let ctx1 = CudaContext::new(0).unwrap();
let ctx2 = ctx1.clone();
let thread1 = std::thread::spawn(move || {
ctx1.bind_to_thread()?;
ctx1.default_stream().alloc_zeros::<f32>(10)
});
let thread2 = std::thread::spawn(move || {
ctx2.bind_to_thread()?;
ctx2.default_stream().alloc_zeros::<f32>(10)
});
let _: crate::driver::CudaSlice<f32> = thread1.join().unwrap().unwrap();
let _: crate::driver::CudaSlice<f32> = thread2.join().unwrap().unwrap();
}
#[test]
fn test_post_build_arc_count() {
let ctx = CudaContext::new(0).unwrap();
assert_eq!(Arc::strong_count(&ctx), 1);
}
#[test]
fn test_post_alloc_arc_counts() {
let ctx = CudaContext::new(0).unwrap();
assert_eq!(Arc::strong_count(&ctx), 1);
let stream = ctx.default_stream();
assert_eq!(Arc::strong_count(&ctx), 2);
let t = stream.alloc_zeros::<f32>(1).unwrap();
assert_eq!(Arc::strong_count(&ctx), 4);
assert_eq!(Arc::strong_count(&stream), 2);
drop(t);
assert_eq!(Arc::strong_count(&ctx), 2);
assert_eq!(Arc::strong_count(&stream), 1);
drop(stream);
assert_eq!(Arc::strong_count(&ctx), 1);
}
#[test]
#[ignore = "must be executed by itself"]
fn test_post_alloc_memory() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let (free1, total1) = ctx.mem_get_info().unwrap();
let t = stream.clone_htod(&[0.0f32; 5]).unwrap();
let (free2, total2) = ctx.mem_get_info().unwrap();
assert_eq!(total1, total2);
assert!(free2 < free1);
drop(t);
ctx.synchronize().unwrap();
let (free3, total3) = ctx.mem_get_info().unwrap();
assert_eq!(total2, total3);
assert!(free3 > free2);
assert_eq!(free3, free1);
}
#[test]
fn test_ctx_copy_to_views() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let smalls = [
stream.clone_htod(&[-1.0f32, -0.8]).unwrap(),
stream.clone_htod(&[-0.6, -0.4]).unwrap(),
stream.clone_htod(&[-0.2, 0.0]).unwrap(),
stream.clone_htod(&[0.2, 0.4]).unwrap(),
stream.clone_htod(&[0.6, 0.8]).unwrap(),
];
let mut big = stream.alloc_zeros::<f32>(10).unwrap();
let mut offset = 0;
for small in smalls.iter() {
let mut sub = big.slice_mut(offset..offset + small.len());
stream.memcpy_dtod(small, &mut sub).unwrap();
offset += small.len();
}
assert_eq!(
stream.clone_dtoh(&big).unwrap(),
[-1.0, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
);
}
#[test]
fn test_leak_and_upgrade() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let a = stream.clone_htod(&[1.0f32, 2.0, 3.0, 4.0, 5.0]).unwrap();
let ptr = a.leak();
let b = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 3) };
assert_eq!(stream.clone_dtoh(&b).unwrap(), &[1.0, 2.0, 3.0]);
let ptr = b.leak();
let c = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 5) };
assert_eq!(stream.clone_dtoh(&c).unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_slice_is_freed_with_correct_context() {
let ctx0 = CudaContext::new(0).unwrap();
let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
let ctx1 = CudaContext::new(0).unwrap();
ctx1.bind_to_thread().unwrap();
drop(ctx0);
drop(slice);
drop(ctx1);
}
#[test]
fn test_copy_uses_correct_context() {
let ctx0 = CudaContext::new(0).unwrap();
let _ctx1 = CudaContext::new(0).unwrap();
let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
let _out = ctx0.default_stream().clone_dtoh(&slice).unwrap();
}
#[test]
fn test_htod_copy_pinned() {
let truth = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let mut pinned = unsafe { ctx.alloc_pinned::<f32>(10) }.unwrap();
pinned.as_mut_slice().unwrap().clone_from_slice(&truth);
assert_eq!(pinned.as_slice().unwrap(), &truth);
let dst = stream.clone_htod(&pinned).unwrap();
let host = stream.clone_dtoh(&dst).unwrap();
assert_eq!(&host, &truth);
}
#[test]
fn test_pinned_copy_is_faster() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let n = 100_000;
let n_samples = 5;
let not_pinned = std::vec![0.0f32; n];
let start = Instant::now();
for _ in 0..n_samples {
let _ = stream.clone_htod(¬_pinned).unwrap();
stream.synchronize().unwrap();
}
let unpinned_elapsed = start.elapsed() / n_samples;
let pinned = unsafe { ctx.alloc_pinned::<f32>(n) }.unwrap();
let start = Instant::now();
for _ in 0..n_samples {
let _ = stream.clone_htod(&pinned).unwrap();
stream.synchronize().unwrap();
}
let pinned_elapsed = start.elapsed() / n_samples;
assert!(
pinned_elapsed.as_secs_f32() * 1.5 < unpinned_elapsed.as_secs_f32(),
"{unpinned_elapsed:?} vs {pinned_elapsed:?}"
);
}
#[test]
fn test_primary_context_is_primary() {
let ctx = CudaContext::new(0).unwrap();
assert!(ctx.is_primary());
}
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
))]
fn create_non_primary_context() -> (sys::CUdevice, sys::CUcontext) {
result::init().unwrap();
let cu_device = result::device::get(0).unwrap();
#[cfg(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010",
))]
let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), 0, cu_device) }
.expect("cuCtxCreate_v4 failed");
#[cfg(not(any(
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010",
)))]
let cu_ctx =
unsafe { result::ctx::create_v3(0, cu_device) }.expect("cuCtxCreate_v3 failed");
assert!(!cu_ctx.is_null());
(cu_device, cu_ctx)
}
#[test]
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
))]
fn test_from_raw_context_creates_and_destroys() {
let (cu_device, cu_ctx) = create_non_primary_context();
let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
assert!(!ctx.is_primary());
ctx.bind_to_thread().unwrap();
drop(ctx);
}
#[test]
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
))]
fn test_from_raw_context_bind_to_thread() {
let (cu_device, cu_ctx) = create_non_primary_context();
let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
let ctx2 = ctx.clone();
let handle = std::thread::spawn(move || {
ctx2.bind_to_thread().unwrap();
let stream = ctx2.default_stream();
let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
let result = stream.clone_dtoh(&data).unwrap();
assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
});
handle.join().unwrap();
}
#[test]
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010",
))]
fn test_new_non_primary_creates_and_destroys() {
let ctx = CudaContext::new_non_primary(0, 0).unwrap();
assert!(!ctx.is_primary());
ctx.bind_to_thread().unwrap();
drop(ctx);
}
#[test]
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010",
))]
fn test_new_non_primary_htod_dtoh() {
let ctx = CudaContext::new_non_primary(0, 0).unwrap();
let stream = ctx.default_stream();
let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
let result = stream.clone_dtoh(&data).unwrap();
assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
}
#[test]
#[cfg(any(
feature = "cuda-11040",
feature = "cuda-11050",
feature = "cuda-11060",
feature = "cuda-11070",
feature = "cuda-11080",
feature = "cuda-12000",
feature = "cuda-12010",
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
feature = "cuda-13010",
))]
fn test_new_non_primary_cross_thread() {
let ctx = CudaContext::new_non_primary(0, 0).unwrap();
let ctx2 = ctx.clone();
let handle = std::thread::spawn(move || {
ctx2.bind_to_thread().unwrap();
let stream = ctx2.default_stream();
let data = stream.clone_htod(&[4.0f32, 5.0, 6.0]).unwrap();
let result = stream.clone_dtoh(&data).unwrap();
assert_eq!(result, std::vec![4.0f32, 5.0, 6.0]);
});
handle.join().unwrap();
}
}