oxicuda_driver/context.rs
1//! CUDA context management with RAII semantics.
2//!
3//! A CUDA **context** is the primary interface through which a CPU thread
4//! interacts with a GPU. It owns driver state such as loaded modules, allocated
5//! memory, and streams. This module provides the [`Context`] type, an RAII
6//! wrapper around `CUcontext` that automatically calls `cuCtxDestroy` on drop.
7//!
8//! # Thread safety
9//!
10//! CUDA contexts can be migrated between threads via `cuCtxSetCurrent`. The
11//! [`Context`] type implements [`Send`] to allow this. It does **not** implement
12//! [`Sync`] because the driver binds a context to a single thread at a time.
13//! Use [`Arc<Context>`](std::sync::Arc) together with explicit
14//! [`set_current`](Context::set_current) calls when sharing across threads.
15//!
16//! # Examples
17//!
18//! ```no_run
19//! use oxicuda_driver::context::Context;
20//! use oxicuda_driver::device::Device;
21//!
22//! oxicuda_driver::init()?;
23//! let device = Device::get(0)?;
24//! let ctx = Context::new(&device)?;
25//! ctx.set_current()?;
26//! // ... launch kernels, allocate memory ...
27//! ctx.synchronize()?;
28//! # Ok::<(), oxicuda_driver::error::CudaError>(())
29//! ```
30
31use crate::device::Device;
32use crate::error::CudaResult;
33use crate::ffi::CUcontext;
34use crate::loader::try_driver;
35
36// ---------------------------------------------------------------------------
37// Scheduling flags
38// ---------------------------------------------------------------------------
39
40/// Context scheduling flags passed to [`Context::with_flags`].
41///
42/// These control how the CPU thread behaves while waiting for GPU operations.
43pub mod flags {
44 /// Let the driver choose the optimal scheduling policy.
45 pub const SCHED_AUTO: u32 = 0x00;
46
47 /// Actively spin (busy-wait) while waiting for GPU results. Lowest latency
48 /// but consumes a full CPU core.
49 pub const SCHED_SPIN: u32 = 0x01;
50
51 /// Yield the CPU time-slice to other threads while waiting. Good for
52 /// multi-threaded applications.
53 pub const SCHED_YIELD: u32 = 0x02;
54
55 /// Block the calling thread on a synchronisation primitive. Lowest CPU
56 /// usage but slightly higher latency.
57 pub const SCHED_BLOCKING_SYNC: u32 = 0x04;
58
59 /// Enable mapped pinned allocations in this context.
60 pub const MAP_HOST: u32 = 0x08;
61
62 /// Keep local memory allocation after launch (deprecated flag kept for
63 /// completeness).
64 pub const LMEM_RESIZE_TO_MAX: u32 = 0x10;
65}
66
67// ---------------------------------------------------------------------------
68// Context
69// ---------------------------------------------------------------------------
70
71/// RAII wrapper for a CUDA context.
72///
73/// A context is created on a specific [`Device`] and becomes the active
74/// context for the calling thread. When the `Context` is dropped,
75/// `cuCtxDestroy_v2` is called automatically.
76///
77/// # Examples
78///
79/// ```no_run
80/// use oxicuda_driver::context::Context;
81/// use oxicuda_driver::device::Device;
82///
83/// oxicuda_driver::init()?;
84/// let dev = Device::get(0)?;
85/// let ctx = Context::new(&dev)?;
86/// println!("Context on device {}", ctx.device().ordinal());
87/// ctx.synchronize()?;
88/// // ctx is destroyed when it goes out of scope
89/// # Ok::<(), oxicuda_driver::error::CudaError>(())
90/// ```
91pub struct Context {
92 /// The raw CUDA context handle.
93 raw: CUcontext,
94 /// The device this context was created on.
95 device: Device,
96}
97
98impl Context {
99 // -- Construction --------------------------------------------------------
100
101 /// Create a new context on the given device with default flags
102 /// ([`flags::SCHED_AUTO`]).
103 ///
104 /// The new context is automatically pushed onto the calling thread's
105 /// context stack and becomes the current context.
106 ///
107 /// # Errors
108 ///
109 /// Returns an error if the driver cannot create the context (e.g., device
110 /// is invalid, out of resources).
111 pub fn new(device: &Device) -> CudaResult<Self> {
112 Self::with_flags(device, flags::SCHED_AUTO)
113 }
114
115 /// Create a new context on the given device with specific scheduling flags.
116 ///
117 /// See the [`flags`] module for available values. Multiple flags can be
118 /// combined with bitwise OR.
119 ///
120 /// # Errors
121 ///
122 /// Returns an error if the driver cannot create the context.
123 ///
124 /// # Examples
125 ///
126 /// ```no_run
127 /// use oxicuda_driver::context::{Context, flags};
128 /// use oxicuda_driver::device::Device;
129 ///
130 /// oxicuda_driver::init()?;
131 /// let dev = Device::get(0)?;
132 /// let ctx = Context::with_flags(&dev, flags::SCHED_BLOCKING_SYNC)?;
133 /// # Ok::<(), oxicuda_driver::error::CudaError>(())
134 /// ```
135 pub fn with_flags(device: &Device, flags: u32) -> CudaResult<Self> {
136 let driver = try_driver()?;
137 let mut raw = CUcontext::default();
138 crate::error::check(unsafe { (driver.cu_ctx_create_v2)(&mut raw, flags, device.raw()) })?;
139 Ok(Self {
140 raw,
141 device: *device,
142 })
143 }
144
145 // -- Current context management -----------------------------------------
146
147 /// Set this context as the current context for the calling thread.
148 ///
149 /// Any previous context on this thread is detached (but not destroyed).
150 ///
151 /// # Errors
152 ///
153 /// Returns an error if the driver call fails.
154 pub fn set_current(&self) -> CudaResult<()> {
155 let driver = try_driver()?;
156 crate::error::check(unsafe { (driver.cu_ctx_set_current)(self.raw) })
157 }
158
159 /// Get the raw handle of the current context for the calling thread.
160 ///
161 /// Returns `None` if no context is bound to the current thread.
162 ///
163 /// # Errors
164 ///
165 /// Returns an error if the driver call fails.
166 pub fn current_raw() -> CudaResult<Option<CUcontext>> {
167 let driver = try_driver()?;
168 let mut ctx = CUcontext::default();
169 crate::error::check(unsafe { (driver.cu_ctx_get_current)(&mut ctx) })?;
170 if ctx.is_null() {
171 Ok(None)
172 } else {
173 Ok(Some(ctx))
174 }
175 }
176
177 // -- Synchronisation ----------------------------------------------------
178
179 /// Block until all pending GPU operations in this context have completed.
180 ///
181 /// This sets the context as current before synchronising to ensure the
182 /// correct context is targeted.
183 ///
184 /// # Errors
185 ///
186 /// Returns an error if any GPU operation failed or the driver call fails.
187 pub fn synchronize(&self) -> CudaResult<()> {
188 self.set_current()?;
189 let driver = try_driver()?;
190 crate::error::check(unsafe { (driver.cu_ctx_synchronize)() })
191 }
192
193 // -- Scoped execution ---------------------------------------------------
194
195 /// Execute a closure with this context set as current, then restore the
196 /// previous context.
197 ///
198 /// This is useful when temporarily switching contexts. The previous
199 /// context (if any) is restored even if the closure returns an error.
200 ///
201 /// # Errors
202 ///
203 /// Propagates any error from the closure. Context-restoration errors are
204 /// logged but do not override the closure result.
205 ///
206 /// # Examples
207 ///
208 /// ```no_run
209 /// use oxicuda_driver::context::Context;
210 /// use oxicuda_driver::device::Device;
211 ///
212 /// oxicuda_driver::init()?;
213 /// let dev = Device::get(0)?;
214 /// let ctx = Context::new(&dev)?;
215 /// let result = ctx.scoped(|| {
216 /// // ctx is current here
217 /// Ok(42)
218 /// })?;
219 /// assert_eq!(result, 42);
220 /// # Ok::<(), oxicuda_driver::error::CudaError>(())
221 /// ```
222 pub fn scoped<F, R>(&self, f: F) -> CudaResult<R>
223 where
224 F: FnOnce() -> CudaResult<R>,
225 {
226 // Save the currently active context (may be None).
227 let prev = Self::current_raw()?;
228
229 // Activate this context.
230 self.set_current()?;
231
232 // Run the user closure.
233 let result = f();
234
235 // Restore the previous context. A null CUcontext detaches any context
236 // from the current thread, which is the correct behaviour when there
237 // was no previous context.
238 let restore_ctx = prev.unwrap_or_default();
239 if let Ok(driver) = try_driver() {
240 if let Err(e) = crate::error::check(unsafe { (driver.cu_ctx_set_current)(restore_ctx) })
241 {
242 tracing::warn!("failed to restore previous context: {e}");
243 }
244 }
245
246 result
247 }
248
249 // -- Accessors ----------------------------------------------------------
250
251 /// Get a reference to the [`Device`] this context was created on.
252 #[inline]
253 pub fn device(&self) -> &Device {
254 &self.device
255 }
256
257 /// Get the raw `CUcontext` handle for use with FFI calls.
258 #[inline]
259 pub fn raw(&self) -> CUcontext {
260 self.raw
261 }
262
263 /// Returns `true` if this context is the current context on the calling
264 /// thread.
265 ///
266 /// # Errors
267 ///
268 /// Returns an error if the driver call fails.
269 pub fn is_current(&self) -> CudaResult<bool> {
270 match Self::current_raw()? {
271 Some(ctx) => Ok(ctx == self.raw),
272 None => Ok(false),
273 }
274 }
275}
276
277// ---------------------------------------------------------------------------
278// Drop
279// ---------------------------------------------------------------------------
280
281impl Drop for Context {
282 /// Destroy the CUDA context.
283 ///
284 /// Errors during destruction are logged via `tracing::warn` but never
285 /// propagated (destructors must not panic).
286 fn drop(&mut self) {
287 if let Ok(driver) = try_driver() {
288 let result = unsafe { (driver.cu_ctx_destroy_v2)(self.raw) };
289 if result != 0 {
290 tracing::warn!(
291 "cuCtxDestroy_v2 failed with error code {result} during Context drop \
292 (device ordinal {})",
293 self.device.ordinal()
294 );
295 }
296 }
297 }
298}
299
300// ---------------------------------------------------------------------------
301// Trait impls
302// ---------------------------------------------------------------------------
303
304// SAFETY: CUDA contexts can be migrated between threads via cuCtxSetCurrent.
305// The caller is responsible for calling set_current() on the new thread.
306unsafe impl Send for Context {}
307
308impl std::fmt::Debug for Context {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 f.debug_struct("Context")
311 .field("raw", &self.raw)
312 .field("device", &self.device)
313 .finish()
314 }
315}