1use crate::*;
2use num_enum::TryFromPrimitive;
3use std::{ptr::null_mut, rc::Rc};
4
5#[derive(Debug)]
8pub struct Context {
9 pub(crate) inner: *mut sys::CUctx_st,
10}
11
12impl Context {
13 pub fn new(device: &Device) -> CudaResult<Context> {
15 let mut inner = null_mut();
16 cuda_error(unsafe {
17 sys::cuCtxCreate_v2(
18 &mut inner as *mut _,
19 sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC,
20 device.handle,
21 )
22 })?;
23 Ok(Context { inner })
24 }
25
26 pub fn version(&self) -> CudaResult<CudaVersion> {
29 let mut out = 0u32;
30 cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?;
31 Ok(out.into())
32 }
33
34 pub fn synchronize(&self) -> CudaResult<()> {
36 cuda_error(unsafe { sys::cuCtxSynchronize() })
37 }
38
39 pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> {
41 cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value as sys::size_t) })
42 }
43
44 pub fn get_limit(&self, limit: LimitType) -> CudaResult<u64> {
46 let mut out: sys::size_t = 0;
47 cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut sys::size_t, limit as u32) })?;
48 Ok(out as u64)
49 }
50
51 pub fn enter<'a>(&'a mut self) -> CudaResult<Rc<Handle<'a>>> {
53 cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?;
54 Ok(Rc::new(Handle {
55 context: self,
56 }))
58 }
59}
60
61impl Drop for Context {
62 fn drop(&mut self) {
63 if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) {
64 eprintln!("CUDA: failed to destroy cuda context: {:?}", e);
65 }
66 }
67}
68
69pub struct Handle<'a> {
71 pub(crate) context: &'a mut Context,
72 }
74
75impl<'a> Handle<'a> {
76 pub fn context(&self) -> &Context {
78 &self.context
79 }
80
81 }
95
96impl<'a> Drop for Handle<'a> {
97 fn drop(&mut self) {
98 if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) {
99 eprintln!("CUDA: error dropping context handle: {:?}", e);
100 }
101 }
102}
103
104#[derive(Clone, Copy, Debug, TryFromPrimitive)]
106#[repr(u32)]
107pub enum LimitType {
108 StackSize = 0x00,
110 PrintfFifoSize = 0x01,
112 MallocHeapSize = 0x02,
114 DevRuntimeSyncDepth = 0x03,
116 DevRuntimePendingLaunchCount = 0x04,
118 MaxL2FetchGranularity = 0x05,
120 PersistingL2CacheSize = 0x06,
122}