cuda_oxide/
context.rs

1use crate::*;
2use num_enum::TryFromPrimitive;
3use std::{ptr::null_mut, rc::Rc};
4
5/// A CUDA application context.
6/// To start interacting with a device, you want to [`Context::enter`]
7#[derive(Debug)]
8pub struct Context {
9    pub(crate) inner: *mut sys::CUctx_st,
10}
11
12impl Context {
13    /// Creates a new [`Context`] for a given [`Device`]
14    pub fn new(device: &Device) -> CudaResult<Context> {
15        let mut inner = null_mut();
16        cuda_error(unsafe {
17            sys::cuCtxCreate_v2(
18                &mut inner as *mut _,
19                sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC,
20                device.handle,
21            )
22        })?;
23        Ok(Context { inner })
24    }
25
26    /// Gets the API version of the [`Context`].
27    /// This is not the compute capability of the device and probably not what you are looking for. See [`Device::compute_capability`]
28    pub fn version(&self) -> CudaResult<CudaVersion> {
29        let mut out = 0u32;
30        cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?;
31        Ok(out.into())
32    }
33
34    /// Synchronize a [`Context`], running all active handles to completion
35    pub fn synchronize(&self) -> CudaResult<()> {
36        cuda_error(unsafe { sys::cuCtxSynchronize() })
37    }
38
39    /// Set a CUDA context limit
40    pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> {
41        cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value as sys::size_t) })
42    }
43
44    /// Get a CUDA context limit
45    pub fn get_limit(&self, limit: LimitType) -> CudaResult<u64> {
46        let mut out: sys::size_t = 0;
47        cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut sys::size_t, limit as u32) })?;
48        Ok(out as u64)
49    }
50
51    /// Enter a [`Context`], consuming a mutable reference to the context, and allowing thread-local operations to happen.
52    pub fn enter<'a>(&'a mut self) -> CudaResult<Rc<Handle<'a>>> {
53        cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?;
54        Ok(Rc::new(Handle {
55            context: self,
56            // async_stream_pool: RefCell::new(vec![]),
57        }))
58    }
59}
60
61impl Drop for Context {
62    fn drop(&mut self) {
63        if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) {
64            eprintln!("CUDA: failed to destroy cuda context: {:?}", e);
65        }
66    }
67}
68
69/// A CUDA [`Context`] handle for executing thread-local operations.
70pub struct Handle<'a> {
71    pub(crate) context: &'a mut Context,
72    // async_stream_pool: RefCell<Vec<Stream<'a>>>,
73}
74
75impl<'a> Handle<'a> {
76    /// Get an immutable reference to the source context.
77    pub fn context(&self) -> &Context {
78        &self.context
79    }
80
81    // pub(crate) fn get_async_stream(self: &Rc<Handle<'a>>) -> CudaResult<Stream<'a>> {
82    //     let mut pool = self.async_stream_pool.borrow_mut();
83    //     if pool.is_empty() {
84    //         Stream::new(self)
85    //     } else {
86    //         Ok(pool.pop().unwrap())
87    //     }
88    // }
89
90    // pub(crate) fn reset_async_stream(self: &Rc<Handle<'a>>, stream: Stream<'a>) {
91    //     let mut pool = self.async_stream_pool.borrow_mut();
92    //     pool.push(stream);
93    // }
94}
95
96impl<'a> Drop for Handle<'a> {
97    fn drop(&mut self) {
98        if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) {
99            eprintln!("CUDA: error dropping context handle: {:?}", e);
100        }
101    }
102}
103
104/// Context limit types
105#[derive(Clone, Copy, Debug, TryFromPrimitive)]
106#[repr(u32)]
107pub enum LimitType {
108    /// GPU thread stack size
109    StackSize = 0x00,
110    /// GPU printf FIFO size
111    PrintfFifoSize = 0x01,
112    /// GPU malloc heap size
113    MallocHeapSize = 0x02,
114    /// GPU device runtime launch synchronize depth
115    DevRuntimeSyncDepth = 0x03,
116    /// GPU device runtime pending launch count
117    DevRuntimePendingLaunchCount = 0x04,
118    /// A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint
119    MaxL2FetchGranularity = 0x05,
120    /// A size in bytes for L2 persisting lines cache size
121    PersistingL2CacheSize = 0x06,
122}