1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use crate::*;
use num_enum::TryFromPrimitive;
use std::{ptr::null_mut, rc::Rc};
#[derive(Debug)]
pub struct Context {
pub(crate) inner: *mut sys::CUctx_st,
}
impl Context {
pub fn new(device: &Device) -> CudaResult<Context> {
let mut inner = null_mut();
cuda_error(unsafe {
sys::cuCtxCreate_v2(
&mut inner as *mut _,
sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC,
device.handle,
)
})?;
Ok(Context { inner })
}
pub fn version(&self) -> CudaResult<CudaVersion> {
let mut out = 0u32;
cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?;
Ok(out.into())
}
pub fn synchronize(&self) -> CudaResult<()> {
cuda_error(unsafe { sys::cuCtxSynchronize() })
}
pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> {
cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value) })
}
pub fn get_limit(&self, limit: LimitType) -> CudaResult<u64> {
let mut out = 0u64;
cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut u64, limit as u32) })?;
Ok(out)
}
pub fn enter<'a>(&'a mut self) -> CudaResult<Rc<Handle<'a>>> {
cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?;
Ok(Rc::new(Handle {
context: self,
}))
}
}
impl Drop for Context {
fn drop(&mut self) {
if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) {
eprintln!("CUDA: failed to destroy cuda context: {:?}", e);
}
}
}
pub struct Handle<'a> {
pub(crate) context: &'a mut Context,
}
impl<'a> Handle<'a> {
pub fn context(&self) -> &Context {
&self.context
}
}
impl<'a> Drop for Handle<'a> {
fn drop(&mut self) {
if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) {
eprintln!("CUDA: error dropping context handle: {:?}", e);
}
}
}
#[derive(Clone, Copy, Debug, TryFromPrimitive)]
#[repr(u32)]
pub enum LimitType {
StackSize = 0x00,
PrintfFifoSize = 0x01,
MallocHeapSize = 0x02,
DevRuntimeSyncDepth = 0x03,
DevRuntimePendingLaunchCount = 0x04,
MaxL2FetchGranularity = 0x05,
PersistingL2CacheSize = 0x06,
}