singe_cuda/context.rs
1use std::{ffi::CString, mem, ptr, sync::Arc};
2
3use singe_cuda_sys::driver;
4
5use crate::{
6 device::Device,
7 error::{Error, Result},
8 graph::Graph,
9 jit::JitOptions,
10 library::Library,
11 module::{Module, ModuleImage},
12 nvrtc::{self, CompilationArtifact, OutputKind},
13 try_ffi,
14 types::Limit,
15};
16
17bitflags::bitflags! {
18 /// Context creation flags.
19 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20 pub struct ContextFlags: u32 {
21 const SCHEDULE_AUTO = driver::CUctx_flags::CU_CTX_SCHED_AUTO as _;
22 const SCHEDULE_SPIN = driver::CUctx_flags::CU_CTX_SCHED_SPIN as _;
23 const SCHEDULE_YIELD = driver::CUctx_flags::CU_CTX_SCHED_YIELD as _;
24 const SCHEDULE_BLOCKING_SYNC = driver::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC as _;
25 const MAP_HOST = driver::CUctx_flags::CU_CTX_MAP_HOST as _;
26 const LOCAL_MEMORY_RESIZE_TO_MAX = driver::CUctx_flags::CU_CTX_LMEM_RESIZE_TO_MAX as _;
27 const COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_COREDUMP_ENABLE as _;
28 const USER_COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_USER_COREDUMP_ENABLE as _;
29 const SYNC_MEMORY_OPERATIONS = driver::CUctx_flags::CU_CTX_SYNC_MEMOPS as _;
30 }
31}
32
33/// A shared CUDA driver context.
34///
35/// Unlike cuBLAS, cuDNN, cuFFT, and similar library handles, a CUDA context is
36/// the underlying execution environment for a device. It is intended to be
37/// shared by streams, modules, libraries, events, allocations, and higher-level
38/// library wrappers.
39///
40/// This type is therefore reference-counted by returning [`Arc<Self>`] from the
41/// constructors, and it remains `Send + Sync`. Shared references do not mutate
42/// Rust-visible state on the [`Context`] object itself; methods such as `bind`
43/// update the calling thread's current CUDA context in the driver.
44///
45/// Prefer one long-lived context per device and share it across dependent CUDA
46/// objects instead of creating many short-lived contexts.
47#[derive(Debug)]
48pub struct Context {
49 handle: driver::CUcontext,
50 device: Device,
51 ownership: ContextOwnership,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55enum ContextOwnership {
56 Created,
57 Primary,
58}
59
60#[non_exhaustive]
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub enum RawContextOwnership {
63 Created,
64 Primary,
65}
66
67impl From<RawContextOwnership> for ContextOwnership {
68 fn from(value: RawContextOwnership) -> Self {
69 match value {
70 RawContextOwnership::Created => Self::Created,
71 RawContextOwnership::Primary => Self::Primary,
72 }
73 }
74}
75
76impl From<ContextOwnership> for RawContextOwnership {
77 fn from(value: ContextOwnership) -> Self {
78 match value {
79 ContextOwnership::Created => Self::Created,
80 ContextOwnership::Primary => Self::Primary,
81 }
82 }
83}
84
85impl Context {
86 pub fn create() -> Result<Arc<Self>> {
87 Self::create_with_flags(ContextFlags::empty())
88 }
89
90 pub fn create_with_flags(flags: ContextFlags) -> Result<Arc<Self>> {
91 let device = Device::current()?;
92 Self::create_for_device_with_flags(device, flags)
93 }
94
95 pub fn create_for_device(device: Device) -> Result<Arc<Self>> {
96 Self::create_for_device_with_flags(device, ContextFlags::empty())
97 }
98
99 pub fn create_for_device_with_flags(device: Device, flags: ContextFlags) -> Result<Arc<Self>> {
100 unsafe {
101 try_ffi!(driver::cuInit(0))?;
102
103 let mut handle = ptr::null_mut();
104 try_ffi!(driver::cuCtxCreate_v4(
105 &raw mut handle,
106 ptr::null_mut(), // CUctxCreateParams
107 flags.bits(),
108 device.id() as _,
109 ))?;
110
111 if handle.is_null() {
112 return Err(Error::NullHandle);
113 }
114
115 Ok(Arc::new(Self {
116 handle,
117 device,
118 ownership: ContextOwnership::Created,
119 }))
120 }
121 }
122
123 pub fn retain_primary_for_device(device: Device) -> Result<Arc<Self>> {
124 unsafe {
125 try_ffi!(driver::cuInit(0))?;
126
127 let mut handle = ptr::null_mut();
128 try_ffi!(driver::cuDevicePrimaryCtxRetain(
129 &raw mut handle,
130 device.id() as _,
131 ))?;
132
133 if handle.is_null() {
134 return Err(Error::NullHandle);
135 }
136
137 try_ffi!(driver::cuCtxSetCurrent(handle))?;
138
139 Ok(Arc::new(Self {
140 handle,
141 device,
142 ownership: ContextOwnership::Primary,
143 }))
144 }
145 }
146
147 /// Binds this CUDA context to the calling CPU thread.
148 ///
149 /// The "current context" is thread-local driver state. Calling this method
150 /// does not mutate the Rust [`Context`] value itself; it makes this context
151 /// current for subsequent CUDA driver and interoperating runtime calls on
152 /// the current host thread.
153 ///
154 /// # Errors
155 ///
156 /// Returns an error if CUDA Driver cannot query or set the current context.
157 pub fn bind(&self) -> Result<()> {
158 unsafe {
159 let mut current_ctx = ptr::null_mut();
160 try_ffi!(driver::cuCtxGetCurrent(&raw mut current_ctx))?;
161 if current_ctx == self.as_raw() {
162 return Ok(());
163 }
164 try_ffi!(driver::cuCtxSetCurrent(self.as_raw()))?;
165 }
166 Ok(())
167 }
168
169 /// Loads the corresponding module from the given image into the current context.
170 /// The image may be a cubin or fatbin as output by **nvcc**, or a NUL-terminated PTX string, either as output by **nvcc** or hand-written, or Tile IR data.
171 ///
172 /// # Errors
173 ///
174 /// Returns an error if the context cannot be bound, CUDA cannot load the module, or a
175 /// previous asynchronous launch reported an error.
176 pub fn load_module(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Module> {
177 self.bind()?;
178
179 unsafe {
180 let mut module_handle = ptr::null_mut();
181 try_ffi!(driver::cuModuleLoadData(
182 &raw mut module_handle,
183 image.as_ptr() as _,
184 ))?;
185 if module_handle.is_null() {
186 return Err(Error::NullHandle);
187 }
188 Module::from_raw(module_handle, Arc::clone(self))
189 }
190 }
191
192 /// Creates an empty CUDA graph associated with this context.
193 ///
194 /// Prefer this over [`RawGraph::create`](crate::graph::RawGraph::create)
195 /// for ordinary Singe code. The returned graph carries its context
196 /// association into instantiated executable graphs, allowing launches and
197 /// uploads to reject streams from another context before calling CUDA.
198 ///
199 /// # Errors
200 ///
201 /// Returns an error if the context cannot be bound or CUDA cannot create the graph.
202 pub fn create_graph(self: &Arc<Self>) -> Result<Graph> {
203 Graph::create_in_context(Arc::clone(self))
204 }
205
206 pub fn unload_module(self: &Arc<Self>, module: Module) -> Result<()> {
207 drop(module);
208 Ok(())
209 }
210
211 /// Loads the corresponding module from the given image into the current context.
212 /// The image may be a cubin or fatbin as output by **nvcc**, or a NUL-terminated PTX string, either as output by **nvcc** or hand-written, or Tile IR data.
213 ///
214 /// # Errors
215 ///
216 /// Returns an error if the context cannot be bound, CUDA cannot load the module, JIT options
217 /// are rejected, or a previous asynchronous launch reported an error.
218 pub fn load_module_with_options(
219 self: &Arc<Self>,
220 image: &ModuleImage<'_>,
221 mut jit_options: JitOptions<'_>,
222 ) -> Result<Module> {
223 self.bind()?;
224
225 let mut jit_options = jit_options.build();
226 unsafe {
227 let mut module_handle = ptr::null_mut();
228 try_ffi!(driver::cuModuleLoadDataEx(
229 &raw mut module_handle,
230 image.as_ptr() as _,
231 jit_options.names.len() as _,
232 jit_options.names.as_mut_ptr() as _,
233 jit_options.values.as_mut_ptr() as _,
234 ))?;
235 if module_handle.is_null() {
236 return Err(Error::NullHandle);
237 }
238 Module::from_raw(module_handle, Arc::clone(self))
239 }
240 }
241
242 pub fn load_nvrtc_module(
243 self: &Arc<Self>,
244 program: &nvrtc::Program,
245 output: OutputKind,
246 ) -> Result<Module> {
247 self.load_nvrtc_module_with_options(program, output, JitOptions::default())
248 }
249
250 pub fn load_nvrtc_module_with_options(
251 self: &Arc<Self>,
252 program: &nvrtc::Program,
253 output: OutputKind,
254 jit_options: JitOptions<'_>,
255 ) -> Result<Module> {
256 let image = module_loadable_image(program.artifact(output)?)?;
257 self.load_module_with_options(&image, jit_options)
258 }
259
260 pub fn load_library(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Library> {
261 self.load_library_with_options(image, JitOptions::default())
262 }
263
264 /// Loads the corresponding library from the given image based on the application defined library loading mode:
265 ///
266 /// * If module loading is set to EAGER by the environment variables described in "Module loading", the library is loaded eagerly into all contexts at the time of the call and future contexts at the time of creation until the library
267 /// is unloaded with [`sys::cuLibraryUnload`](singe_cuda_sys::driver::cuLibraryUnload).
268 /// * If the environment variables are set to LAZY, the library is not immediately loaded into existing contexts and is loaded only when a function is needed for that context,
269 /// such as a kernel launch.
270 ///
271 /// These environment variables are described in the CUDA programming guide under the "CUDA environment variables" section.
272 ///
273 /// The code may be a cubin or fatbin emitted by **nvcc**, a NUL-terminated PTX string emitted by **nvcc** or written by hand, or Tile IR data.
274 /// A fatbin must also contain relocatable code when doing separate compilation.
275 ///
276 /// If the library contains managed variables and no device in the system supports them, this call returns [`crate::error::Status::NotSupported`].
277 pub fn load_library_with_options(
278 self: &Arc<Self>,
279 image: &ModuleImage<'_>,
280 mut jit_options: JitOptions<'_>,
281 ) -> Result<Library> {
282 self.bind()?;
283
284 let mut jit_options = jit_options.build();
285 let mut handle = ptr::null_mut();
286 unsafe {
287 try_ffi!(driver::cuLibraryLoadData(
288 &raw mut handle,
289 image.as_ptr() as _,
290 jit_options.names.as_mut_ptr() as _,
291 jit_options.values.as_mut_ptr() as _,
292 jit_options.names.len() as _,
293 ptr::null_mut(),
294 ptr::null_mut(),
295 0,
296 ))?;
297 }
298 if handle.is_null() {
299 return Err(Error::NullHandle);
300 }
301 unsafe { Library::from_raw(handle, Arc::clone(self)) }
302 }
303
304 pub fn load_nvrtc_library(
305 self: &Arc<Self>,
306 program: &nvrtc::Program,
307 output: OutputKind,
308 ) -> Result<Library> {
309 self.load_nvrtc_library_with_options(program, output, JitOptions::default())
310 }
311
312 pub fn load_nvrtc_library_with_options(
313 self: &Arc<Self>,
314 program: &nvrtc::Program,
315 output: OutputKind,
316 jit_options: JitOptions<'_>,
317 ) -> Result<Library> {
318 let image = library_loadable_image(program.artifact(output)?)?;
319 self.load_library_with_options(&image, jit_options)
320 }
321
322 /// Loads the corresponding library from the given file based on the application defined library loading mode:
323 ///
324 /// * If module loading is set to EAGER by the environment variables described in "Module loading", the library is loaded eagerly into all contexts at the time of the call and future contexts at the time of creation until the library
325 /// is unloaded with [`sys::cuLibraryUnload`](singe_cuda_sys::driver::cuLibraryUnload).
326 /// * If the environment variables are set to LAZY, the library is not immediately loaded into existing contexts and is loaded only when a function is needed for that context,
327 /// such as a kernel launch.
328 ///
329 /// These environment variables are described in the CUDA programming guide under the "CUDA environment variables" section.
330 ///
331 /// The file must be a cubin emitted by **nvcc**, a PTX file emitted by **nvcc** or written by hand, a fatbin emitted by **nvcc** or written by hand, or a Tile IR file.
332 /// A fatbin must also contain relocatable code when doing separate compilation.
333 ///
334 /// If the library contains managed variables and no device in the system supports them, this call returns [`crate::error::Status::NotSupported`].
335 ///
336 /// # Errors
337 ///
338 /// Returns an error if this context cannot be bound, if `path` contains an
339 /// interior NUL byte, or if CUDA Driver cannot load the library.
340 pub fn load_library_from_file(self: &Arc<Self>, path: &str) -> Result<Library> {
341 self.bind()?;
342 let path = CString::new(path)?;
343 let mut handle = ptr::null_mut();
344 unsafe {
345 try_ffi!(driver::cuLibraryLoadFromFile(
346 &raw mut handle,
347 path.as_ptr(),
348 ptr::null_mut(),
349 ptr::null_mut(),
350 0,
351 ptr::null_mut(),
352 ptr::null_mut(),
353 0,
354 ))?;
355 }
356 if handle.is_null() {
357 return Err(Error::NullHandle);
358 }
359 unsafe { Library::from_raw(handle, Arc::clone(self)) }
360 }
361
362 /// Blocks until the current context has completed all preceding requested tasks.
363 /// If the current context is the primary context, child contexts that have been created are also synchronized.
364 /// [`Context::synchronize`] returns an error if one of the preceding tasks failed.
365 /// If the context was created with [`ContextFlags::SCHEDULE_BLOCKING_SYNC`], the CPU thread blocks until the GPU context has finished its work.
366 ///
367 /// # Errors
368 ///
369 /// Returns an error if the context cannot be bound, a preceding task failed, or a previous
370 /// asynchronous launch reported an error.
371 pub fn synchronize(&self) -> Result<()> {
372 self.bind()?;
373 unsafe {
374 try_ffi!(driver::cuCtxSynchronize())?;
375 }
376 Ok(())
377 }
378
379 /// Returns the flags of the current context.
380 /// See [`ContextFlags`] for flag values.
381 ///
382 /// # Errors
383 ///
384 /// Returns an error if the context cannot be bound, CUDA cannot query the flags, or a
385 /// previous asynchronous launch reported an error.
386 pub fn flags(&self) -> Result<ContextFlags> {
387 self.bind()?;
388 unsafe {
389 let mut flags = 0;
390 try_ffi!(driver::cuCtxGetFlags(&raw mut flags))?;
391 Ok(ContextFlags::from_bits_truncate(flags))
392 }
393 }
394
395 /// Returns the current size of limit.
396 /// The supported [`Limit`] values are:
397 ///
398 /// * [`Limit::StackSize`]: stack size in bytes of each GPU thread.
399 /// * [`Limit::PrintfFifoSize`]: size in bytes of the FIFO used by the `printf()` device system call.
400 /// * [`Limit::MallocHeapSize`]: size in bytes of the heap used by the `malloc()` and `free()` device system calls.
401 /// * [`Limit::DevRuntimeSyncDepth`]: maximum grid depth at which a thread can issue the device runtime call [`Device::synchronize`] to wait on child grid launches to complete.
402 /// * [`Limit::DevRuntimePendingLaunchCount`]: maximum number of outstanding device runtime launches that can be made from this context.
403 /// * [`Limit::MaxL2FetchGranularity`]: L2 cache fetch granularity.
404 /// * [`Limit::PersistingL2CacheSize`]: persisting L2 cache size in bytes.
405 ///
406 /// # Errors
407 ///
408 /// Returns an error if the context cannot be bound, `limit` is unsupported, CUDA cannot query
409 /// the limit, or a previous asynchronous launch reported an error.
410 pub fn limit(&self, limit: Limit) -> Result<usize> {
411 self.bind()?;
412 unsafe {
413 let mut value = 0;
414 try_ffi!(driver::cuCtxGetLimit(&raw mut value, limit.into()))?;
415 Ok(value as usize)
416 }
417 }
418
419 /// Setting limit to value is a request by the application to update the current limit maintained by the context.
420 /// The driver may modify the requested value to meet hardware requirements, such as clamping to minimum or maximum values or rounding up to the nearest element size.
421 /// Use [`Context::limit`] to query the effective value.
422 ///
423 /// Setting each [`Limit`] has its own restrictions.
424 ///
425 /// * [`Limit::StackSize`] controls the stack size in bytes of each GPU thread.
426 /// The driver automatically increases the per-thread stack size for each
427 /// kernel launch as needed.
428 /// This size is not reset back to the original value after each launch.
429 /// Setting this value will take
430 /// effect immediately, and if necessary, the device will block until all preceding requested tasks are complete.
431 ///
432 /// * [`Limit::PrintfFifoSize`] controls the size in bytes of the FIFO used by the `printf()` device system call.
433 /// Configure [`Limit::PrintfFifoSize`] before launching any kernel that uses the `printf()` device system call; otherwise [`crate::error::Status::InvalidValue`] is returned.
434 ///
435 /// * [`Limit::MallocHeapSize`] controls the size in bytes of the heap used by the `malloc()` and `free()` device system calls.
436 /// Configure [`Limit::MallocHeapSize`] before launching any kernel that uses the `malloc()` or `free()` device system calls; otherwise [`crate::error::Status::InvalidValue`] is returned.
437 ///
438 /// * [`Limit::DevRuntimeSyncDepth`] controls the maximum nesting depth of a grid at which a thread can safely call [`Device::synchronize`].
439 /// Setting this limit must be performed before any launch of a kernel that uses the device runtime and calls [`Device::synchronize`] above the default sync depth, two levels of grids.
440 /// Calls to [`Device::synchronize`] fail if this limit is violated.
441 /// This limit can be set smaller than the default or up to the maximum launch depth of 24.
442 /// Additional sync-depth levels require the driver to reserve large amounts of device memory that can no longer be used for application allocations.
443 /// If these reservations of device memory fail, [`Context::set_limit`] returns [`crate::error::Status::OutOfMemory`], and the limit can be reset to a lower value.
444 /// This limit is only applicable to devices of compute capability < 9.0.
445 /// Setting this limit on devices of other compute capability versions returns [`crate::error::Status::UnsupportedLimit`].
446 ///
447 /// * [`Limit::DevRuntimePendingLaunchCount`] controls the maximum number of outstanding device runtime launches that can be made from the current context.
448 /// A grid is outstanding from launch until it is known to have completed.
449 /// Device runtime launches that violate this limit fail.
450 /// If a module using the device runtime needs more pending launches than the default 2048 launches, this limit can be increased.
451 /// Sustaining additional pending launches requires the driver to reserve larger amounts of device memory up front, which can no longer be used for allocations.
452 /// If these reservations fail, [`Context::set_limit`] returns [`crate::error::Status::OutOfMemory`], and the limit can be reset to a lower value.
453 /// This limit is only applicable to devices of compute capability 3.5 and higher.
454 /// Attempting to set this limit on devices of compute capability less than 3.5 returns [`crate::error::Status::UnsupportedLimit`].
455 ///
456 /// * [`Limit::MaxL2FetchGranularity`] controls the L2 cache fetch granularity.
457 /// Values can range from 0B to 128B.
458 /// Performance hint that may be ignored or clamped depending on the platform.
459 ///
460 /// * [`Limit::PersistingL2CacheSize`] controls size in bytes available for persisting L2 cache.
461 /// Performance hint that may be ignored or clamped depending on the platform.
462 ///
463 /// # Errors
464 ///
465 /// Returns an error if the context cannot be bound, `limit` is unsupported, CUDA rejects the
466 /// requested value, or a previous asynchronous launch reported an error.
467 pub fn set_limit(&self, limit: Limit, value: usize) -> Result<()> {
468 self.bind()?;
469 unsafe {
470 try_ffi!(driver::cuCtxSetLimit(limit.into(), value as _))?;
471 }
472 Ok(())
473 }
474
475 pub const fn device(&self) -> Device {
476 self.device
477 }
478
479 pub const fn as_raw(&self) -> driver::CUcontext {
480 self.handle
481 }
482
483 /// Takes ownership of a raw CUDA context.
484 ///
485 /// # Safety
486 ///
487 /// `handle` must be a valid CUDA context for `device`, and no other Rust
488 /// wrapper may own the same release responsibility. `ownership` must match
489 /// how the context should be released: created contexts are destroyed with
490 /// `cuCtxDestroy`, while primary contexts are released with
491 /// `cuDevicePrimaryCtxRelease`.
492 pub unsafe fn from_raw(
493 handle: driver::CUcontext,
494 device: Device,
495 ownership: RawContextOwnership,
496 ) -> Result<Arc<Self>> {
497 if handle.is_null() {
498 return Err(Error::NullHandle);
499 }
500
501 Ok(Arc::new(Self {
502 handle,
503 device,
504 ownership: ownership.into(),
505 }))
506 }
507
508 /// Transfers ownership of the raw CUDA context to the caller.
509 ///
510 /// The caller becomes responsible for releasing the returned context
511 /// according to the returned ownership mode.
512 pub fn into_raw_parts(self) -> (driver::CUcontext, Device, RawContextOwnership) {
513 let raw = (self.handle, self.device, self.ownership.into());
514 mem::forget(self);
515 raw
516 }
517}
518
519// CUDA driver contexts are shared execution environments, not per-thread
520// library handles. The Rust wrapper only stores the raw context pointer and the
521// owning device, while current-context selection is maintained by CUDA as
522// thread-local driver state.
523unsafe impl Send for Context {}
524unsafe impl Sync for Context {}
525
526impl Drop for Context {
527 fn drop(&mut self) {
528 unsafe {
529 let result = match self.ownership {
530 ContextOwnership::Created => try_ffi!(driver::cuCtxDestroy_v2(self.handle)),
531 ContextOwnership::Primary => {
532 try_ffi!(driver::cuDevicePrimaryCtxRelease_v2(self.device.id() as _))
533 }
534 };
535
536 if let Err(err) = result {
537 #[cfg(debug_assertions)]
538 eprintln!("failed to destroy CUDA context wrapper: {err}");
539 }
540 }
541 }
542}
543
544impl PartialEq for Context {
545 fn eq(&self, other: &Self) -> bool {
546 self.as_raw() == other.as_raw()
547 }
548}
549
550impl Eq for Context {}
551
552fn module_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
553 match artifact {
554 CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
555 CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
556 }
557}
558
559fn library_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
560 match artifact {
561 CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
562 CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
563 }
564}