pub mod legacy;
use crate::{
device::Device,
error::{CudaResult, DropResult, ToResult},
private::Sealed,
sys as cuda, CudaApiVersion,
};
use legacy::StreamPriorityRange;
use std::{
mem::{self, transmute, MaybeUninit},
ptr,
};
pub trait ContextHandle: Sealed {
fn get_inner(&self) -> cuda::CUcontext;
}
impl Sealed for Context {}
impl ContextHandle for Context {
fn get_inner(&self) -> cuda::CUcontext {
self.inner
}
}
#[repr(u32)]
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum CacheConfig {
PreferNone = 0,
PreferShared = 1,
PreferL1 = 2,
PreferEqual = 3,
}
#[repr(u32)]
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum ResourceLimit {
StackSize = 0,
PrintfFifoSize = 1,
MallocHeapSize = 2,
DeviceRuntimeSynchronizeDepth = 3,
DeviceRuntimePendingLaunchCount = 4,
MaxL2FetchGranularity = 5,
}
#[repr(u32)]
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum SharedMemoryConfig {
DefaultBankSize = 0,
FourByteBankSize = 1,
EightByteBankSize = 2,
}
bitflags::bitflags! {
pub struct ContextFlags: u32 {
const SCHED_SPIN = 0x01;
const SCHED_YIELD = 0x02;
const SCHED_BLOCKING_SYNC = 0x04;
const SCHED_AUTO = 0x00;
const MAP_HOST = 0x08;
const LMEM_RESIZE_TO_MAX = 0x10;
}
}
#[derive(Debug)]
pub struct Context {
inner: cuda::CUcontext,
device: cuda::CUdevice,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Clone for Context {
fn clone(&self) -> Self {
Self::new(Device {
device: self.device,
})
.expect("Failed to clone context")
}
}
impl Context {
pub fn new(device: Device) -> CudaResult<Self> {
let mut inner = MaybeUninit::uninit();
unsafe {
cuda::cuDevicePrimaryCtxRetain(inner.as_mut_ptr(), device.as_raw()).to_result()?;
let inner = inner.assume_init();
cuda::cuCtxSetCurrent(inner);
Ok(Self {
inner,
device: device.as_raw(),
})
}
}
pub unsafe fn reset(device: &Device) -> CudaResult<()> {
cuda::cuDevicePrimaryCtxReset_v2(device.as_raw()).to_result()
}
pub fn set_flags(&self, flags: ContextFlags) -> CudaResult<()> {
unsafe { cuda::cuDevicePrimaryCtxSetFlags_v2(self.device, flags.bits()).to_result() }
}
pub fn as_raw(&self) -> cuda::CUcontext {
self.inner
}
pub fn get_api_version(&self) -> CudaResult<CudaApiVersion> {
unsafe {
let mut api_version = 0u32;
cuda::cuCtxGetApiVersion(self.inner, &mut api_version as *mut u32).to_result()?;
Ok(CudaApiVersion {
version: api_version as i32,
})
}
}
pub fn drop(mut ctx: Context) -> DropResult<Context> {
if ctx.inner.is_null() {
return Ok(());
}
unsafe {
let inner = mem::replace(&mut ctx.inner, ptr::null_mut());
match cuda::cuDevicePrimaryCtxRelease_v2(ctx.device).to_result() {
Ok(()) => {
mem::forget(ctx);
Ok(())
}
Err(e) => Err((
e,
Context {
inner,
device: ctx.device,
},
)),
}
}
}
}
impl Drop for Context {
fn drop(&mut self) {
if self.inner.is_null() {
return;
}
unsafe {
self.inner = ptr::null_mut();
cuda::cuDevicePrimaryCtxRelease_v2(self.device);
}
}
}
#[derive(Debug)]
pub struct CurrentContext;
impl CurrentContext {
pub fn get_cache_config() -> CudaResult<CacheConfig> {
unsafe {
let mut config = CacheConfig::PreferNone;
cuda::cuCtxGetCacheConfig(&mut config as *mut CacheConfig as *mut cuda::CUfunc_cache)
.to_result()?;
Ok(config)
}
}
pub fn get_device() -> CudaResult<Device> {
unsafe {
let mut device = Device { device: 0 };
cuda::cuCtxGetDevice(&mut device.device as *mut cuda::CUdevice).to_result()?;
Ok(device)
}
}
pub fn get_flags() -> CudaResult<ContextFlags> {
unsafe {
let mut flags = 0u32;
cuda::cuCtxGetFlags(&mut flags as *mut u32).to_result()?;
Ok(ContextFlags::from_bits_truncate(flags))
}
}
pub fn get_resource_limit(resource: ResourceLimit) -> CudaResult<usize> {
unsafe {
let mut limit: usize = 0;
cuda::cuCtxGetLimit(&mut limit as *mut usize, transmute(resource)).to_result()?;
Ok(limit)
}
}
pub fn get_shared_memory_config() -> CudaResult<SharedMemoryConfig> {
unsafe {
let mut cfg = SharedMemoryConfig::DefaultBankSize;
cuda::cuCtxGetSharedMemConfig(
&mut cfg as *mut SharedMemoryConfig as *mut cuda::CUsharedconfig,
)
.to_result()?;
Ok(cfg)
}
}
pub fn get_stream_priority_range() -> CudaResult<StreamPriorityRange> {
unsafe {
let mut range = StreamPriorityRange {
least: 0,
greatest: 0,
};
cuda::cuCtxGetStreamPriorityRange(
&mut range.least as *mut i32,
&mut range.greatest as *mut i32,
)
.to_result()?;
Ok(range)
}
}
pub fn set_cache_config(cfg: CacheConfig) -> CudaResult<()> {
unsafe { cuda::cuCtxSetCacheConfig(transmute(cfg)).to_result() }
}
pub fn set_resource_limit(resource: ResourceLimit, limit: usize) -> CudaResult<()> {
unsafe {
cuda::cuCtxSetLimit(transmute(resource), limit).to_result()?;
Ok(())
}
}
pub fn set_shared_memory_config(cfg: SharedMemoryConfig) -> CudaResult<()> {
unsafe { cuda::cuCtxSetSharedMemConfig(transmute(cfg)).to_result() }
}
pub fn set_current<C: ContextHandle>(c: &C) -> CudaResult<()> {
unsafe {
cuda::cuCtxSetCurrent(c.get_inner()).to_result()?;
Ok(())
}
}
pub fn synchronize() -> CudaResult<()> {
unsafe {
cuda::cuCtxSynchronize().to_result()?;
Ok(())
}
}
}