fil_rustacuda/
module.rs

1//! Functions and types for working with CUDA modules.
2
3use crate::error::{CudaResult, DropResult, ToResult};
4use crate::function::Function;
5use crate::memory::{CopyDestination, DeviceCopy, DevicePointer};
6use std::ffi::{c_void, CStr};
7use std::fmt;
8use std::marker::PhantomData;
9use std::mem;
10use std::ptr;
11
12/// A compiled CUDA module, loaded into a context.
13#[derive(Debug)]
14pub struct Module {
15    inner: cuda_driver_sys::CUmodule,
16}
17impl Module {
18    /// Load a module from the given file name into the current context.
19    ///
20    /// The given file should be either a cubin file, a ptx file, or a fatbin file such as
21    /// those produced by `nvcc`.
22    ///
23    /// # Example
24    ///
25    /// ```
26    /// # use rustacuda::*;
27    /// # use std::error::Error;
28    /// # fn main() -> Result<(), Box<dyn Error>> {
29    /// # let _ctx = quick_init()?;
30    /// use rustacuda::module::Module;
31    /// use std::ffi::CString;
32    ///
33    /// let filename = CString::new("./resources/add.ptx")?;
34    /// let module = Module::load_from_file(&filename)?;
35    /// # Ok(())
36    /// # }
37    /// ```
38    pub fn load_from_file(filename: &CStr) -> CudaResult<Module> {
39        unsafe {
40            let mut module = Module {
41                inner: ptr::null_mut(),
42            };
43            cuda_driver_sys::cuModuleLoad(
44                &mut module.inner as *mut cuda_driver_sys::CUmodule,
45                filename.as_ptr(),
46            )
47            .to_result()?;
48            Ok(module)
49        }
50    }
51
52    /// Load a module from a CStr.
53    ///
54    /// This is useful in combination with `include_str!`, to include the device code into the
55    /// compiled executable.
56    ///
57    /// The given CStr must contain the bytes of a cubin file, a ptx file or a fatbin file such as
58    /// those produced by `nvcc`.
59    ///
60    /// # Example
61    ///
62    /// ```
63    /// # use rustacuda::*;
64    /// # use std::error::Error;
65    /// # fn main() -> Result<(), Box<dyn Error>> {
66    /// # let _ctx = quick_init()?;
67    /// use rustacuda::module::Module;
68    /// use std::ffi::CString;
69    ///
70    /// let image = CString::new(include_str!("../resources/add.ptx"))?;
71    /// let module = Module::load_from_string(&image)?;
72    /// # Ok(())
73    /// # }
74    /// ```
75    pub fn load_from_string(image: &CStr) -> CudaResult<Module> {
76        unsafe {
77            let mut module = Module {
78                inner: ptr::null_mut(),
79            };
80            cuda_driver_sys::cuModuleLoadData(
81                &mut module.inner as *mut cuda_driver_sys::CUmodule,
82                image.as_ptr() as *const c_void,
83            )
84            .to_result()?;
85            Ok(module)
86        }
87    }
88
89    /// Load a module from a byte slice.
90    ///
91    /// This is useful in combination with [`include_bytes!`](std::include_bytes), to include the
92    /// device code into the compiled executable.
93    ///
94    /// The given slice must contain the bytes of a cubin file, a ptx file or a fatbin file such as
95    /// those produced by `nvcc`.
96    ///
97    /// # Example
98    ///
99    /// ```
100    /// # use rustacuda::*;
101    /// # use std::error::Error;
102    /// # fn main() -> Result<(), Box<dyn Error>> {
103    /// # let _ctx = quick_init()?;
104    /// use rustacuda::module::Module;
105    /// use std::ffi::CString;
106    ///
107    /// let image = include_bytes!("../resources/add.ptx");
108    /// let module = Module::load_from_bytes(image)?;
109    /// # Ok(())
110    /// # }
111    /// ```
112    pub fn load_from_bytes(image: &[u8]) -> CudaResult<Module> {
113        unsafe {
114            let mut module = Module {
115                inner: ptr::null_mut(),
116            };
117            cuda_driver_sys::cuModuleLoadData(
118                &mut module.inner as *mut cuda_driver_sys::CUmodule,
119                image.as_ptr() as *const c_void,
120            )
121            .to_result()?;
122            Ok(module)
123        }
124    }
125
126    /// Get a reference to a global symbol, which can then be copied to/from.
127    ///
128    /// # Panics:
129    ///
130    /// This function panics if the size of the symbol is not the same as the `mem::sizeof<T>()`.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// # use rustacuda::*;
136    /// # use rustacuda::memory::CopyDestination;
137    /// # use std::error::Error;
138    /// # fn main() -> Result<(), Box<dyn Error>> {
139    /// # let _ctx = quick_init()?;
140    /// use rustacuda::module::Module;
141    /// use std::ffi::CString;
142    ///
143    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
144    /// let module = Module::load_from_string(&ptx)?;
145    /// let name = CString::new("my_constant")?;
146    /// let symbol = module.get_global::<u32>(&name)?;
147    /// let mut host_const = 0;
148    /// symbol.copy_to(&mut host_const)?;
149    /// assert_eq!(314, host_const);
150    /// # Ok(())
151    /// # }
152    /// ```
153    pub fn get_global<'a, T: DeviceCopy>(&'a self, name: &CStr) -> CudaResult<Symbol<'a, T>> {
154        unsafe {
155            let mut ptr: DevicePointer<T> = DevicePointer::null();
156            let mut size: usize = 0;
157
158            cuda_driver_sys::cuModuleGetGlobal_v2(
159                &mut ptr as *mut DevicePointer<T> as *mut cuda_driver_sys::CUdeviceptr,
160                &mut size as *mut usize,
161                self.inner,
162                name.as_ptr(),
163            )
164            .to_result()?;
165            assert_eq!(size, mem::size_of::<T>());
166            Ok(Symbol {
167                ptr,
168                module: PhantomData,
169            })
170        }
171    }
172
173    /// Get a reference to a kernel function which can then be launched.
174    ///
175    /// # Examples
176    ///
177    /// ```
178    /// # use rustacuda::*;
179    /// # use std::error::Error;
180    /// # fn main() -> Result<(), Box<dyn Error>> {
181    /// # let _ctx = quick_init()?;
182    /// use rustacuda::module::Module;
183    /// use std::ffi::CString;
184    ///
185    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
186    /// let module = Module::load_from_string(&ptx)?;
187    /// let name = CString::new("sum")?;
188    /// let function = module.get_function(&name)?;
189    /// # Ok(())
190    /// # }
191    /// ```
192    pub fn get_function<'a>(&'a self, name: &CStr) -> CudaResult<Function<'a>> {
193        unsafe {
194            let mut func: cuda_driver_sys::CUfunction = ptr::null_mut();
195
196            cuda_driver_sys::cuModuleGetFunction(
197                &mut func as *mut cuda_driver_sys::CUfunction,
198                self.inner,
199                name.as_ptr(),
200            )
201            .to_result()?;
202            Ok(Function::new(func, self))
203        }
204    }
205
206    /// Destroy a `Module`, returning an error.
207    ///
208    /// Destroying a module can return errors from previous asynchronous work. This function
209    /// destroys the given module and returns the error and the un-destroyed module on failure.
210    ///
211    /// # Example
212    ///
213    /// ```
214    /// # use rustacuda::*;
215    /// # use std::error::Error;
216    /// # fn main() -> Result<(), Box<dyn Error>> {
217    /// # let _ctx = quick_init()?;
218    /// use rustacuda::module::Module;
219    /// use std::ffi::CString;
220    ///
221    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
222    /// let module = Module::load_from_string(&ptx)?;
223    /// match Module::drop(module) {
224    ///     Ok(()) => println!("Successfully destroyed"),
225    ///     Err((e, module)) => {
226    ///         println!("Failed to destroy module: {:?}", e);
227    ///         // Do something with module
228    ///     },
229    /// }
230    /// # Ok(())
231    /// # }
232    /// ```
233    pub fn drop(mut module: Module) -> DropResult<Module> {
234        if module.inner.is_null() {
235            return Ok(());
236        }
237
238        unsafe {
239            let inner = mem::replace(&mut module.inner, ptr::null_mut());
240            match cuda_driver_sys::cuModuleUnload(inner).to_result() {
241                Ok(()) => {
242                    mem::forget(module);
243                    Ok(())
244                }
245                Err(e) => Err((e, Module { inner })),
246            }
247        }
248    }
249}
250impl Drop for Module {
251    fn drop(&mut self) {
252        if self.inner.is_null() {
253            return;
254        }
255        unsafe {
256            // No choice but to panic if this fails...
257            let module = mem::replace(&mut self.inner, ptr::null_mut());
258            cuda_driver_sys::cuModuleUnload(module)
259                .to_result()
260                .expect("Failed to unload CUDA module");
261        }
262    }
263}
264
265/// Handle to a symbol defined within a CUDA module.
266#[derive(Debug)]
267pub struct Symbol<'a, T: DeviceCopy> {
268    ptr: DevicePointer<T>,
269    module: PhantomData<&'a Module>,
270}
271impl<'a, T: DeviceCopy> crate::private::Sealed for Symbol<'a, T> {}
272impl<'a, T: DeviceCopy> fmt::Pointer for Symbol<'a, T> {
273    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274        fmt::Pointer::fmt(&self.ptr, f)
275    }
276}
277impl<'a, T: DeviceCopy> CopyDestination<T> for Symbol<'a, T> {
278    fn copy_from(&mut self, val: &T) -> CudaResult<()> {
279        let size = mem::size_of::<T>();
280        if size != 0 {
281            unsafe {
282                cuda_driver_sys::cuMemcpyHtoD_v2(
283                    self.ptr.as_raw_mut() as u64,
284                    val as *const T as *const c_void,
285                    size,
286                )
287                .to_result()?
288            }
289        }
290        Ok(())
291    }
292
293    fn copy_to(&self, val: &mut T) -> CudaResult<()> {
294        let size = mem::size_of::<T>();
295        if size != 0 {
296            unsafe {
297                cuda_driver_sys::cuMemcpyDtoH_v2(
298                    val as *const T as *mut c_void,
299                    self.ptr.as_raw() as u64,
300                    size,
301                )
302                .to_result()?
303            }
304        }
305        Ok(())
306    }
307}
308
309#[cfg(test)]
310mod test {
311    use super::*;
312    use crate::quick_init;
313    use std::error::Error;
314    use std::ffi::CString;
315
316    #[test]
317    fn test_load_from_file() -> Result<(), Box<dyn Error>> {
318        let _context = quick_init();
319
320        let filename = CString::new("./resources/add.ptx")?;
321        let module = Module::load_from_file(&filename)?;
322        drop(module);
323        Ok(())
324    }
325
326    #[test]
327    fn test_load_from_memory() -> Result<(), Box<dyn Error>> {
328        let _context = quick_init();
329        let ptx_text = CString::new(include_str!("../resources/add.ptx"))?;
330        let module = Module::load_from_string(&ptx_text)?;
331        drop(module);
332        Ok(())
333    }
334
335    #[test]
336    fn test_copy_from_module() -> Result<(), Box<dyn Error>> {
337        let _context = quick_init();
338
339        let ptx = CString::new(include_str!("../resources/add.ptx"))?;
340        let module = Module::load_from_string(&ptx)?;
341
342        let constant_name = CString::new("my_constant")?;
343        let symbol = module.get_global::<u32>(&constant_name)?;
344
345        let mut constant_copy = 0u32;
346        symbol.copy_to(&mut constant_copy)?;
347        assert_eq!(314, constant_copy);
348        Ok(())
349    }
350
351    #[test]
352    fn test_copy_to_module() -> Result<(), Box<dyn Error>> {
353        let _context = quick_init();
354
355        let ptx = CString::new(include_str!("../resources/add.ptx"))?;
356        let module = Module::load_from_string(&ptx)?;
357
358        let constant_name = CString::new("my_constant")?;
359        let mut symbol = module.get_global::<u32>(&constant_name)?;
360
361        symbol.copy_from(&100)?;
362
363        let mut constant_copy = 0u32;
364        symbol.copy_to(&mut constant_copy)?;
365        assert_eq!(100, constant_copy);
366        Ok(())
367    }
368}