oxicuda_driver/primary_context.rs
1//! Primary context management (one per device, reference counted by driver).
2//!
3//! Every CUDA device has exactly one **primary context** that is shared
4//! among all users of that device within the same process. The primary
5//! context is reference-counted by the CUDA driver: it is created on the
6//! first [`PrimaryContext::retain`] call and destroyed when the last
7//! retainer releases it.
8//!
9//! Primary contexts are the recommended way to share a device context
10//! across multiple libraries and subsystems in the same process, because
11//! the driver ensures only one context exists per device.
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use oxicuda_driver::device::Device;
17//! use oxicuda_driver::primary_context::PrimaryContext;
18//!
19//! oxicuda_driver::init()?;
20//! let dev = Device::get(0)?;
21//! let pctx = PrimaryContext::retain(&dev)?;
22//! let (active, flags) = pctx.get_state()?;
23//! println!("active={active}, flags={flags}");
24//! pctx.release()?;
25//! # Ok::<(), oxicuda_driver::error::CudaError>(())
26//! ```
27
28use std::ffi::c_int;
29
30use crate::device::Device;
31use crate::error::CudaResult;
32use crate::ffi::CUcontext;
33use crate::loader::try_driver;
34
35// ---------------------------------------------------------------------------
36// PrimaryContext
37// ---------------------------------------------------------------------------
38
39/// RAII wrapper for a CUDA primary context.
40///
41/// A primary context is the per-device, reference-counted context managed
42/// by the CUDA driver. Unlike a regular [`Context`](crate::context::Context),
43/// the primary context is shared among all callers that retain it on the
44/// same device.
45///
46/// [`PrimaryContext::retain`] increments the driver's reference count and
47/// [`PrimaryContext::release`] decrements it. When the count reaches zero,
48/// the driver destroys the context.
49#[derive(Debug)]
50pub struct PrimaryContext {
51 /// The device this primary context belongs to.
52 device: Device,
53 /// The raw CUDA context handle obtained from `cuDevicePrimaryCtxRetain`.
54 raw: CUcontext,
55}
56
57// SAFETY: The primary context handle is managed by the CUDA driver and
58// can be used from any thread when properly synchronised.
59unsafe impl Send for PrimaryContext {}
60
61impl PrimaryContext {
62 /// Retains the primary context on the given device.
63 ///
64 /// If the primary context does not yet exist, the driver creates it.
65 /// Each call to `retain` increments an internal reference count. The
66 /// context remains alive until all retainers call [`release`](Self::release).
67 ///
68 /// # Errors
69 ///
70 /// Returns an error if the driver cannot be loaded or the retain fails
71 /// (e.g., invalid device).
72 pub fn retain(device: &Device) -> CudaResult<Self> {
73 let driver = try_driver()?;
74 let mut raw = CUcontext::default();
75 crate::error::check(unsafe {
76 (driver.cu_device_primary_ctx_retain)(&mut raw, device.raw())
77 })?;
78 Ok(Self {
79 device: *device,
80 raw,
81 })
82 }
83
84 /// Releases this primary context, decrementing the driver's reference count.
85 ///
86 /// When the last retainer releases, the driver destroys the context and
87 /// frees its resources. After calling this method the `PrimaryContext`
88 /// is consumed and cannot be used further.
89 ///
90 /// # Errors
91 ///
92 /// Returns an error if the release call fails.
93 pub fn release(self) -> CudaResult<()> {
94 let driver = try_driver()?;
95 crate::error::check(unsafe {
96 (driver.cu_device_primary_ctx_release_v2)(self.device.raw())
97 })?;
98 // Prevent Drop from releasing again.
99 std::mem::forget(self);
100 Ok(())
101 }
102
103 /// Sets the flags for the primary context.
104 ///
105 /// The flags control scheduling behaviour (e.g., spin, yield, blocking).
106 /// See [`context::flags`](crate::context::flags) for available values.
107 ///
108 /// This must be called **before** the primary context is made active
109 /// (i.e., before any retain or before all retainers have released).
110 /// If the primary context is already active, this returns
111 /// [`CudaError::PrimaryContextActive`](crate::error::CudaError::PrimaryContextActive).
112 ///
113 /// # Errors
114 ///
115 /// Returns an error if the flags cannot be set.
116 pub fn set_flags(&self, flags: u32) -> CudaResult<()> {
117 let driver = try_driver()?;
118 crate::error::check(unsafe {
119 (driver.cu_device_primary_ctx_set_flags_v2)(self.device.raw(), flags)
120 })
121 }
122
123 /// Returns the current state of the primary context.
124 ///
125 /// Returns `(active, flags)` where:
126 /// - `active` is `true` if the context is currently retained by at
127 /// least one caller.
128 /// - `flags` are the scheduling flags currently in effect.
129 ///
130 /// # Errors
131 ///
132 /// Returns an error if the state query fails.
133 pub fn get_state(&self) -> CudaResult<(bool, u32)> {
134 let driver = try_driver()?;
135 let mut flags: u32 = 0;
136 let mut active: c_int = 0;
137 crate::error::check(unsafe {
138 (driver.cu_device_primary_ctx_get_state)(self.device.raw(), &mut flags, &mut active)
139 })?;
140 Ok((active != 0, flags))
141 }
142
143 /// Resets the primary context on this device.
144 ///
145 /// This destroys all allocations, modules, and state associated with
146 /// the primary context. The context is then re-created the next time
147 /// it is retained.
148 ///
149 /// # Errors
150 ///
151 /// Returns an error if the reset fails.
152 pub fn reset(&self) -> CudaResult<()> {
153 let driver = try_driver()?;
154 crate::error::check(unsafe { (driver.cu_device_primary_ctx_reset_v2)(self.device.raw()) })
155 }
156
157 /// Returns a reference to the device this primary context belongs to.
158 #[inline]
159 pub fn device(&self) -> &Device {
160 &self.device
161 }
162
163 /// Returns the raw `CUcontext` handle.
164 #[inline]
165 pub fn raw(&self) -> CUcontext {
166 self.raw
167 }
168}
169
170impl Drop for PrimaryContext {
171 /// Release the primary context on drop.
172 ///
173 /// Errors during release are logged but never propagated.
174 fn drop(&mut self) {
175 if let Ok(driver) = try_driver() {
176 let rc = unsafe { (driver.cu_device_primary_ctx_release_v2)(self.device.raw()) };
177 if rc != 0 {
178 tracing::warn!(
179 cuda_error = rc,
180 device = self.device.ordinal(),
181 "cuDevicePrimaryCtxRelease_v2 failed during PrimaryContext drop"
182 );
183 }
184 }
185 }
186}
187
188impl std::fmt::Display for PrimaryContext {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "PrimaryContext(device={})", self.device.ordinal())
191 }
192}
193
194// ---------------------------------------------------------------------------
195// Tests
196// ---------------------------------------------------------------------------
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn primary_context_display() {
204 // We cannot construct a real PrimaryContext without a GPU, but we
205 // can test the Display impl by verifying the format string.
206 let display_str = format!("PrimaryContext(device={})", 0);
207 assert!(display_str.contains("PrimaryContext"));
208 assert!(display_str.contains("device=0"));
209 }
210
211 #[test]
212 fn primary_context_is_send() {
213 fn assert_send<T: Send>() {}
214 assert_send::<PrimaryContext>();
215 }
216
217 #[test]
218 fn retain_signature_compiles() {
219 let _: fn(&Device) -> CudaResult<PrimaryContext> = PrimaryContext::retain;
220 }
221
222 #[test]
223 fn set_flags_signature_compiles() {
224 let _: fn(&PrimaryContext, u32) -> CudaResult<()> = PrimaryContext::set_flags;
225 }
226
227 #[test]
228 fn get_state_signature_compiles() {
229 let _: fn(&PrimaryContext) -> CudaResult<(bool, u32)> = PrimaryContext::get_state;
230 }
231
232 #[test]
233 fn reset_signature_compiles() {
234 let _: fn(&PrimaryContext) -> CudaResult<()> = PrimaryContext::reset;
235 }
236
237 #[cfg(feature = "gpu-tests")]
238 #[test]
239 fn retain_and_release_on_real_gpu() {
240 crate::init().ok();
241 if let Ok(dev) = Device::get(0) {
242 let pctx = PrimaryContext::retain(&dev).expect("failed to retain primary context");
243 let (active, _flags) = pctx.get_state().expect("failed to get state");
244 assert!(active);
245 pctx.release().expect("failed to release primary context");
246 }
247 }
248}