use std::sync::Arc;
use baracuda_cuda_sys::types::CUcontext_flags;
use baracuda_cuda_sys::{driver, CUcontext};
use crate::device::Device;
use crate::error::{check, Result};
use crate::init::init;
#[derive(Clone, Debug)]
pub struct Context {
inner: Arc<ContextInner>,
}
struct ContextInner {
handle: CUcontext,
device: Device,
}
unsafe impl Send for ContextInner {}
unsafe impl Sync for ContextInner {}
impl core::fmt::Debug for ContextInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Context")
.field("handle", &self.handle)
.field("device", &self.device)
.finish()
}
}
impl Context {
pub fn new(device: &Device) -> Result<Self> {
Self::with_flags(device, CUcontext_flags::SCHED_AUTO)
}
pub fn with_flags(device: &Device, flags: u32) -> Result<Self> {
init()?;
let d = driver()?;
let cu = d.cu_ctx_create()?;
let mut ctx: CUcontext = core::ptr::null_mut();
check(unsafe { cu(&mut ctx, flags, device.0) })?;
Ok(Self {
inner: Arc::new(ContextInner {
handle: ctx,
device: *device,
}),
})
}
pub fn current() -> Result<Option<CUcontext>> {
init()?;
let d = driver()?;
let cu = d.cu_ctx_get_current()?;
let mut ctx: CUcontext = core::ptr::null_mut();
check(unsafe { cu(&mut ctx) })?;
if ctx.is_null() {
Ok(None)
} else {
Ok(Some(ctx))
}
}
pub fn set_current(&self) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_set_current()?;
check(unsafe { cu(self.inner.handle) })
}
pub fn push(&self) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_push_current()?;
check(unsafe { cu(self.inner.handle) })
}
pub fn pop() -> Result<CUcontext> {
init()?;
let d = driver()?;
let cu = d.cu_ctx_pop_current()?;
let mut ctx: CUcontext = core::ptr::null_mut();
check(unsafe { cu(&mut ctx) })?;
Ok(ctx)
}
pub fn synchronize(&self) -> Result<()> {
self.set_current()?;
let d = driver()?;
let cu = d.cu_ctx_synchronize()?;
check(unsafe { cu() })
}
pub fn api_version(&self) -> Result<u32> {
let d = driver()?;
let cu = d.cu_ctx_get_api_version()?;
let mut v: core::ffi::c_uint = 0;
check(unsafe { cu(self.inner.handle, &mut v) })?;
Ok(v)
}
pub fn current_device() -> Result<Device> {
let d = driver()?;
let cu = d.cu_ctx_get_device()?;
let mut dev = baracuda_cuda_sys::CUdevice::default();
check(unsafe { cu(&mut dev) })?;
Ok(Device(dev))
}
pub fn current_flags() -> Result<u32> {
let d = driver()?;
let cu = d.cu_ctx_get_flags()?;
let mut f: core::ffi::c_uint = 0;
check(unsafe { cu(&mut f) })?;
Ok(f)
}
pub fn get_limit(limit: u32) -> Result<usize> {
let d = driver()?;
let cu = d.cu_ctx_get_limit()?;
let mut v: usize = 0;
check(unsafe { cu(&mut v, limit) })?;
Ok(v)
}
pub fn set_limit(limit: u32, value: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_set_limit()?;
check(unsafe { cu(limit, value) })
}
pub fn cache_config() -> Result<u32> {
let d = driver()?;
let cu = d.cu_ctx_get_cache_config()?;
let mut c: core::ffi::c_uint = 0;
check(unsafe { cu(&mut c) })?;
Ok(c)
}
pub fn set_cache_config(config: u32) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_set_cache_config()?;
check(unsafe { cu(config) })
}
pub fn stream_priority_range() -> Result<(i32, i32)> {
let d = driver()?;
let cu = d.cu_ctx_get_stream_priority_range()?;
let mut least: core::ffi::c_int = 0;
let mut greatest: core::ffi::c_int = 0;
check(unsafe { cu(&mut least, &mut greatest) })?;
Ok((least, greatest))
}
pub fn enable_peer_access(peer: &Context) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_enable_peer_access()?;
check(unsafe { cu(peer.inner.handle, 0) })
}
pub fn disable_peer_access(peer: &Context) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_disable_peer_access()?;
check(unsafe { cu(peer.inner.handle) })
}
#[inline]
pub fn device(&self) -> Device {
self.inner.device
}
#[inline]
pub fn as_raw(&self) -> CUcontext {
self.inner.handle
}
pub fn id(&self) -> Result<u64> {
let d = driver()?;
let cu = d.cu_ctx_get_id()?;
let mut out: u64 = 0;
check(unsafe { cu(self.inner.handle, &mut out) })?;
Ok(out)
}
pub fn record_event(&self, event: &crate::Event) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_record_event()?;
check(unsafe { cu(self.inner.handle, event.as_raw()) })
}
pub fn wait_event(&self, event: &crate::Event) -> Result<()> {
let d = driver()?;
let cu = d.cu_ctx_wait_event()?;
check(unsafe { cu(self.inner.handle, event.as_raw()) })
}
}
impl Drop for ContextInner {
fn drop(&mut self) {
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_ctx_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
#[derive(Debug)]
pub struct PrimaryContext {
handle: CUcontext,
device: Device,
}
unsafe impl Send for PrimaryContext {}
unsafe impl Sync for PrimaryContext {}
impl PrimaryContext {
pub fn retain(device: &Device) -> Result<Self> {
init()?;
let d = driver()?;
let cu = d.cu_device_primary_ctx_retain()?;
let mut handle: CUcontext = core::ptr::null_mut();
check(unsafe { cu(&mut handle, device.0) })?;
Ok(Self {
handle,
device: *device,
})
}
pub fn reset(device: &Device) -> Result<()> {
init()?;
let d = driver()?;
let cu = d.cu_device_primary_ctx_reset()?;
check(unsafe { cu(device.0) })
}
pub fn device(&self) -> Device {
self.device
}
#[inline]
pub fn as_raw(&self) -> CUcontext {
self.handle
}
}
impl Drop for PrimaryContext {
fn drop(&mut self) {
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_device_primary_ctx_release() {
let _ = unsafe { cu(self.device.0) };
}
}
}
}