use crate::device::Device;
use crate::error::{CudaResult, DropResult, ToResult};
use crate::private::Sealed;
use crate::CudaApiVersion;
use cuda_driver_sys::CUcontext;
use std::mem;
use std::mem::transmute;
use std::ptr;
#[repr(u32)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum CacheConfig {
PreferNone = 0,
PreferShared = 1,
PreferL1 = 2,
PreferEqual = 3,
#[doc(hidden)]
__Nonexhaustive,
}
#[repr(u32)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum ResourceLimit {
StackSize = 0,
PrintfFifoSize = 1,
MallocHeapSize = 2,
DeviceRuntimeSynchronizeDepth = 3,
DeviceRuntimePendingLaunchCount = 4,
MaxL2FetchGranularity = 5,
#[doc(hidden)]
__Nonexhaustive,
}
#[repr(u32)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum SharedMemoryConfig {
DefaultBankSize = 0,
FourByteBankSize = 1,
EightByteBankSize = 2,
#[doc(hidden)]
__Nonexhaustive,
}
bitflags! {
pub struct ContextFlags: u32 {
const SCHED_SPIN = 0x01;
const SCHED_YIELD = 0x02;
const SCHED_BLOCKING_SYNC = 0x04;
const SCHED_AUTO = 0x00;
const MAP_HOST = 0x08;
const LMEM_RESIZE_TO_MAX = 0x10;
}
}
#[derive(Debug)]
pub struct Context {
inner: CUcontext,
}
impl Context {
pub fn create_and_push(flags: ContextFlags, device: Device) -> CudaResult<Context> {
unsafe {
let mut ctx: CUcontext = ptr::null_mut();
cuda_driver_sys::cuCtxCreate_v2(
&mut ctx as *mut CUcontext,
flags.bits(),
device.into_inner(),
)
.to_result()?;
Ok(Context { inner: ctx })
}
}
pub fn get_api_version(&self) -> CudaResult<CudaApiVersion> {
unsafe {
let mut api_version = 0u32;
cuda_driver_sys::cuCtxGetApiVersion(self.inner, &mut api_version as *mut u32)
.to_result()?;
Ok(CudaApiVersion {
version: api_version as i32,
})
}
}
pub fn get_unowned(&self) -> UnownedContext {
UnownedContext { inner: self.inner }
}
pub fn drop(mut ctx: Context) -> DropResult<Context> {
if ctx.inner.is_null() {
return Ok(());
}
unsafe {
let inner = mem::replace(&mut ctx.inner, ptr::null_mut());
match cuda_driver_sys::cuCtxDestroy_v2(inner).to_result() {
Ok(()) => {
mem::forget(ctx);
Ok(())
}
Err(e) => Err((e, Context { inner })),
}
}
}
}
impl Drop for Context {
fn drop(&mut self) {
if self.inner.is_null() {
return;
}
unsafe {
let inner = mem::replace(&mut self.inner, ptr::null_mut());
cuda_driver_sys::cuCtxDestroy_v2(inner)
.to_result()
.expect("Failed to destroy context");
}
}
}
pub trait ContextHandle: Sealed {
#[doc(hidden)]
fn get_inner(&self) -> CUcontext;
}
impl Sealed for Context {}
impl ContextHandle for Context {
fn get_inner(&self) -> CUcontext {
self.inner
}
}
impl Sealed for UnownedContext {}
impl ContextHandle for UnownedContext {
fn get_inner(&self) -> CUcontext {
self.inner
}
}
#[derive(Debug, Clone)]
pub struct UnownedContext {
inner: CUcontext,
}
unsafe impl Send for UnownedContext {}
unsafe impl Sync for UnownedContext {}
impl UnownedContext {
pub fn get_api_version(&self) -> CudaResult<CudaApiVersion> {
unsafe {
let mut api_version = 0u32;
cuda_driver_sys::cuCtxGetApiVersion(self.inner, &mut api_version as *mut u32)
.to_result()?;
Ok(CudaApiVersion {
version: api_version as i32,
})
}
}
}
#[derive(Debug)]
pub struct ContextStack;
impl ContextStack {
pub fn pop() -> CudaResult<UnownedContext> {
unsafe {
let mut ctx: CUcontext = ptr::null_mut();
cuda_driver_sys::cuCtxPopCurrent_v2(&mut ctx as *mut CUcontext).to_result()?;
Ok(UnownedContext { inner: ctx })
}
}
pub fn push<C: ContextHandle>(ctx: &C) -> CudaResult<()> {
unsafe {
cuda_driver_sys::cuCtxPushCurrent_v2(ctx.get_inner()).to_result()?;
Ok(())
}
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct StreamPriorityRange {
pub least: i32,
pub greatest: i32,
}
#[derive(Debug)]
pub struct CurrentContext;
impl CurrentContext {
pub fn get_cache_config() -> CudaResult<CacheConfig> {
unsafe {
let mut config = CacheConfig::PreferNone;
cuda_driver_sys::cuCtxGetCacheConfig(
&mut config as *mut CacheConfig as *mut cuda_driver_sys::CUfunc_cache,
)
.to_result()?;
Ok(config)
}
}
pub fn get_device() -> CudaResult<Device> {
unsafe {
let mut device = Device { device: 0 };
cuda_driver_sys::cuCtxGetDevice(&mut device.device as *mut cuda_driver_sys::CUdevice)
.to_result()?;
Ok(device)
}
}
pub fn get_flags() -> CudaResult<ContextFlags> {
unsafe {
let mut flags = 0u32;
cuda_driver_sys::cuCtxGetFlags(&mut flags as *mut u32).to_result()?;
Ok(ContextFlags::from_bits_truncate(flags))
}
}
pub fn get_resource_limit(resource: ResourceLimit) -> CudaResult<usize> {
unsafe {
let mut limit: usize = 0;
cuda_driver_sys::cuCtxGetLimit(&mut limit as *mut usize, transmute(resource))
.to_result()?;
Ok(limit)
}
}
pub fn get_shared_memory_config() -> CudaResult<SharedMemoryConfig> {
unsafe {
let mut cfg = SharedMemoryConfig::DefaultBankSize;
cuda_driver_sys::cuCtxGetSharedMemConfig(
&mut cfg as *mut SharedMemoryConfig as *mut cuda_driver_sys::CUsharedconfig,
)
.to_result()?;
Ok(cfg)
}
}
pub fn get_stream_priority_range() -> CudaResult<StreamPriorityRange> {
unsafe {
let mut range = StreamPriorityRange {
least: 0,
greatest: 0,
};
cuda_driver_sys::cuCtxGetStreamPriorityRange(
&mut range.least as *mut i32,
&mut range.greatest as *mut i32,
)
.to_result()?;
Ok(range)
}
}
pub fn set_cache_config(cfg: CacheConfig) -> CudaResult<()> {
unsafe { cuda_driver_sys::cuCtxSetCacheConfig(transmute(cfg)).to_result() }
}
pub fn set_resource_limit(resource: ResourceLimit, limit: usize) -> CudaResult<()> {
unsafe {
cuda_driver_sys::cuCtxSetLimit(transmute(resource), limit).to_result()?;
Ok(())
}
}
pub fn set_shared_memory_config(cfg: SharedMemoryConfig) -> CudaResult<()> {
unsafe { cuda_driver_sys::cuCtxSetSharedMemConfig(transmute(cfg)).to_result() }
}
pub fn get_current() -> CudaResult<UnownedContext> {
unsafe {
let mut ctx: CUcontext = ptr::null_mut();
cuda_driver_sys::cuCtxGetCurrent(&mut ctx as *mut CUcontext).to_result()?;
Ok(UnownedContext { inner: ctx })
}
}
pub fn set_current<C: ContextHandle>(c: &C) -> CudaResult<()> {
unsafe {
cuda_driver_sys::cuCtxSetCurrent(c.get_inner()).to_result()?;
Ok(())
}
}
pub fn synchronize() -> CudaResult<()> {
unsafe {
cuda_driver_sys::cuCtxSynchronize().to_result()?;
Ok(())
}
}
}