Skip to main content

oxicuda_driver/
context.rs

1//! CUDA context management with RAII semantics.
2//!
3//! A CUDA **context** is the primary interface through which a CPU thread
4//! interacts with a GPU. It owns driver state such as loaded modules, allocated
5//! memory, and streams. This module provides the [`Context`] type, an RAII
6//! wrapper around `CUcontext` that automatically calls `cuCtxDestroy` on drop.
7//!
8//! # Thread safety
9//!
10//! CUDA contexts can be migrated between threads via `cuCtxSetCurrent`. The
11//! [`Context`] type implements [`Send`] to allow this. It does **not** implement
12//! [`Sync`] because the driver binds a context to a single thread at a time.
13//! Use [`Arc<Context>`](std::sync::Arc) together with explicit
14//! [`set_current`](Context::set_current) calls when sharing across threads.
15//!
16//! # Examples
17//!
18//! ```no_run
19//! use oxicuda_driver::context::Context;
20//! use oxicuda_driver::device::Device;
21//!
22//! oxicuda_driver::init()?;
23//! let device = Device::get(0)?;
24//! let ctx = Context::new(&device)?;
25//! ctx.set_current()?;
26//! // ... launch kernels, allocate memory ...
27//! ctx.synchronize()?;
28//! # Ok::<(), oxicuda_driver::error::CudaError>(())
29//! ```
30
31use crate::device::Device;
32use crate::error::CudaResult;
33use crate::ffi::CUcontext;
34use crate::loader::try_driver;
35
36// ---------------------------------------------------------------------------
37// Scheduling flags
38// ---------------------------------------------------------------------------
39
40/// Context scheduling flags passed to [`Context::with_flags`].
41///
42/// These control how the CPU thread behaves while waiting for GPU operations.
43pub mod flags {
44    /// Let the driver choose the optimal scheduling policy.
45    pub const SCHED_AUTO: u32 = 0x00;
46
47    /// Actively spin (busy-wait) while waiting for GPU results. Lowest latency
48    /// but consumes a full CPU core.
49    pub const SCHED_SPIN: u32 = 0x01;
50
51    /// Yield the CPU time-slice to other threads while waiting. Good for
52    /// multi-threaded applications.
53    pub const SCHED_YIELD: u32 = 0x02;
54
55    /// Block the calling thread on a synchronisation primitive. Lowest CPU
56    /// usage but slightly higher latency.
57    pub const SCHED_BLOCKING_SYNC: u32 = 0x04;
58
59    /// Enable mapped pinned allocations in this context.
60    pub const MAP_HOST: u32 = 0x08;
61
62    /// Keep local memory allocation after launch (deprecated flag kept for
63    /// completeness).
64    pub const LMEM_RESIZE_TO_MAX: u32 = 0x10;
65}
66
67// ---------------------------------------------------------------------------
68// Context
69// ---------------------------------------------------------------------------
70
71/// RAII wrapper for a CUDA context.
72///
73/// A context is created on a specific [`Device`] and becomes the active
74/// context for the calling thread. When the `Context` is dropped,
75/// `cuCtxDestroy_v2` is called automatically.
76///
77/// # Examples
78///
79/// ```no_run
80/// use oxicuda_driver::context::Context;
81/// use oxicuda_driver::device::Device;
82///
83/// oxicuda_driver::init()?;
84/// let dev = Device::get(0)?;
85/// let ctx = Context::new(&dev)?;
86/// println!("Context on device {}", ctx.device().ordinal());
87/// ctx.synchronize()?;
88/// // ctx is destroyed when it goes out of scope
89/// # Ok::<(), oxicuda_driver::error::CudaError>(())
90/// ```
91pub struct Context {
92    /// The raw CUDA context handle.
93    raw: CUcontext,
94    /// The device this context was created on.
95    device: Device,
96}
97
98impl Context {
99    // -- Construction --------------------------------------------------------
100
101    /// Create a new context on the given device with default flags
102    /// ([`flags::SCHED_AUTO`]).
103    ///
104    /// The new context is automatically pushed onto the calling thread's
105    /// context stack and becomes the current context.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the driver cannot create the context (e.g., device
110    /// is invalid, out of resources).
111    pub fn new(device: &Device) -> CudaResult<Self> {
112        Self::with_flags(device, flags::SCHED_AUTO)
113    }
114
115    /// Create a new context on the given device with specific scheduling flags.
116    ///
117    /// See the [`flags`] module for available values. Multiple flags can be
118    /// combined with bitwise OR.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the driver cannot create the context.
123    ///
124    /// # Examples
125    ///
126    /// ```no_run
127    /// use oxicuda_driver::context::{Context, flags};
128    /// use oxicuda_driver::device::Device;
129    ///
130    /// oxicuda_driver::init()?;
131    /// let dev = Device::get(0)?;
132    /// let ctx = Context::with_flags(&dev, flags::SCHED_BLOCKING_SYNC)?;
133    /// # Ok::<(), oxicuda_driver::error::CudaError>(())
134    /// ```
135    pub fn with_flags(device: &Device, flags: u32) -> CudaResult<Self> {
136        let driver = try_driver()?;
137        let mut raw = CUcontext::default();
138        crate::error::check(unsafe { (driver.cu_ctx_create_v2)(&mut raw, flags, device.raw()) })?;
139        Ok(Self {
140            raw,
141            device: *device,
142        })
143    }
144
145    // -- Current context management -----------------------------------------
146
147    /// Set this context as the current context for the calling thread.
148    ///
149    /// Any previous context on this thread is detached (but not destroyed).
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if the driver call fails.
154    pub fn set_current(&self) -> CudaResult<()> {
155        let driver = try_driver()?;
156        crate::error::check(unsafe { (driver.cu_ctx_set_current)(self.raw) })
157    }
158
159    /// Get the raw handle of the current context for the calling thread.
160    ///
161    /// Returns `None` if no context is bound to the current thread.
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if the driver call fails.
166    pub fn current_raw() -> CudaResult<Option<CUcontext>> {
167        let driver = try_driver()?;
168        let mut ctx = CUcontext::default();
169        crate::error::check(unsafe { (driver.cu_ctx_get_current)(&mut ctx) })?;
170        if ctx.is_null() {
171            Ok(None)
172        } else {
173            Ok(Some(ctx))
174        }
175    }
176
177    // -- Synchronisation ----------------------------------------------------
178
179    /// Block until all pending GPU operations in this context have completed.
180    ///
181    /// This sets the context as current before synchronising to ensure the
182    /// correct context is targeted.
183    ///
184    /// # Errors
185    ///
186    /// Returns an error if any GPU operation failed or the driver call fails.
187    pub fn synchronize(&self) -> CudaResult<()> {
188        self.set_current()?;
189        let driver = try_driver()?;
190        crate::error::check(unsafe { (driver.cu_ctx_synchronize)() })
191    }
192
193    // -- Scoped execution ---------------------------------------------------
194
195    /// Execute a closure with this context set as current, then restore the
196    /// previous context.
197    ///
198    /// This is useful when temporarily switching contexts. The previous
199    /// context (if any) is restored even if the closure returns an error.
200    ///
201    /// # Errors
202    ///
203    /// Propagates any error from the closure. Context-restoration errors are
204    /// logged but do not override the closure result.
205    ///
206    /// # Examples
207    ///
208    /// ```no_run
209    /// use oxicuda_driver::context::Context;
210    /// use oxicuda_driver::device::Device;
211    ///
212    /// oxicuda_driver::init()?;
213    /// let dev = Device::get(0)?;
214    /// let ctx = Context::new(&dev)?;
215    /// let result = ctx.scoped(|| {
216    ///     // ctx is current here
217    ///     Ok(42)
218    /// })?;
219    /// assert_eq!(result, 42);
220    /// # Ok::<(), oxicuda_driver::error::CudaError>(())
221    /// ```
222    pub fn scoped<F, R>(&self, f: F) -> CudaResult<R>
223    where
224        F: FnOnce() -> CudaResult<R>,
225    {
226        // Save the currently active context (may be None).
227        let prev = Self::current_raw()?;
228
229        // Activate this context.
230        self.set_current()?;
231
232        // Run the user closure.
233        let result = f();
234
235        // Restore the previous context. A null CUcontext detaches any context
236        // from the current thread, which is the correct behaviour when there
237        // was no previous context.
238        let restore_ctx = prev.unwrap_or_default();
239        if let Ok(driver) = try_driver() {
240            if let Err(e) = crate::error::check(unsafe { (driver.cu_ctx_set_current)(restore_ctx) })
241            {
242                tracing::warn!("failed to restore previous context: {e}");
243            }
244        }
245
246        result
247    }
248
249    // -- Accessors ----------------------------------------------------------
250
251    /// Get a reference to the [`Device`] this context was created on.
252    #[inline]
253    pub fn device(&self) -> &Device {
254        &self.device
255    }
256
257    /// Get the raw `CUcontext` handle for use with FFI calls.
258    #[inline]
259    pub fn raw(&self) -> CUcontext {
260        self.raw
261    }
262
263    /// Returns `true` if this context is the current context on the calling
264    /// thread.
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if the driver call fails.
269    pub fn is_current(&self) -> CudaResult<bool> {
270        match Self::current_raw()? {
271            Some(ctx) => Ok(ctx == self.raw),
272            None => Ok(false),
273        }
274    }
275}
276
277// ---------------------------------------------------------------------------
278// Drop
279// ---------------------------------------------------------------------------
280
281impl Drop for Context {
282    /// Destroy the CUDA context.
283    ///
284    /// Errors during destruction are logged via `tracing::warn` but never
285    /// propagated (destructors must not panic).
286    fn drop(&mut self) {
287        if let Ok(driver) = try_driver() {
288            let result = unsafe { (driver.cu_ctx_destroy_v2)(self.raw) };
289            if result != 0 {
290                tracing::warn!(
291                    "cuCtxDestroy_v2 failed with error code {result} during Context drop \
292                     (device ordinal {})",
293                    self.device.ordinal()
294                );
295            }
296        }
297    }
298}
299
300// ---------------------------------------------------------------------------
301// Trait impls
302// ---------------------------------------------------------------------------
303
304// SAFETY: CUDA contexts can be migrated between threads via cuCtxSetCurrent.
305// The caller is responsible for calling set_current() on the new thread.
306unsafe impl Send for Context {}
307
308impl std::fmt::Debug for Context {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("Context")
311            .field("raw", &self.raw)
312            .field("device", &self.device)
313            .finish()
314    }
315}