baracuda_driver/
module.rs1use core::ffi::{c_char, c_void};
4use std::ffi::CString;
5use std::sync::Arc;
6
7use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUmodule};
8
9use crate::context::Context;
10use crate::error::{check, Result};
11
12#[derive(Clone)]
14pub struct Module {
15 inner: Arc<ModuleInner>,
16}
17
18struct ModuleInner {
19 handle: CUmodule,
20 context: Context,
21}
22
23unsafe impl Send for ModuleInner {}
24unsafe impl Sync for ModuleInner {}
25
26impl core::fmt::Debug for ModuleInner {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.debug_struct("Module")
29 .field("handle", &self.handle)
30 .finish_non_exhaustive()
31 }
32}
33
34impl core::fmt::Debug for Module {
35 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36 self.inner.fmt(f)
37 }
38}
39
40impl Module {
41 pub fn load_raw(context: &Context, image: &[u8]) -> Result<Self> {
47 context.set_current()?;
48 let d = driver()?;
49 let cu = d.cu_module_load_data()?;
50 let mut module: CUmodule = core::ptr::null_mut();
51 check(unsafe { cu(&mut module, image.as_ptr() as *const c_void) })?;
53 Ok(Self {
54 inner: Arc::new(ModuleInner {
55 handle: module,
56 context: context.clone(),
57 }),
58 })
59 }
60
61 pub fn load_ptx(context: &Context, ptx_source: &str) -> Result<Self> {
63 let c_src = CString::new(ptx_source).map_err(|_| {
65 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
66 library: "cuda-driver",
67 symbol: "cuModuleLoadData(PTX input contained a NUL byte)",
68 })
69 })?;
70 Self::load_raw(context, c_src.as_bytes_with_nul())
71 }
72
73 pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
76 let d = driver()?;
77 let cu = d.cu_module_get_global()?;
78 let c_name = CString::new(name).map_err(|_| {
79 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
80 library: "cuda-driver",
81 symbol: "cuModuleGetGlobal(name contained a NUL byte)",
82 })
83 })?;
84 let mut dptr = CUdeviceptr(0);
85 let mut bytes: usize = 0;
86 check(unsafe {
87 cu(
88 &mut dptr,
89 &mut bytes,
90 self.inner.handle,
91 c_name.as_ptr() as *const c_char,
92 )
93 })?;
94 Ok((dptr, bytes))
95 }
96
97 pub fn get_function(&self, name: &str) -> Result<Function> {
99 let d = driver()?;
100 let cu = d.cu_module_get_function()?;
101 let c_name = CString::new(name).map_err(|_| {
102 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
103 library: "cuda-driver",
104 symbol: "cuModuleGetFunction(kernel name contained a NUL byte)",
105 })
106 })?;
107 let mut func: CUfunction = core::ptr::null_mut();
108 check(unsafe {
111 cu(
112 &mut func,
113 self.inner.handle,
114 c_name.as_ptr() as *const c_char,
115 )
116 })?;
117 Ok(Function {
118 handle: func,
119 _owner: FunctionOwner::Module(self.clone()),
120 })
121 }
122
123 #[inline]
125 pub fn as_raw(&self) -> CUmodule {
126 self.inner.handle
127 }
128
129 pub fn loading_mode() -> Result<i32> {
133 let d = driver()?;
134 let cu = d.cu_module_get_loading_mode()?;
135 let mut mode: core::ffi::c_int = 0;
136 check(unsafe { cu(&mut mode) })?;
137 Ok(mode)
138 }
139
140 pub unsafe fn load_data_ex(
154 context: &Context,
155 image: &[u8],
156 options: &mut [i32],
157 option_values: &mut [*mut core::ffi::c_void],
158 ) -> Result<Self> { unsafe {
159 assert_eq!(
160 options.len(),
161 option_values.len(),
162 "load_data_ex: options and option_values must have the same length"
163 );
164 context.set_current()?;
165 let d = driver()?;
166 let cu = d.cu_module_load_data_ex()?;
167 let mut module: CUmodule = core::ptr::null_mut();
168 check(cu(
169 &mut module,
170 image.as_ptr() as *const c_void,
171 options.len() as core::ffi::c_uint,
172 options.as_mut_ptr(),
173 option_values.as_mut_ptr(),
174 ))?;
175 Ok(Self {
176 inner: Arc::new(ModuleInner {
177 handle: module,
178 context: context.clone(),
179 }),
180 })
181 }}
182
183 #[inline]
185 pub fn context(&self) -> &Context {
186 &self.inner.context
187 }
188}
189
190impl Drop for ModuleInner {
191 fn drop(&mut self) {
192 if let Ok(d) = driver() {
193 if let Ok(cu) = d.cu_module_unload() {
194 let _ = unsafe { cu(self.handle) };
195 }
196 }
197 }
198}
199
200#[derive(Clone, Debug)]
205pub struct Function {
206 handle: CUfunction,
207 _owner: FunctionOwner,
208}
209
210#[derive(Clone, Debug)]
211#[allow(dead_code)]
212enum FunctionOwner {
213 Module(Module),
215 Library(crate::library::Library),
217}
218
219impl Function {
220 pub(crate) fn from_raw_with_library(
224 handle: CUfunction,
225 library: crate::library::Library,
226 ) -> Self {
227 Self {
228 handle,
229 _owner: FunctionOwner::Library(library),
230 }
231 }
232}
233
234unsafe impl Send for Function {}
235unsafe impl Sync for Function {}
236
237impl Function {
238 #[inline]
240 pub fn as_raw(&self) -> CUfunction {
241 self.handle
242 }
243
244 #[inline]
248 pub fn module(&self) -> Option<&Module> {
249 match &self._owner {
250 FunctionOwner::Module(m) => Some(m),
251 FunctionOwner::Library(_) => None,
252 }
253 }
254
255 pub fn get_attribute(&self, attribute: i32) -> Result<i32> {
258 let d = driver()?;
259 let cu = d.cu_func_get_attribute()?;
260 let mut v: core::ffi::c_int = 0;
261 check(unsafe { cu(&mut v, attribute, self.handle) })?;
262 Ok(v)
263 }
264
265 pub fn name(&self) -> Result<String> {
267 let d = driver()?;
268 let cu = d.cu_func_get_name()?;
269 let mut p: *const core::ffi::c_char = core::ptr::null();
270 check(unsafe { cu(&mut p, self.handle) })?;
271 if p.is_null() {
272 return Ok(String::new());
273 }
274 let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
275 Ok(cstr.to_string_lossy().into_owned())
276 }
277
278 pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
281 let d = driver()?;
282 let cu = d.cu_func_get_param_info()?;
283 let mut off: usize = 0;
284 let mut sz: usize = 0;
285 check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
286 Ok((off, sz))
287 }
288
289 pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
291 let d = driver()?;
292 let cu = d.cu_func_get_module()?;
293 let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
294 check(unsafe { cu(&mut m, self.handle) })?;
295 Ok(m)
296 }
297
298 pub fn set_attribute(&self, attribute: i32, value: i32) -> Result<()> {
302 let d = driver()?;
303 let cu = d.cu_func_set_attribute()?;
304 check(unsafe { cu(self.handle, attribute, value) })
305 }
306
307 pub fn max_threads_per_block(&self) -> Result<i32> {
311 use baracuda_cuda_sys::types::CUfunction_attribute as A;
312 self.get_attribute(A::MAX_THREADS_PER_BLOCK)
313 }
314
315 pub fn shared_size_bytes(&self) -> Result<i32> {
317 use baracuda_cuda_sys::types::CUfunction_attribute as A;
318 self.get_attribute(A::SHARED_SIZE_BYTES)
319 }
320
321 pub fn num_regs(&self) -> Result<i32> {
323 use baracuda_cuda_sys::types::CUfunction_attribute as A;
324 self.get_attribute(A::NUM_REGS)
325 }
326
327 pub fn local_size_bytes(&self) -> Result<i32> {
329 use baracuda_cuda_sys::types::CUfunction_attribute as A;
330 self.get_attribute(A::LOCAL_SIZE_BYTES)
331 }
332
333 pub fn ptx_version(&self) -> Result<i32> {
335 use baracuda_cuda_sys::types::CUfunction_attribute as A;
336 self.get_attribute(A::PTX_VERSION)
337 }
338
339 pub fn binary_version(&self) -> Result<i32> {
341 use baracuda_cuda_sys::types::CUfunction_attribute as A;
342 self.get_attribute(A::BINARY_VERSION)
343 }
344}