use std::ffi::c_int;
use crate::device::Device;
use crate::error::CudaResult;
use crate::ffi::CUcontext;
use crate::loader::try_driver;
#[derive(Debug)]
pub struct PrimaryContext {
device: Device,
raw: CUcontext,
}
unsafe impl Send for PrimaryContext {}
impl PrimaryContext {
pub fn retain(device: &Device) -> CudaResult<Self> {
let driver = try_driver()?;
let mut raw = CUcontext::default();
crate::error::check(unsafe {
(driver.cu_device_primary_ctx_retain)(&mut raw, device.raw())
})?;
Ok(Self {
device: *device,
raw,
})
}
pub fn release(self) -> CudaResult<()> {
let driver = try_driver()?;
crate::error::check(unsafe {
(driver.cu_device_primary_ctx_release_v2)(self.device.raw())
})?;
std::mem::forget(self);
Ok(())
}
pub fn set_flags(&self, flags: u32) -> CudaResult<()> {
let driver = try_driver()?;
crate::error::check(unsafe {
(driver.cu_device_primary_ctx_set_flags_v2)(self.device.raw(), flags)
})
}
pub fn get_state(&self) -> CudaResult<(bool, u32)> {
let driver = try_driver()?;
let mut flags: u32 = 0;
let mut active: c_int = 0;
crate::error::check(unsafe {
(driver.cu_device_primary_ctx_get_state)(self.device.raw(), &mut flags, &mut active)
})?;
Ok((active != 0, flags))
}
pub fn reset(&self) -> CudaResult<()> {
let driver = try_driver()?;
crate::error::check(unsafe { (driver.cu_device_primary_ctx_reset_v2)(self.device.raw()) })
}
#[inline]
pub fn device(&self) -> &Device {
&self.device
}
#[inline]
pub fn raw(&self) -> CUcontext {
self.raw
}
}
impl Drop for PrimaryContext {
fn drop(&mut self) {
if let Ok(driver) = try_driver() {
let rc = unsafe { (driver.cu_device_primary_ctx_release_v2)(self.device.raw()) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
device = self.device.ordinal(),
"cuDevicePrimaryCtxRelease_v2 failed during PrimaryContext drop"
);
}
}
}
}
impl std::fmt::Display for PrimaryContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PrimaryContext(device={})", self.device.ordinal())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn primary_context_display() {
let display_str = format!("PrimaryContext(device={})", 0);
assert!(display_str.contains("PrimaryContext"));
assert!(display_str.contains("device=0"));
}
#[test]
fn primary_context_is_send() {
fn assert_send<T: Send>() {}
assert_send::<PrimaryContext>();
}
#[test]
fn retain_signature_compiles() {
let _: fn(&Device) -> CudaResult<PrimaryContext> = PrimaryContext::retain;
}
#[test]
fn set_flags_signature_compiles() {
let _: fn(&PrimaryContext, u32) -> CudaResult<()> = PrimaryContext::set_flags;
}
#[test]
fn get_state_signature_compiles() {
let _: fn(&PrimaryContext) -> CudaResult<(bool, u32)> = PrimaryContext::get_state;
}
#[test]
fn reset_signature_compiles() {
let _: fn(&PrimaryContext) -> CudaResult<()> = PrimaryContext::reset;
}
#[cfg(feature = "gpu-tests")]
#[test]
fn retain_and_release_on_real_gpu() {
crate::init().ok();
if let Ok(dev) = Device::get(0) {
let pctx = PrimaryContext::retain(&dev).expect("failed to retain primary context");
let (active, _flags) = pctx.get_state().expect("failed to get state");
assert!(active);
pctx.release().expect("failed to release primary context");
}
}
}