1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use crate::*;
use num_enum::TryFromPrimitive;
use std::{ptr::null_mut, rc::Rc};
/// A CUDA application context.
/// To start interacting with a device, you want to [`Context::enter`]
#[derive(Debug)]
pub struct Context {
pub(crate) inner: *mut sys::CUctx_st,
}
impl Context {
/// Creates a new [`Context`] for a given [`Device`]
pub fn new(device: &Device) -> CudaResult<Context> {
let mut inner = null_mut();
cuda_error(unsafe {
sys::cuCtxCreate_v2(
&mut inner as *mut _,
sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC,
device.handle,
)
})?;
Ok(Context { inner })
}
/// Gets the API version of the [`Context`].
/// This is not the compute capability of the device and probably not what you are looking for. See [`Device::compute_capability`]
pub fn version(&self) -> CudaResult<CudaVersion> {
let mut out = 0u32;
cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?;
Ok(out.into())
}
/// Synchronize a [`Context`], running all active handles to completion
pub fn synchronize(&self) -> CudaResult<()> {
cuda_error(unsafe { sys::cuCtxSynchronize() })
}
/// Set a CUDA context limit
pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> {
cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value as sys::size_t) })
}
/// Get a CUDA context limit
pub fn get_limit(&self, limit: LimitType) -> CudaResult<u64> {
let mut out: sys::size_t = 0;
cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut sys::size_t, limit as u32) })?;
Ok(out as u64)
}
/// Enter a [`Context`], consuming a mutable reference to the context, and allowing thread-local operations to happen.
pub fn enter<'a>(&'a mut self) -> CudaResult<Rc<Handle<'a>>> {
cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?;
Ok(Rc::new(Handle {
context: self,
// async_stream_pool: RefCell::new(vec![]),
}))
}
}
impl Drop for Context {
fn drop(&mut self) {
if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) {
eprintln!("CUDA: failed to destroy cuda context: {:?}", e);
}
}
}
/// A CUDA [`Context`] handle for executing thread-local operations.
pub struct Handle<'a> {
pub(crate) context: &'a mut Context,
// async_stream_pool: RefCell<Vec<Stream<'a>>>,
}
impl<'a> Handle<'a> {
/// Get an immutable reference to the source context.
pub fn context(&self) -> &Context {
&self.context
}
// pub(crate) fn get_async_stream(self: &Rc<Handle<'a>>) -> CudaResult<Stream<'a>> {
// let mut pool = self.async_stream_pool.borrow_mut();
// if pool.is_empty() {
// Stream::new(self)
// } else {
// Ok(pool.pop().unwrap())
// }
// }
// pub(crate) fn reset_async_stream(self: &Rc<Handle<'a>>, stream: Stream<'a>) {
// let mut pool = self.async_stream_pool.borrow_mut();
// pool.push(stream);
// }
}
impl<'a> Drop for Handle<'a> {
fn drop(&mut self) {
if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) {
eprintln!("CUDA: error dropping context handle: {:?}", e);
}
}
}
/// Context limit types
#[derive(Clone, Copy, Debug, TryFromPrimitive)]
#[repr(u32)]
pub enum LimitType {
/// GPU thread stack size
StackSize = 0x00,
/// GPU printf FIFO size
PrintfFifoSize = 0x01,
/// GPU malloc heap size
MallocHeapSize = 0x02,
/// GPU device runtime launch synchronize depth
DevRuntimeSyncDepth = 0x03,
/// GPU device runtime pending launch count
DevRuntimePendingLaunchCount = 0x04,
/// A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint
MaxL2FetchGranularity = 0x05,
/// A size in bytes for L2 persisting lines cache size
PersistingL2CacheSize = 0x06,
}