Skip to main content

oxicuda_driver/
primary_context.rs

1//! Primary context management (one per device, reference counted by driver).
2//!
3//! Every CUDA device has exactly one **primary context** that is shared
4//! among all users of that device within the same process. The primary
5//! context is reference-counted by the CUDA driver: it is created on the
6//! first [`PrimaryContext::retain`] call and destroyed when the last
7//! retainer releases it.
8//!
9//! Primary contexts are the recommended way to share a device context
10//! across multiple libraries and subsystems in the same process, because
11//! the driver ensures only one context exists per device.
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use oxicuda_driver::device::Device;
17//! use oxicuda_driver::primary_context::PrimaryContext;
18//!
19//! oxicuda_driver::init()?;
20//! let dev = Device::get(0)?;
21//! let pctx = PrimaryContext::retain(&dev)?;
22//! let (active, flags) = pctx.get_state()?;
23//! println!("active={active}, flags={flags}");
24//! pctx.release()?;
25//! # Ok::<(), oxicuda_driver::error::CudaError>(())
26//! ```
27
28use std::ffi::c_int;
29
30use crate::device::Device;
31use crate::error::CudaResult;
32use crate::ffi::CUcontext;
33use crate::loader::try_driver;
34
35// ---------------------------------------------------------------------------
36// PrimaryContext
37// ---------------------------------------------------------------------------
38
39/// RAII wrapper for a CUDA primary context.
40///
41/// A primary context is the per-device, reference-counted context managed
42/// by the CUDA driver. Unlike a regular [`Context`](crate::context::Context),
43/// the primary context is shared among all callers that retain it on the
44/// same device.
45///
46/// [`PrimaryContext::retain`] increments the driver's reference count and
47/// [`PrimaryContext::release`] decrements it. When the count reaches zero,
48/// the driver destroys the context.
49#[derive(Debug)]
50pub struct PrimaryContext {
51    /// The device this primary context belongs to.
52    device: Device,
53    /// The raw CUDA context handle obtained from `cuDevicePrimaryCtxRetain`.
54    raw: CUcontext,
55}
56
57// SAFETY: The primary context handle is managed by the CUDA driver and
58// can be used from any thread when properly synchronised.
59unsafe impl Send for PrimaryContext {}
60
61impl PrimaryContext {
62    /// Retains the primary context on the given device.
63    ///
64    /// If the primary context does not yet exist, the driver creates it.
65    /// Each call to `retain` increments an internal reference count. The
66    /// context remains alive until all retainers call [`release`](Self::release).
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if the driver cannot be loaded or the retain fails
71    /// (e.g., invalid device).
72    pub fn retain(device: &Device) -> CudaResult<Self> {
73        let driver = try_driver()?;
74        let mut raw = CUcontext::default();
75        crate::error::check(unsafe {
76            (driver.cu_device_primary_ctx_retain)(&mut raw, device.raw())
77        })?;
78        Ok(Self {
79            device: *device,
80            raw,
81        })
82    }
83
84    /// Releases this primary context, decrementing the driver's reference count.
85    ///
86    /// When the last retainer releases, the driver destroys the context and
87    /// frees its resources. After calling this method the `PrimaryContext`
88    /// is consumed and cannot be used further.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if the release call fails.
93    pub fn release(self) -> CudaResult<()> {
94        let driver = try_driver()?;
95        crate::error::check(unsafe {
96            (driver.cu_device_primary_ctx_release_v2)(self.device.raw())
97        })?;
98        // Prevent Drop from releasing again.
99        std::mem::forget(self);
100        Ok(())
101    }
102
103    /// Sets the flags for the primary context.
104    ///
105    /// The flags control scheduling behaviour (e.g., spin, yield, blocking).
106    /// See [`context::flags`](crate::context::flags) for available values.
107    ///
108    /// This must be called **before** the primary context is made active
109    /// (i.e., before any retain or before all retainers have released).
110    /// If the primary context is already active, this returns
111    /// [`CudaError::PrimaryContextActive`](crate::error::CudaError::PrimaryContextActive).
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the flags cannot be set.
116    pub fn set_flags(&self, flags: u32) -> CudaResult<()> {
117        let driver = try_driver()?;
118        crate::error::check(unsafe {
119            (driver.cu_device_primary_ctx_set_flags_v2)(self.device.raw(), flags)
120        })
121    }
122
123    /// Returns the current state of the primary context.
124    ///
125    /// Returns `(active, flags)` where:
126    /// - `active` is `true` if the context is currently retained by at
127    ///   least one caller.
128    /// - `flags` are the scheduling flags currently in effect.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the state query fails.
133    pub fn get_state(&self) -> CudaResult<(bool, u32)> {
134        let driver = try_driver()?;
135        let mut flags: u32 = 0;
136        let mut active: c_int = 0;
137        crate::error::check(unsafe {
138            (driver.cu_device_primary_ctx_get_state)(self.device.raw(), &mut flags, &mut active)
139        })?;
140        Ok((active != 0, flags))
141    }
142
143    /// Resets the primary context on this device.
144    ///
145    /// This destroys all allocations, modules, and state associated with
146    /// the primary context. The context is then re-created the next time
147    /// it is retained.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if the reset fails.
152    pub fn reset(&self) -> CudaResult<()> {
153        let driver = try_driver()?;
154        crate::error::check(unsafe { (driver.cu_device_primary_ctx_reset_v2)(self.device.raw()) })
155    }
156
157    /// Returns a reference to the device this primary context belongs to.
158    #[inline]
159    pub fn device(&self) -> &Device {
160        &self.device
161    }
162
163    /// Returns the raw `CUcontext` handle.
164    #[inline]
165    pub fn raw(&self) -> CUcontext {
166        self.raw
167    }
168}
169
170impl Drop for PrimaryContext {
171    /// Release the primary context on drop.
172    ///
173    /// Errors during release are logged but never propagated.
174    fn drop(&mut self) {
175        if let Ok(driver) = try_driver() {
176            let rc = unsafe { (driver.cu_device_primary_ctx_release_v2)(self.device.raw()) };
177            if rc != 0 {
178                tracing::warn!(
179                    cuda_error = rc,
180                    device = self.device.ordinal(),
181                    "cuDevicePrimaryCtxRelease_v2 failed during PrimaryContext drop"
182                );
183            }
184        }
185    }
186}
187
188impl std::fmt::Display for PrimaryContext {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        write!(f, "PrimaryContext(device={})", self.device.ordinal())
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Tests
196// ---------------------------------------------------------------------------
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn primary_context_display() {
204        // We cannot construct a real PrimaryContext without a GPU, but we
205        // can test the Display impl by verifying the format string.
206        let display_str = format!("PrimaryContext(device={})", 0);
207        assert!(display_str.contains("PrimaryContext"));
208        assert!(display_str.contains("device=0"));
209    }
210
211    #[test]
212    fn primary_context_is_send() {
213        fn assert_send<T: Send>() {}
214        assert_send::<PrimaryContext>();
215    }
216
217    #[test]
218    fn retain_signature_compiles() {
219        let _: fn(&Device) -> CudaResult<PrimaryContext> = PrimaryContext::retain;
220    }
221
222    #[test]
223    fn set_flags_signature_compiles() {
224        let _: fn(&PrimaryContext, u32) -> CudaResult<()> = PrimaryContext::set_flags;
225    }
226
227    #[test]
228    fn get_state_signature_compiles() {
229        let _: fn(&PrimaryContext) -> CudaResult<(bool, u32)> = PrimaryContext::get_state;
230    }
231
232    #[test]
233    fn reset_signature_compiles() {
234        let _: fn(&PrimaryContext) -> CudaResult<()> = PrimaryContext::reset;
235    }
236
237    #[cfg(feature = "gpu-tests")]
238    #[test]
239    fn retain_and_release_on_real_gpu() {
240        crate::init().ok();
241        if let Ok(dev) = Device::get(0) {
242            let pctx = PrimaryContext::retain(&dev).expect("failed to retain primary context");
243            let (active, _flags) = pctx.get_state().expect("failed to get state");
244            assert!(active);
245            pctx.release().expect("failed to release primary context");
246        }
247    }
248}