Skip to main content

baracuda_driver/
module.rs

1//! Compiled module loading (PTX, CUBIN, fatbin) and kernel entry-point lookup.
2
3use core::ffi::{c_char, c_void};
4use std::ffi::CString;
5use std::sync::Arc;
6
7use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUmodule};
8
9use crate::context::Context;
10use crate::error::{check, Result};
11
12/// A loaded CUDA module (e.g. compiled PTX).
13#[derive(Clone)]
14pub struct Module {
15    inner: Arc<ModuleInner>,
16}
17
18struct ModuleInner {
19    handle: CUmodule,
20    context: Context,
21}
22
23unsafe impl Send for ModuleInner {}
24unsafe impl Sync for ModuleInner {}
25
26impl core::fmt::Debug for ModuleInner {
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.debug_struct("Module")
29            .field("handle", &self.handle)
30            .finish_non_exhaustive()
31    }
32}
33
34impl core::fmt::Debug for Module {
35    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36        self.inner.fmt(f)
37    }
38}
39
40impl Module {
41    /// Load a module from a raw binary image (CUBIN, fatbin, or PTX text with a trailing NUL).
42    ///
43    /// For PTX, the bytes must be a null-terminated UTF-8 string matching the
44    /// `ptx` file on disk. [`Module::load_ptx`] is a convenience wrapper that
45    /// adds the NUL for you.
46    pub fn load_raw(context: &Context, image: &[u8]) -> Result<Self> {
47        context.set_current()?;
48        let d = driver()?;
49        let cu = d.cu_module_load_data()?;
50        let mut module: CUmodule = core::ptr::null_mut();
51        // SAFETY: `image.as_ptr()` is valid for reads of the image bytes.
52        check(unsafe { cu(&mut module, image.as_ptr() as *const c_void) })?;
53        Ok(Self {
54            inner: Arc::new(ModuleInner {
55                handle: module,
56                context: context.clone(),
57            }),
58        })
59    }
60
61    /// Load a module from a PTX source string.
62    pub fn load_ptx(context: &Context, ptx_source: &str) -> Result<Self> {
63        // cuModuleLoadData expects a null-terminated buffer for PTX.
64        let c_src = CString::new(ptx_source).map_err(|_| {
65            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
66                library: "cuda-driver",
67                symbol: "cuModuleLoadData(PTX input contained a NUL byte)",
68            })
69        })?;
70        Self::load_raw(context, c_src.as_bytes_with_nul())
71    }
72
73    /// Look up a `__device__` global variable by name. Returns
74    /// `(device_ptr, size_in_bytes)`.
75    pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
76        let d = driver()?;
77        let cu = d.cu_module_get_global()?;
78        let c_name = CString::new(name).map_err(|_| {
79            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
80                library: "cuda-driver",
81                symbol: "cuModuleGetGlobal(name contained a NUL byte)",
82            })
83        })?;
84        let mut dptr = CUdeviceptr(0);
85        let mut bytes: usize = 0;
86        check(unsafe {
87            cu(
88                &mut dptr,
89                &mut bytes,
90                self.inner.handle,
91                c_name.as_ptr() as *const c_char,
92            )
93        })?;
94        Ok((dptr, bytes))
95    }
96
97    /// Look up a kernel entry point by name.
98    pub fn get_function(&self, name: &str) -> Result<Function> {
99        let d = driver()?;
100        let cu = d.cu_module_get_function()?;
101        let c_name = CString::new(name).map_err(|_| {
102            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
103                library: "cuda-driver",
104                symbol: "cuModuleGetFunction(kernel name contained a NUL byte)",
105            })
106        })?;
107        let mut func: CUfunction = core::ptr::null_mut();
108        // SAFETY: `func` writable; `self.inner.handle` owned by this Arc;
109        // `c_name.as_ptr()` is null-terminated.
110        check(unsafe {
111            cu(
112                &mut func,
113                self.inner.handle,
114                c_name.as_ptr() as *const c_char,
115            )
116        })?;
117        Ok(Function {
118            handle: func,
119            _owner: FunctionOwner::Module(self.clone()),
120        })
121    }
122
123    /// Raw `CUmodule` handle. Use with care.
124    #[inline]
125    pub fn as_raw(&self) -> CUmodule {
126        self.inner.handle
127    }
128
129    /// Return the current process-wide module loading mode (eager vs. lazy).
130    /// Compare against
131    /// [`baracuda_cuda_sys::types::CUmoduleLoadingMode`] constants.
132    pub fn loading_mode() -> Result<i32> {
133        let d = driver()?;
134        let cu = d.cu_module_get_loading_mode()?;
135        let mut mode: core::ffi::c_int = 0;
136        check(unsafe { cu(&mut mode) })?;
137        Ok(mode)
138    }
139
140    /// Load a module from a raw image with extra JIT compiler options —
141    /// the typical use is capturing the JIT log when a PTX module
142    /// fails to compile. `options` and `option_values` are parallel
143    /// arrays whose entries follow the `CUjit_option` ABI (see the
144    /// CUDA driver reference). For PTX, the bytes must be a
145    /// null-terminated UTF-8 string.
146    ///
147    /// # Safety
148    ///
149    /// Each `option_value` must point at a value of the type the
150    /// matching `CUjit_option` expects (some are pointers, some are
151    /// integers cast to `*mut c_void`). The arrays must have the same
152    /// length.
153    pub unsafe fn load_data_ex(
154        context: &Context,
155        image: &[u8],
156        options: &mut [i32],
157        option_values: &mut [*mut core::ffi::c_void],
158    ) -> Result<Self> { unsafe {
159        assert_eq!(
160            options.len(),
161            option_values.len(),
162            "load_data_ex: options and option_values must have the same length"
163        );
164        context.set_current()?;
165        let d = driver()?;
166        let cu = d.cu_module_load_data_ex()?;
167        let mut module: CUmodule = core::ptr::null_mut();
168        check(cu(
169            &mut module,
170            image.as_ptr() as *const c_void,
171            options.len() as core::ffi::c_uint,
172            options.as_mut_ptr(),
173            option_values.as_mut_ptr(),
174        ))?;
175        Ok(Self {
176            inner: Arc::new(ModuleInner {
177                handle: module,
178                context: context.clone(),
179            }),
180        })
181    }}
182
183    /// The [`Context`] this module was loaded into.
184    #[inline]
185    pub fn context(&self) -> &Context {
186        &self.inner.context
187    }
188}
189
190impl Drop for ModuleInner {
191    fn drop(&mut self) {
192        if let Ok(d) = driver() {
193            if let Ok(cu) = d.cu_module_unload() {
194                let _ = unsafe { cu(self.handle) };
195            }
196        }
197    }
198}
199
200/// A kernel entry point — either inside a [`Module`] (classic
201/// Driver API) or materialized from a [`crate::library::Kernel`] (CUDA
202/// 12.0+ library API). Either way it keeps the parent alive via an Arc
203/// so the kernel stays valid for as long as any [`Function`] handle exists.
204#[derive(Clone, Debug)]
205pub struct Function {
206    handle: CUfunction,
207    _owner: FunctionOwner,
208}
209
210#[derive(Clone, Debug)]
211#[allow(dead_code)]
212enum FunctionOwner {
213    /// Owned by a `Module` (classic Driver API flow).
214    Module(Module),
215    /// Owned by a `Library` (CUDA 12.0+ cuLibrary flow).
216    Library(crate::library::Library),
217}
218
219impl Function {
220    /// Construct from an already-resolved `CUfunction` plus the parent
221    /// library that owns it. Intended for `library::Kernel`'s
222    /// `function_for_current_context`.
223    pub(crate) fn from_raw_with_library(
224        handle: CUfunction,
225        library: crate::library::Library,
226    ) -> Self {
227        Self {
228            handle,
229            _owner: FunctionOwner::Library(library),
230        }
231    }
232}
233
234unsafe impl Send for Function {}
235unsafe impl Sync for Function {}
236
237impl Function {
238    /// Raw `CUfunction`. Use with care.
239    #[inline]
240    pub fn as_raw(&self) -> CUfunction {
241        self.handle
242    }
243
244    /// The [`Module`] this kernel lives in, if it was obtained through
245    /// `Module::get_function`. Returns `None` for kernels materialized
246    /// from a `library::Kernel`.
247    #[inline]
248    pub fn module(&self) -> Option<&Module> {
249        match &self._owner {
250            FunctionOwner::Module(m) => Some(m),
251            FunctionOwner::Library(_) => None,
252        }
253    }
254
255    /// Query a kernel attribute (see
256    /// [`baracuda_cuda_sys::types::CUfunction_attribute`]).
257    pub fn get_attribute(&self, attribute: i32) -> Result<i32> {
258        let d = driver()?;
259        let cu = d.cu_func_get_attribute()?;
260        let mut v: core::ffi::c_int = 0;
261        check(unsafe { cu(&mut v, attribute, self.handle) })?;
262        Ok(v)
263    }
264
265    /// Return the demangled kernel name reported by the driver.
266    pub fn name(&self) -> Result<String> {
267        let d = driver()?;
268        let cu = d.cu_func_get_name()?;
269        let mut p: *const core::ffi::c_char = core::ptr::null();
270        check(unsafe { cu(&mut p, self.handle) })?;
271        if p.is_null() {
272            return Ok(String::new());
273        }
274        let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
275        Ok(cstr.to_string_lossy().into_owned())
276    }
277
278    /// Return `(offset_in_bytes, size_in_bytes)` for the `index`-th
279    /// parameter in this function's ABI signature.
280    pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
281        let d = driver()?;
282        let cu = d.cu_func_get_param_info()?;
283        let mut off: usize = 0;
284        let mut sz: usize = 0;
285        check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
286        Ok((off, sz))
287    }
288
289    /// Return the raw `CUmodule` this function was loaded from, if any.
290    pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
291        let d = driver()?;
292        let cu = d.cu_func_get_module()?;
293        let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
294        check(unsafe { cu(&mut m, self.handle) })?;
295        Ok(m)
296    }
297
298    /// Set a kernel attribute. Only a subset is writable (notably
299    /// `MAX_DYNAMIC_SHARED_SIZE_BYTES` and
300    /// `PREFERRED_SHARED_MEMORY_CARVEOUT`).
301    pub fn set_attribute(&self, attribute: i32, value: i32) -> Result<()> {
302        let d = driver()?;
303        let cu = d.cu_func_set_attribute()?;
304        check(unsafe { cu(self.handle, attribute, value) })
305    }
306
307    // Convenience named accessors for the most-read attributes.
308
309    /// Maximum threads per block this kernel supports on the current device.
310    pub fn max_threads_per_block(&self) -> Result<i32> {
311        use baracuda_cuda_sys::types::CUfunction_attribute as A;
312        self.get_attribute(A::MAX_THREADS_PER_BLOCK)
313    }
314
315    /// Size of per-block statically-allocated shared memory (bytes).
316    pub fn shared_size_bytes(&self) -> Result<i32> {
317        use baracuda_cuda_sys::types::CUfunction_attribute as A;
318        self.get_attribute(A::SHARED_SIZE_BYTES)
319    }
320
321    /// Number of registers used per thread.
322    pub fn num_regs(&self) -> Result<i32> {
323        use baracuda_cuda_sys::types::CUfunction_attribute as A;
324        self.get_attribute(A::NUM_REGS)
325    }
326
327    /// Per-thread local-memory footprint (bytes).
328    pub fn local_size_bytes(&self) -> Result<i32> {
329        use baracuda_cuda_sys::types::CUfunction_attribute as A;
330        self.get_attribute(A::LOCAL_SIZE_BYTES)
331    }
332
333    /// PTX version this kernel was compiled from, as `major*10 + minor`.
334    pub fn ptx_version(&self) -> Result<i32> {
335        use baracuda_cuda_sys::types::CUfunction_attribute as A;
336        self.get_attribute(A::PTX_VERSION)
337    }
338
339    /// SM-architecture this kernel was compiled for, as `major*10 + minor`.
340    pub fn binary_version(&self) -> Result<i32> {
341        use baracuda_cuda_sys::types::CUfunction_attribute as A;
342        self.get_attribute(A::BINARY_VERSION)
343    }
344}