baracuda_driver/
context.rs1use std::sync::Arc;
9
10use baracuda_cuda_sys::types::CUcontext_flags;
11use baracuda_cuda_sys::{driver, CUcontext};
12
13use crate::device::Device;
14use crate::error::{check, Result};
15use crate::init::init;
16
17#[derive(Clone, Debug)]
21pub struct Context {
22 inner: Arc<ContextInner>,
23}
24
25struct ContextInner {
26 handle: CUcontext,
27 device: Device,
28}
29
30unsafe impl Send for ContextInner {}
35unsafe impl Sync for ContextInner {}
36
37impl core::fmt::Debug for ContextInner {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 f.debug_struct("Context")
40 .field("handle", &self.handle)
41 .field("device", &self.device)
42 .finish()
43 }
44}
45
46impl Context {
47 pub fn new(device: &Device) -> Result<Self> {
49 Self::with_flags(device, CUcontext_flags::SCHED_AUTO)
50 }
51
52 pub fn with_flags(device: &Device, flags: u32) -> Result<Self> {
56 init()?;
57 let d = driver()?;
58 let cu = d.cu_ctx_create()?;
59 let mut ctx: CUcontext = core::ptr::null_mut();
60 check(unsafe { cu(&mut ctx, flags, device.0) })?;
62 Ok(Self {
63 inner: Arc::new(ContextInner {
64 handle: ctx,
65 device: *device,
66 }),
67 })
68 }
69
70 pub fn current() -> Result<Option<CUcontext>> {
77 init()?;
78 let d = driver()?;
79 let cu = d.cu_ctx_get_current()?;
80 let mut ctx: CUcontext = core::ptr::null_mut();
81 check(unsafe { cu(&mut ctx) })?;
82 if ctx.is_null() {
83 Ok(None)
84 } else {
85 Ok(Some(ctx))
86 }
87 }
88
89 pub fn set_current(&self) -> Result<()> {
91 let d = driver()?;
92 let cu = d.cu_ctx_set_current()?;
93 check(unsafe { cu(self.inner.handle) })
96 }
97
98 pub fn push(&self) -> Result<()> {
100 let d = driver()?;
101 let cu = d.cu_ctx_push_current()?;
102 check(unsafe { cu(self.inner.handle) })
103 }
104
105 pub fn pop() -> Result<CUcontext> {
107 init()?;
108 let d = driver()?;
109 let cu = d.cu_ctx_pop_current()?;
110 let mut ctx: CUcontext = core::ptr::null_mut();
111 check(unsafe { cu(&mut ctx) })?;
112 Ok(ctx)
113 }
114
115 pub fn synchronize(&self) -> Result<()> {
118 self.set_current()?;
119 let d = driver()?;
120 let cu = d.cu_ctx_synchronize()?;
121 check(unsafe { cu() })
122 }
123
124 pub fn api_version(&self) -> Result<u32> {
126 let d = driver()?;
127 let cu = d.cu_ctx_get_api_version()?;
128 let mut v: core::ffi::c_uint = 0;
129 check(unsafe { cu(self.inner.handle, &mut v) })?;
130 Ok(v)
131 }
132
133 pub fn current_device() -> Result<Device> {
136 let d = driver()?;
137 let cu = d.cu_ctx_get_device()?;
138 let mut dev = baracuda_cuda_sys::CUdevice::default();
139 check(unsafe { cu(&mut dev) })?;
140 Ok(Device(dev))
141 }
142
143 pub fn current_flags() -> Result<u32> {
148 let d = driver()?;
149 let cu = d.cu_ctx_get_flags()?;
150 let mut f: core::ffi::c_uint = 0;
151 check(unsafe { cu(&mut f) })?;
152 Ok(f)
153 }
154
155 pub fn get_limit(limit: u32) -> Result<usize> {
158 let d = driver()?;
159 let cu = d.cu_ctx_get_limit()?;
160 let mut v: usize = 0;
161 check(unsafe { cu(&mut v, limit) })?;
162 Ok(v)
163 }
164
165 pub fn set_limit(limit: u32, value: usize) -> Result<()> {
169 let d = driver()?;
170 let cu = d.cu_ctx_set_limit()?;
171 check(unsafe { cu(limit, value) })
172 }
173
174 pub fn cache_config() -> Result<u32> {
177 let d = driver()?;
178 let cu = d.cu_ctx_get_cache_config()?;
179 let mut c: core::ffi::c_uint = 0;
180 check(unsafe { cu(&mut c) })?;
181 Ok(c)
182 }
183
184 pub fn set_cache_config(config: u32) -> Result<()> {
186 let d = driver()?;
187 let cu = d.cu_ctx_set_cache_config()?;
188 check(unsafe { cu(config) })
189 }
190
191 pub fn stream_priority_range() -> Result<(i32, i32)> {
194 let d = driver()?;
195 let cu = d.cu_ctx_get_stream_priority_range()?;
196 let mut least: core::ffi::c_int = 0;
197 let mut greatest: core::ffi::c_int = 0;
198 check(unsafe { cu(&mut least, &mut greatest) })?;
199 Ok((least, greatest))
200 }
201
202 pub fn enable_peer_access(peer: &Context) -> Result<()> {
206 let d = driver()?;
207 let cu = d.cu_ctx_enable_peer_access()?;
208 check(unsafe { cu(peer.inner.handle, 0) })
209 }
210
211 pub fn disable_peer_access(peer: &Context) -> Result<()> {
213 let d = driver()?;
214 let cu = d.cu_ctx_disable_peer_access()?;
215 check(unsafe { cu(peer.inner.handle) })
216 }
217
218 #[inline]
220 pub fn device(&self) -> Device {
221 self.inner.device
222 }
223
224 #[inline]
226 pub fn as_raw(&self) -> CUcontext {
227 self.inner.handle
228 }
229
230 pub fn id(&self) -> Result<u64> {
233 let d = driver()?;
234 let cu = d.cu_ctx_get_id()?;
235 let mut out: u64 = 0;
236 check(unsafe { cu(self.inner.handle, &mut out) })?;
237 Ok(out)
238 }
239
240 pub fn record_event(&self, event: &crate::Event) -> Result<()> {
243 let d = driver()?;
244 let cu = d.cu_ctx_record_event()?;
245 check(unsafe { cu(self.inner.handle, event.as_raw()) })
246 }
247
248 pub fn wait_event(&self, event: &crate::Event) -> Result<()> {
250 let d = driver()?;
251 let cu = d.cu_ctx_wait_event()?;
252 check(unsafe { cu(self.inner.handle, event.as_raw()) })
253 }
254}
255
256impl Drop for ContextInner {
257 fn drop(&mut self) {
258 if let Ok(d) = driver() {
259 if let Ok(cu) = d.cu_ctx_destroy() {
260 let _ = unsafe { cu(self.handle) };
263 }
264 }
265 }
266}
267
268#[derive(Debug)]
277pub struct PrimaryContext {
278 handle: CUcontext,
279 device: Device,
280}
281
282unsafe impl Send for PrimaryContext {}
283unsafe impl Sync for PrimaryContext {}
284
285impl PrimaryContext {
286 pub fn retain(device: &Device) -> Result<Self> {
289 init()?;
290 let d = driver()?;
291 let cu = d.cu_device_primary_ctx_retain()?;
292 let mut handle: CUcontext = core::ptr::null_mut();
293 check(unsafe { cu(&mut handle, device.0) })?;
294 Ok(Self {
295 handle,
296 device: *device,
297 })
298 }
299
300 pub fn reset(device: &Device) -> Result<()> {
308 init()?;
309 let d = driver()?;
310 let cu = d.cu_device_primary_ctx_reset()?;
311 check(unsafe { cu(device.0) })
312 }
313
314 pub fn device(&self) -> Device {
316 self.device
317 }
318
319 #[inline]
321 pub fn as_raw(&self) -> CUcontext {
322 self.handle
323 }
324}
325
326impl Drop for PrimaryContext {
327 fn drop(&mut self) {
328 if let Ok(d) = driver() {
329 if let Ok(cu) = d.cu_device_primary_ctx_release() {
330 let _ = unsafe { cu(self.device.0) };
331 }
332 }
333 }
334}