Skip to main content

baracuda_driver/
context.rs

1//! CUDA contexts — both primary (shared with the Runtime API) and explicit.
2//!
3//! A [`Context`] owns the handle returned by `cuCtxCreate`. Contexts are
4//! reference-counted via `Arc` so multiple streams/events/modules can
5//! share ownership; the underlying `cuCtxDestroy` runs when the last clone
6//! drops.
7
8use std::sync::Arc;
9
10use baracuda_cuda_sys::types::CUcontext_flags;
11use baracuda_cuda_sys::{driver, CUcontext};
12
13use crate::device::Device;
14use crate::error::{check, Result};
15use crate::init::init;
16
17/// A CUDA context created by `cuCtxCreate`.
18///
19/// Multiple [`Context`] clones refer to the same underlying driver context.
20#[derive(Clone, Debug)]
21pub struct Context {
22    inner: Arc<ContextInner>,
23}
24
25struct ContextInner {
26    handle: CUcontext,
27    device: Device,
28}
29
30// SAFETY: CUcontext is a raw pointer, but NVIDIA documents that a context
31// object may be shared between threads so long as each thread calls
32// `cuCtxSetCurrent` before issuing work. Concurrent kernel submission on
33// different streams is explicitly supported.
34unsafe impl Send for ContextInner {}
35unsafe impl Sync for ContextInner {}
36
37impl core::fmt::Debug for ContextInner {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        f.debug_struct("Context")
40            .field("handle", &self.handle)
41            .field("device", &self.device)
42            .finish()
43    }
44}
45
46impl Context {
47    /// Create a new context on `device` with default scheduling flags.
48    pub fn new(device: &Device) -> Result<Self> {
49        Self::with_flags(device, CUcontext_flags::SCHED_AUTO)
50    }
51
52    /// Create a new context on `device`, passing `flags` verbatim to
53    /// `cuCtxCreate`. See [`baracuda_cuda_sys::types::CUcontext_flags`] for
54    /// the permitted values.
55    pub fn with_flags(device: &Device, flags: u32) -> Result<Self> {
56        init()?;
57        let d = driver()?;
58        let cu = d.cu_ctx_create()?;
59        let mut ctx: CUcontext = core::ptr::null_mut();
60        // SAFETY: `ctx` is a writable pointer; `device.0` is a live CUdevice.
61        check(unsafe { cu(&mut ctx, flags, device.0) })?;
62        Ok(Self {
63            inner: Arc::new(ContextInner {
64                handle: ctx,
65                device: *device,
66            }),
67        })
68    }
69
70    /// Retrieve the thread's currently-current context, if any. Returns
71    /// `Ok(None)` when no context is current.
72    ///
73    /// **Note:** the returned `Context` is a _non-owning_ view — its `Drop`
74    /// will not call `cuCtxDestroy` on the handle. Use this only for
75    /// interop inspection, not lifecycle management.
76    pub fn current() -> Result<Option<CUcontext>> {
77        init()?;
78        let d = driver()?;
79        let cu = d.cu_ctx_get_current()?;
80        let mut ctx: CUcontext = core::ptr::null_mut();
81        check(unsafe { cu(&mut ctx) })?;
82        if ctx.is_null() {
83            Ok(None)
84        } else {
85            Ok(Some(ctx))
86        }
87    }
88
89    /// Make this context current on the calling thread.
90    pub fn set_current(&self) -> Result<()> {
91        let d = driver()?;
92        let cu = d.cu_ctx_set_current()?;
93        // SAFETY: `self.inner.handle` is alive for at least the duration of
94        // this call (held by Arc).
95        check(unsafe { cu(self.inner.handle) })
96    }
97
98    /// Push this context onto the thread's context stack.
99    pub fn push(&self) -> Result<()> {
100        let d = driver()?;
101        let cu = d.cu_ctx_push_current()?;
102        check(unsafe { cu(self.inner.handle) })
103    }
104
105    /// Pop the top context off the thread's context stack.
106    pub fn pop() -> Result<CUcontext> {
107        init()?;
108        let d = driver()?;
109        let cu = d.cu_ctx_pop_current()?;
110        let mut ctx: CUcontext = core::ptr::null_mut();
111        check(unsafe { cu(&mut ctx) })?;
112        Ok(ctx)
113    }
114
115    /// Block the calling thread until all work previously submitted to
116    /// streams in this context has completed.
117    pub fn synchronize(&self) -> Result<()> {
118        self.set_current()?;
119        let d = driver()?;
120        let cu = d.cu_ctx_synchronize()?;
121        check(unsafe { cu() })
122    }
123
124    /// API version this context was created with (major*1000 + minor*10, e.g. 12060).
125    pub fn api_version(&self) -> Result<u32> {
126        let d = driver()?;
127        let cu = d.cu_ctx_get_api_version()?;
128        let mut v: core::ffi::c_uint = 0;
129        check(unsafe { cu(self.inner.handle, &mut v) })?;
130        Ok(v)
131    }
132
133    /// Device ordinal of the thread's currently-current context.
134    /// Fails with `CUDA_ERROR_INVALID_CONTEXT` if no context is current.
135    pub fn current_device() -> Result<Device> {
136        let d = driver()?;
137        let cu = d.cu_ctx_get_device()?;
138        let mut dev = baracuda_cuda_sys::CUdevice::default();
139        check(unsafe { cu(&mut dev) })?;
140        Ok(Device(dev))
141    }
142
143    /// Flags the current context was created with (`SCHED_*`, `MAP_HOST`, etc.).
144    ///
145    /// Operates on the thread's current context, so make sure you've made
146    /// this one current first.
147    pub fn current_flags() -> Result<u32> {
148        let d = driver()?;
149        let cu = d.cu_ctx_get_flags()?;
150        let mut f: core::ffi::c_uint = 0;
151        check(unsafe { cu(&mut f) })?;
152        Ok(f)
153    }
154
155    /// Query a resource limit of the current context. `limit` is one of
156    /// [`baracuda_cuda_sys::types::CUlimit`].
157    pub fn get_limit(limit: u32) -> Result<usize> {
158        let d = driver()?;
159        let cu = d.cu_ctx_get_limit()?;
160        let mut v: usize = 0;
161        check(unsafe { cu(&mut v, limit) })?;
162        Ok(v)
163    }
164
165    /// Set a resource limit of the current context. `limit` is one of
166    /// [`baracuda_cuda_sys::types::CUlimit`]. Not all limits are
167    /// writable on every device.
168    pub fn set_limit(limit: u32, value: usize) -> Result<()> {
169        let d = driver()?;
170        let cu = d.cu_ctx_set_limit()?;
171        check(unsafe { cu(limit, value) })
172    }
173
174    /// Current context's L1/shared-memory preference. Values are from
175    /// [`baracuda_cuda_sys::types::CUfunc_cache`].
176    pub fn cache_config() -> Result<u32> {
177        let d = driver()?;
178        let cu = d.cu_ctx_get_cache_config()?;
179        let mut c: core::ffi::c_uint = 0;
180        check(unsafe { cu(&mut c) })?;
181        Ok(c)
182    }
183
184    /// Set the current context's L1/shared-memory preference.
185    pub fn set_cache_config(config: u32) -> Result<()> {
186        let d = driver()?;
187        let cu = d.cu_ctx_set_cache_config()?;
188        check(unsafe { cu(config) })
189    }
190
191    /// Hardware-supported stream priority range `(least_priority, greatest_priority)`.
192    /// On most GPUs that's `(0, -1)` — lower numbers = higher priority.
193    pub fn stream_priority_range() -> Result<(i32, i32)> {
194        let d = driver()?;
195        let cu = d.cu_ctx_get_stream_priority_range()?;
196        let mut least: core::ffi::c_int = 0;
197        let mut greatest: core::ffi::c_int = 0;
198        check(unsafe { cu(&mut least, &mut greatest) })?;
199        Ok((least, greatest))
200    }
201
202    /// Enable peer access from the current context to `peer`'s context.
203    /// After this call, kernels in the current context can read/write
204    /// allocations owned by `peer`.
205    pub fn enable_peer_access(peer: &Context) -> Result<()> {
206        let d = driver()?;
207        let cu = d.cu_ctx_enable_peer_access()?;
208        check(unsafe { cu(peer.inner.handle, 0) })
209    }
210
211    /// Revert [`enable_peer_access`](Self::enable_peer_access).
212    pub fn disable_peer_access(peer: &Context) -> Result<()> {
213        let d = driver()?;
214        let cu = d.cu_ctx_disable_peer_access()?;
215        check(unsafe { cu(peer.inner.handle) })
216    }
217
218    /// The [`Device`] this context was created on.
219    #[inline]
220    pub fn device(&self) -> Device {
221        self.inner.device
222    }
223
224    /// Raw `CUcontext`. Use with care.
225    #[inline]
226    pub fn as_raw(&self) -> CUcontext {
227        self.inner.handle
228    }
229
230    /// Driver-assigned 64-bit context ID. Useful for correlating
231    /// CUPTI / Nsight traces against this `Context`.
232    pub fn id(&self) -> Result<u64> {
233        let d = driver()?;
234        let cu = d.cu_ctx_get_id()?;
235        let mut out: u64 = 0;
236        check(unsafe { cu(self.inner.handle, &mut out) })?;
237        Ok(out)
238    }
239
240    /// Record `event` on this context (rather than tying it to a specific
241    /// stream). CUDA 12+.
242    pub fn record_event(&self, event: &crate::Event) -> Result<()> {
243        let d = driver()?;
244        let cu = d.cu_ctx_record_event()?;
245        check(unsafe { cu(self.inner.handle, event.as_raw()) })
246    }
247
248    /// Make this context wait on `event`. CUDA 12+.
249    pub fn wait_event(&self, event: &crate::Event) -> Result<()> {
250        let d = driver()?;
251        let cu = d.cu_ctx_wait_event()?;
252        check(unsafe { cu(self.inner.handle, event.as_raw()) })
253    }
254}
255
256impl Drop for ContextInner {
257    fn drop(&mut self) {
258        if let Ok(d) = driver() {
259            if let Ok(cu) = d.cu_ctx_destroy() {
260                // SAFETY: `self.handle` was produced by cuCtxCreate and has
261                // not been destroyed elsewhere (we're dropping the last Arc).
262                let _ = unsafe { cu(self.handle) };
263            }
264        }
265    }
266}
267
268/// A retained reference to a device's _primary_ context — the one shared
269/// with the CUDA Runtime API (`cudart`). Use this when you need to mix
270/// driver-API kernels/streams with framework code that relies on the
271/// runtime API (most ML frameworks do).
272///
273/// Each [`PrimaryContext::retain`] bumps a refcount on the device's
274/// primary context; `Drop` calls `cuDevicePrimaryCtxRelease`. The context
275/// itself is destroyed when the refcount hits zero.
276#[derive(Debug)]
277pub struct PrimaryContext {
278    handle: CUcontext,
279    device: Device,
280}
281
282unsafe impl Send for PrimaryContext {}
283unsafe impl Sync for PrimaryContext {}
284
285impl PrimaryContext {
286    /// Increment the refcount on `device`'s primary context and return a
287    /// handle to it. Equivalent to `cuDevicePrimaryCtxRetain`.
288    pub fn retain(device: &Device) -> Result<Self> {
289        init()?;
290        let d = driver()?;
291        let cu = d.cu_device_primary_ctx_retain()?;
292        let mut handle: CUcontext = core::ptr::null_mut();
293        check(unsafe { cu(&mut handle, device.0) })?;
294        Ok(Self {
295            handle,
296            device: *device,
297        })
298    }
299
300    /// Forcibly destroy the primary context on `device`, releasing all
301    /// resources and resetting refcounts. Any outstanding handles
302    /// returned by [`retain`] become dangling — only call this when you
303    /// know nobody else (Runtime API, other libraries) is using the
304    /// primary context. Equivalent to `cuDevicePrimaryCtxReset`.
305    ///
306    /// [`retain`]: Self::retain
307    pub fn reset(device: &Device) -> Result<()> {
308        init()?;
309        let d = driver()?;
310        let cu = d.cu_device_primary_ctx_reset()?;
311        check(unsafe { cu(device.0) })
312    }
313
314    /// Underlying device.
315    pub fn device(&self) -> Device {
316        self.device
317    }
318
319    /// Raw `CUcontext` — same handle the Runtime API would use.
320    #[inline]
321    pub fn as_raw(&self) -> CUcontext {
322        self.handle
323    }
324}
325
326impl Drop for PrimaryContext {
327    fn drop(&mut self) {
328        if let Ok(d) = driver() {
329            if let Ok(cu) = d.cu_device_primary_ctx_release() {
330                let _ = unsafe { cu(self.device.0) };
331            }
332        }
333    }
334}