fil_rustacuda/module.rs
1//! Functions and types for working with CUDA modules.
2
3use crate::error::{CudaResult, DropResult, ToResult};
4use crate::function::Function;
5use crate::memory::{CopyDestination, DeviceCopy, DevicePointer};
6use std::ffi::{c_void, CStr};
7use std::fmt;
8use std::marker::PhantomData;
9use std::mem;
10use std::ptr;
11
12/// A compiled CUDA module, loaded into a context.
13#[derive(Debug)]
14pub struct Module {
15 inner: cuda_driver_sys::CUmodule,
16}
17impl Module {
18 /// Load a module from the given file name into the current context.
19 ///
20 /// The given file should be either a cubin file, a ptx file, or a fatbin file such as
21 /// those produced by `nvcc`.
22 ///
23 /// # Example
24 ///
25 /// ```
26 /// # use rustacuda::*;
27 /// # use std::error::Error;
28 /// # fn main() -> Result<(), Box<dyn Error>> {
29 /// # let _ctx = quick_init()?;
30 /// use rustacuda::module::Module;
31 /// use std::ffi::CString;
32 ///
33 /// let filename = CString::new("./resources/add.ptx")?;
34 /// let module = Module::load_from_file(&filename)?;
35 /// # Ok(())
36 /// # }
37 /// ```
38 pub fn load_from_file(filename: &CStr) -> CudaResult<Module> {
39 unsafe {
40 let mut module = Module {
41 inner: ptr::null_mut(),
42 };
43 cuda_driver_sys::cuModuleLoad(
44 &mut module.inner as *mut cuda_driver_sys::CUmodule,
45 filename.as_ptr(),
46 )
47 .to_result()?;
48 Ok(module)
49 }
50 }
51
52 /// Load a module from a CStr.
53 ///
54 /// This is useful in combination with `include_str!`, to include the device code into the
55 /// compiled executable.
56 ///
57 /// The given CStr must contain the bytes of a cubin file, a ptx file or a fatbin file such as
58 /// those produced by `nvcc`.
59 ///
60 /// # Example
61 ///
62 /// ```
63 /// # use rustacuda::*;
64 /// # use std::error::Error;
65 /// # fn main() -> Result<(), Box<dyn Error>> {
66 /// # let _ctx = quick_init()?;
67 /// use rustacuda::module::Module;
68 /// use std::ffi::CString;
69 ///
70 /// let image = CString::new(include_str!("../resources/add.ptx"))?;
71 /// let module = Module::load_from_string(&image)?;
72 /// # Ok(())
73 /// # }
74 /// ```
75 pub fn load_from_string(image: &CStr) -> CudaResult<Module> {
76 unsafe {
77 let mut module = Module {
78 inner: ptr::null_mut(),
79 };
80 cuda_driver_sys::cuModuleLoadData(
81 &mut module.inner as *mut cuda_driver_sys::CUmodule,
82 image.as_ptr() as *const c_void,
83 )
84 .to_result()?;
85 Ok(module)
86 }
87 }
88
89 /// Load a module from a byte slice.
90 ///
91 /// This is useful in combination with [`include_bytes!`](std::include_bytes), to include the
92 /// device code into the compiled executable.
93 ///
94 /// The given slice must contain the bytes of a cubin file, a ptx file or a fatbin file such as
95 /// those produced by `nvcc`.
96 ///
97 /// # Example
98 ///
99 /// ```
100 /// # use rustacuda::*;
101 /// # use std::error::Error;
102 /// # fn main() -> Result<(), Box<dyn Error>> {
103 /// # let _ctx = quick_init()?;
104 /// use rustacuda::module::Module;
105 /// use std::ffi::CString;
106 ///
107 /// let image = include_bytes!("../resources/add.ptx");
108 /// let module = Module::load_from_bytes(image)?;
109 /// # Ok(())
110 /// # }
111 /// ```
112 pub fn load_from_bytes(image: &[u8]) -> CudaResult<Module> {
113 unsafe {
114 let mut module = Module {
115 inner: ptr::null_mut(),
116 };
117 cuda_driver_sys::cuModuleLoadData(
118 &mut module.inner as *mut cuda_driver_sys::CUmodule,
119 image.as_ptr() as *const c_void,
120 )
121 .to_result()?;
122 Ok(module)
123 }
124 }
125
126 /// Get a reference to a global symbol, which can then be copied to/from.
127 ///
128 /// # Panics:
129 ///
130 /// This function panics if the size of the symbol is not the same as the `mem::sizeof<T>()`.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// # use rustacuda::*;
136 /// # use rustacuda::memory::CopyDestination;
137 /// # use std::error::Error;
138 /// # fn main() -> Result<(), Box<dyn Error>> {
139 /// # let _ctx = quick_init()?;
140 /// use rustacuda::module::Module;
141 /// use std::ffi::CString;
142 ///
143 /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
144 /// let module = Module::load_from_string(&ptx)?;
145 /// let name = CString::new("my_constant")?;
146 /// let symbol = module.get_global::<u32>(&name)?;
147 /// let mut host_const = 0;
148 /// symbol.copy_to(&mut host_const)?;
149 /// assert_eq!(314, host_const);
150 /// # Ok(())
151 /// # }
152 /// ```
153 pub fn get_global<'a, T: DeviceCopy>(&'a self, name: &CStr) -> CudaResult<Symbol<'a, T>> {
154 unsafe {
155 let mut ptr: DevicePointer<T> = DevicePointer::null();
156 let mut size: usize = 0;
157
158 cuda_driver_sys::cuModuleGetGlobal_v2(
159 &mut ptr as *mut DevicePointer<T> as *mut cuda_driver_sys::CUdeviceptr,
160 &mut size as *mut usize,
161 self.inner,
162 name.as_ptr(),
163 )
164 .to_result()?;
165 assert_eq!(size, mem::size_of::<T>());
166 Ok(Symbol {
167 ptr,
168 module: PhantomData,
169 })
170 }
171 }
172
173 /// Get a reference to a kernel function which can then be launched.
174 ///
175 /// # Examples
176 ///
177 /// ```
178 /// # use rustacuda::*;
179 /// # use std::error::Error;
180 /// # fn main() -> Result<(), Box<dyn Error>> {
181 /// # let _ctx = quick_init()?;
182 /// use rustacuda::module::Module;
183 /// use std::ffi::CString;
184 ///
185 /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
186 /// let module = Module::load_from_string(&ptx)?;
187 /// let name = CString::new("sum")?;
188 /// let function = module.get_function(&name)?;
189 /// # Ok(())
190 /// # }
191 /// ```
192 pub fn get_function<'a>(&'a self, name: &CStr) -> CudaResult<Function<'a>> {
193 unsafe {
194 let mut func: cuda_driver_sys::CUfunction = ptr::null_mut();
195
196 cuda_driver_sys::cuModuleGetFunction(
197 &mut func as *mut cuda_driver_sys::CUfunction,
198 self.inner,
199 name.as_ptr(),
200 )
201 .to_result()?;
202 Ok(Function::new(func, self))
203 }
204 }
205
206 /// Destroy a `Module`, returning an error.
207 ///
208 /// Destroying a module can return errors from previous asynchronous work. This function
209 /// destroys the given module and returns the error and the un-destroyed module on failure.
210 ///
211 /// # Example
212 ///
213 /// ```
214 /// # use rustacuda::*;
215 /// # use std::error::Error;
216 /// # fn main() -> Result<(), Box<dyn Error>> {
217 /// # let _ctx = quick_init()?;
218 /// use rustacuda::module::Module;
219 /// use std::ffi::CString;
220 ///
221 /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
222 /// let module = Module::load_from_string(&ptx)?;
223 /// match Module::drop(module) {
224 /// Ok(()) => println!("Successfully destroyed"),
225 /// Err((e, module)) => {
226 /// println!("Failed to destroy module: {:?}", e);
227 /// // Do something with module
228 /// },
229 /// }
230 /// # Ok(())
231 /// # }
232 /// ```
233 pub fn drop(mut module: Module) -> DropResult<Module> {
234 if module.inner.is_null() {
235 return Ok(());
236 }
237
238 unsafe {
239 let inner = mem::replace(&mut module.inner, ptr::null_mut());
240 match cuda_driver_sys::cuModuleUnload(inner).to_result() {
241 Ok(()) => {
242 mem::forget(module);
243 Ok(())
244 }
245 Err(e) => Err((e, Module { inner })),
246 }
247 }
248 }
249}
250impl Drop for Module {
251 fn drop(&mut self) {
252 if self.inner.is_null() {
253 return;
254 }
255 unsafe {
256 // No choice but to panic if this fails...
257 let module = mem::replace(&mut self.inner, ptr::null_mut());
258 cuda_driver_sys::cuModuleUnload(module)
259 .to_result()
260 .expect("Failed to unload CUDA module");
261 }
262 }
263}
264
265/// Handle to a symbol defined within a CUDA module.
266#[derive(Debug)]
267pub struct Symbol<'a, T: DeviceCopy> {
268 ptr: DevicePointer<T>,
269 module: PhantomData<&'a Module>,
270}
271impl<'a, T: DeviceCopy> crate::private::Sealed for Symbol<'a, T> {}
272impl<'a, T: DeviceCopy> fmt::Pointer for Symbol<'a, T> {
273 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274 fmt::Pointer::fmt(&self.ptr, f)
275 }
276}
277impl<'a, T: DeviceCopy> CopyDestination<T> for Symbol<'a, T> {
278 fn copy_from(&mut self, val: &T) -> CudaResult<()> {
279 let size = mem::size_of::<T>();
280 if size != 0 {
281 unsafe {
282 cuda_driver_sys::cuMemcpyHtoD_v2(
283 self.ptr.as_raw_mut() as u64,
284 val as *const T as *const c_void,
285 size,
286 )
287 .to_result()?
288 }
289 }
290 Ok(())
291 }
292
293 fn copy_to(&self, val: &mut T) -> CudaResult<()> {
294 let size = mem::size_of::<T>();
295 if size != 0 {
296 unsafe {
297 cuda_driver_sys::cuMemcpyDtoH_v2(
298 val as *const T as *mut c_void,
299 self.ptr.as_raw() as u64,
300 size,
301 )
302 .to_result()?
303 }
304 }
305 Ok(())
306 }
307}
308
309#[cfg(test)]
310mod test {
311 use super::*;
312 use crate::quick_init;
313 use std::error::Error;
314 use std::ffi::CString;
315
316 #[test]
317 fn test_load_from_file() -> Result<(), Box<dyn Error>> {
318 let _context = quick_init();
319
320 let filename = CString::new("./resources/add.ptx")?;
321 let module = Module::load_from_file(&filename)?;
322 drop(module);
323 Ok(())
324 }
325
326 #[test]
327 fn test_load_from_memory() -> Result<(), Box<dyn Error>> {
328 let _context = quick_init();
329 let ptx_text = CString::new(include_str!("../resources/add.ptx"))?;
330 let module = Module::load_from_string(&ptx_text)?;
331 drop(module);
332 Ok(())
333 }
334
335 #[test]
336 fn test_copy_from_module() -> Result<(), Box<dyn Error>> {
337 let _context = quick_init();
338
339 let ptx = CString::new(include_str!("../resources/add.ptx"))?;
340 let module = Module::load_from_string(&ptx)?;
341
342 let constant_name = CString::new("my_constant")?;
343 let symbol = module.get_global::<u32>(&constant_name)?;
344
345 let mut constant_copy = 0u32;
346 symbol.copy_to(&mut constant_copy)?;
347 assert_eq!(314, constant_copy);
348 Ok(())
349 }
350
351 #[test]
352 fn test_copy_to_module() -> Result<(), Box<dyn Error>> {
353 let _context = quick_init();
354
355 let ptx = CString::new(include_str!("../resources/add.ptx"))?;
356 let module = Module::load_from_string(&ptx)?;
357
358 let constant_name = CString::new("my_constant")?;
359 let mut symbol = module.get_global::<u32>(&constant_name)?;
360
361 symbol.copy_from(&100)?;
362
363 let mut constant_copy = 0u32;
364 symbol.copy_to(&mut constant_copy)?;
365 assert_eq!(100, constant_copy);
366 Ok(())
367 }
368}