Skip to main content

baracuda_driver/
library.rs

1//! Driver-API library + kernel management (CUDA 12.0+).
2//!
3//! Unlike [`crate::Module`] which loads into a specific [`crate::Context`],
4//! a [`Library`] is context-independent — one `cuLibraryLoadData` call
5//! produces a handle that can be queried for kernels from any context on
6//! any device. This is the modern preferred way to ship precompiled
7//! PTX / CUBIN / fatbin in a reusable form.
8//!
9//! Requires CUDA 12.0+ at runtime. On older drivers the loader reports
10//! `LoaderError::SymbolNotFound` on first use.
11
12use core::ffi::{c_char, c_void};
13use std::ffi::CString;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUkernel, CUlibrary};
17
18use crate::error::{check, Result};
19use crate::module::Function;
20
21/// A loaded CUDA library (CUDA 12.0+).
22#[derive(Clone)]
23pub struct Library {
24    inner: Arc<LibraryInner>,
25}
26
27struct LibraryInner {
28    handle: CUlibrary,
29}
30
31unsafe impl Send for LibraryInner {}
32unsafe impl Sync for LibraryInner {}
33
34impl core::fmt::Debug for LibraryInner {
35    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36        f.debug_struct("Library")
37            .field("handle", &self.handle)
38            .finish_non_exhaustive()
39    }
40}
41
42impl core::fmt::Debug for Library {
43    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44        self.inner.fmt(f)
45    }
46}
47
48impl Library {
49    /// Load a library from a raw image (CUBIN, fatbin, or null-terminated PTX)
50    /// with no JIT/library options.
51    pub fn load_raw(image: &[u8]) -> Result<Self> {
52        let d = driver()?;
53        let cu = d.cu_library_load_data()?;
54        let mut lib: CUlibrary = core::ptr::null_mut();
55        check(unsafe {
56            cu(
57                &mut lib,
58                image.as_ptr() as *const c_void,
59                core::ptr::null_mut(), // jit_options
60                core::ptr::null_mut(), // jit_option_values
61                0,                     // num_jit_options
62                core::ptr::null_mut(), // library_options
63                core::ptr::null_mut(), // library_option_values
64                0,                     // num_library_options
65            )
66        })?;
67        Ok(Self {
68            inner: Arc::new(LibraryInner { handle: lib }),
69        })
70    }
71
72    /// Load a library from a PTX source string (NUL-terminates internally).
73    pub fn load_ptx(ptx: &str) -> Result<Self> {
74        let c_src = CString::new(ptx).map_err(|_| {
75            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
76                library: "cuda-driver",
77                symbol: "cuLibraryLoadData(PTX contained a NUL byte)",
78            })
79        })?;
80        Self::load_raw(c_src.as_bytes_with_nul())
81    }
82
83    /// Look up a kernel entry point by name.
84    pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
85        let d = driver()?;
86        let cu = d.cu_library_get_kernel()?;
87        let c_name = CString::new(name).map_err(|_| {
88            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
89                library: "cuda-driver",
90                symbol: "cuLibraryGetKernel(name contained a NUL byte)",
91            })
92        })?;
93        let mut kernel: CUkernel = core::ptr::null_mut();
94        check(unsafe {
95            cu(
96                &mut kernel,
97                self.inner.handle,
98                c_name.as_ptr() as *const c_char,
99            )
100        })?;
101        Ok(Kernel {
102            handle: kernel,
103            _library: self.clone(),
104        })
105    }
106
107    /// Count of kernels this library exposes (CUDA 12.4+).
108    pub fn kernel_count(&self) -> Result<u32> {
109        let d = driver()?;
110        let cu = d.cu_library_get_kernel_count()?;
111        let mut n: core::ffi::c_uint = 0;
112        check(unsafe { cu(&mut n, self.inner.handle) })?;
113        Ok(n)
114    }
115
116    /// Enumerate every kernel in the library (CUDA 12.4+). Allocates the
117    /// result vector at the count reported by [`Self::kernel_count`].
118    pub fn enumerate_kernels(&self) -> Result<Vec<Kernel>> {
119        let d = driver()?;
120        let n = self.kernel_count()?;
121        let cu = d.cu_library_enumerate_kernels()?;
122        let mut raw: Vec<baracuda_cuda_sys::CUkernel> = vec![core::ptr::null_mut(); n as usize];
123        if n > 0 {
124            check(unsafe { cu(raw.as_mut_ptr(), n, self.inner.handle) })?;
125        }
126        Ok(raw
127            .into_iter()
128            .map(|h| Kernel {
129                handle: h,
130                _library: self.clone(),
131            })
132            .collect())
133    }
134
135    /// Retrieve the underlying `CUmodule` backing this library, if any
136    /// (CUDA 12.4+). Not all libraries have an addressable module — some
137    /// ship compiled-kernel images only.
138    pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
139        let d = driver()?;
140        let cu = d.cu_library_get_module()?;
141        let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
142        check(unsafe { cu(&mut m, self.inner.handle) })?;
143        Ok(m)
144    }
145
146    /// Look up a managed-memory global by name (CUDA 12.2+).
147    pub fn get_managed(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
148        let d = driver()?;
149        let cu = d.cu_library_get_managed()?;
150        let c_name = CString::new(name).map_err(|_| {
151            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
152                library: "cuda-driver",
153                symbol: "cuLibraryGetManaged(name contained a NUL byte)",
154            })
155        })?;
156        let mut dptr = CUdeviceptr(0);
157        let mut bytes: usize = 0;
158        check(unsafe {
159            cu(
160                &mut dptr,
161                &mut bytes,
162                self.inner.handle,
163                c_name.as_ptr() as *const c_char,
164            )
165        })?;
166        Ok((dptr, bytes))
167    }
168
169    /// Look up a unified-function pointer by name (CUDA 12.4+).
170    /// Returns the raw host-side function pointer; the caller is
171    /// responsible for casting it to the right signature before calling.
172    pub fn get_unified_function(&self, name: &str) -> Result<*mut core::ffi::c_void> {
173        let d = driver()?;
174        let cu = d.cu_library_get_unified_function()?;
175        let c_name = CString::new(name).map_err(|_| {
176            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
177                library: "cuda-driver",
178                symbol: "cuLibraryGetUnifiedFunction(name contained a NUL byte)",
179            })
180        })?;
181        let mut fptr: *mut core::ffi::c_void = core::ptr::null_mut();
182        check(unsafe {
183            cu(
184                &mut fptr,
185                self.inner.handle,
186                c_name.as_ptr() as *const c_char,
187            )
188        })?;
189        Ok(fptr)
190    }
191
192    /// Look up a `__device__` global variable by name across contexts.
193    /// Returns `(device_ptr, size_in_bytes)`. The returned pointer is valid
194    /// in whatever context is current when the caller dereferences it.
195    pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
196        let d = driver()?;
197        let cu = d.cu_library_get_global()?;
198        let c_name = CString::new(name).map_err(|_| {
199            crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
200                library: "cuda-driver",
201                symbol: "cuLibraryGetGlobal(name contained a NUL byte)",
202            })
203        })?;
204        let mut dptr = CUdeviceptr(0);
205        let mut bytes: usize = 0;
206        check(unsafe {
207            cu(
208                &mut dptr,
209                &mut bytes,
210                self.inner.handle,
211                c_name.as_ptr() as *const c_char,
212            )
213        })?;
214        Ok((dptr, bytes))
215    }
216
217    /// Raw `CUlibrary` handle. Use with care.
218    #[inline]
219    pub fn as_raw(&self) -> CUlibrary {
220        self.inner.handle
221    }
222}
223
224impl Drop for LibraryInner {
225    fn drop(&mut self) {
226        if let Ok(d) = driver() {
227            if let Ok(cu) = d.cu_library_unload() {
228                let _ = unsafe { cu(self.handle) };
229            }
230        }
231    }
232}
233
234/// A kernel from a [`Library`]. Library kernels are per-library but not
235/// per-context; use [`Kernel::function_for_current_context`] to materialize
236/// a [`Function`] for the active context before launching.
237#[derive(Clone, Debug)]
238pub struct Kernel {
239    handle: CUkernel,
240    _library: Library,
241}
242
243unsafe impl Send for Kernel {}
244unsafe impl Sync for Kernel {}
245
246impl Kernel {
247    /// Raw `CUkernel`. Use with care.
248    #[inline]
249    pub fn as_raw(&self) -> CUkernel {
250        self.handle
251    }
252
253    /// Materialize this library kernel into a [`Function`] suitable for
254    /// the caller's currently-current CUDA context. The returned
255    /// `Function` keeps a clone of the parent library alive.
256    pub fn function_for_current_context(&self) -> Result<Function> {
257        let d = driver()?;
258        let cu = d.cu_kernel_get_function()?;
259        let mut f: CUfunction = core::ptr::null_mut();
260        check(unsafe { cu(&mut f, self.handle) })?;
261        // Reuse the public module-less Function constructor via from_raw.
262        Ok(Function::from_raw_with_library(f, self._library.clone()))
263    }
264
265    /// Query a `CUfunction_attribute` on this kernel for a specific device.
266    pub fn attribute(&self, attr: i32, device: &crate::Device) -> Result<i32> {
267        let d = driver()?;
268        let cu = d.cu_kernel_get_attribute()?;
269        let mut v: core::ffi::c_int = 0;
270        check(unsafe { cu(&mut v, attr, self.handle, device.as_raw()) })?;
271        Ok(v)
272    }
273
274    /// Set a (writable) `CUfunction_attribute` on this kernel for a
275    /// specific device. Common writables: `MAX_DYNAMIC_SHARED_SIZE_BYTES`,
276    /// `PREFERRED_SHARED_MEMORY_CARVEOUT`.
277    pub fn set_attribute(&self, attr: i32, value: i32, device: &crate::Device) -> Result<()> {
278        let d = driver()?;
279        let cu = d.cu_kernel_set_attribute()?;
280        check(unsafe { cu(attr, value, self.handle, device.as_raw()) })
281    }
282
283    /// Return the kernel's demangled name as reported by the driver.
284    pub fn name(&self) -> Result<String> {
285        let d = driver()?;
286        let cu = d.cu_kernel_get_name()?;
287        let mut p: *const core::ffi::c_char = core::ptr::null();
288        check(unsafe { cu(&mut p, self.handle) })?;
289        if p.is_null() {
290            return Ok(String::new());
291        }
292        // SAFETY: driver returns a NUL-terminated static string; we copy to owned.
293        let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
294        Ok(cstr.to_string_lossy().into_owned())
295    }
296
297    /// Set the preferred L1 vs shared-memory cache config for this kernel
298    /// on `device`. Pass one of the
299    /// [`baracuda_cuda_sys::types::CUfunc_cache`] constants (PREFER_NONE,
300    /// PREFER_SHARED, PREFER_L1, PREFER_EQUAL).
301    pub fn set_cache_config(&self, config: u32, device: &crate::Device) -> Result<()> {
302        let d = driver()?;
303        let cu = d.cu_kernel_set_cache_config()?;
304        check(unsafe { cu(self.handle, config as core::ffi::c_int, device.as_raw()) })
305    }
306
307    /// Return `(offset_in_bytes, size_in_bytes)` for the `index`-th
308    /// parameter in the kernel's ABI signature. Useful for reflective
309    /// launches and matching Rust structs to kernel parameters.
310    pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
311        let d = driver()?;
312        let cu = d.cu_kernel_get_param_info()?;
313        let mut off: usize = 0;
314        let mut sz: usize = 0;
315        check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
316        Ok((off, sz))
317    }
318
319    /// Return the library that owns this kernel.
320    pub fn library(&self) -> Library {
321        self._library.clone()
322    }
323}