use crate::device::Device;
use crate::error::CudaResult;
use crate::ffi::CUcontext;
use crate::loader::try_driver;
pub mod flags {
pub const SCHED_AUTO: u32 = 0x00;
pub const SCHED_SPIN: u32 = 0x01;
pub const SCHED_YIELD: u32 = 0x02;
pub const SCHED_BLOCKING_SYNC: u32 = 0x04;
pub const MAP_HOST: u32 = 0x08;
pub const LMEM_RESIZE_TO_MAX: u32 = 0x10;
}
pub struct Context {
raw: CUcontext,
device: Device,
}
impl Context {
pub fn new(device: &Device) -> CudaResult<Self> {
Self::with_flags(device, flags::SCHED_AUTO)
}
pub fn with_flags(device: &Device, flags: u32) -> CudaResult<Self> {
let driver = try_driver()?;
let mut raw = CUcontext::default();
crate::error::check(unsafe { (driver.cu_ctx_create_v2)(&mut raw, flags, device.raw()) })?;
Ok(Self {
raw,
device: *device,
})
}
pub fn set_current(&self) -> CudaResult<()> {
let driver = try_driver()?;
crate::error::check(unsafe { (driver.cu_ctx_set_current)(self.raw) })
}
pub fn current_raw() -> CudaResult<Option<CUcontext>> {
let driver = try_driver()?;
let mut ctx = CUcontext::default();
crate::error::check(unsafe { (driver.cu_ctx_get_current)(&mut ctx) })?;
if ctx.is_null() {
Ok(None)
} else {
Ok(Some(ctx))
}
}
pub fn synchronize(&self) -> CudaResult<()> {
self.set_current()?;
let driver = try_driver()?;
crate::error::check(unsafe { (driver.cu_ctx_synchronize)() })
}
pub fn scoped<F, R>(&self, f: F) -> CudaResult<R>
where
F: FnOnce() -> CudaResult<R>,
{
let prev = Self::current_raw()?;
self.set_current()?;
let result = f();
let restore_ctx = prev.unwrap_or_default();
if let Ok(driver) = try_driver() {
if let Err(e) = crate::error::check(unsafe { (driver.cu_ctx_set_current)(restore_ctx) })
{
tracing::warn!("failed to restore previous context: {e}");
}
}
result
}
#[inline]
pub fn device(&self) -> &Device {
&self.device
}
#[inline]
pub fn raw(&self) -> CUcontext {
self.raw
}
pub fn is_current(&self) -> CudaResult<bool> {
match Self::current_raw()? {
Some(ctx) => Ok(ctx == self.raw),
None => Ok(false),
}
}
}
impl Drop for Context {
fn drop(&mut self) {
if let Ok(driver) = try_driver() {
let result = unsafe { (driver.cu_ctx_destroy_v2)(self.raw) };
if result != 0 {
tracing::warn!(
"cuCtxDestroy_v2 failed with error code {result} during Context drop \
(device ordinal {})",
self.device.ordinal()
);
}
}
}
}
unsafe impl Send for Context {}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("raw", &self.raw)
.field("device", &self.device)
.finish()
}
}