Skip to main content

baracuda_runtime/
module.rs

1//! Runtime-API library + kernel loading.
2//!
3//! Uses `cudaLibraryLoadData` (CUDA 12.0+) to load PTX / CUBIN at runtime,
4//! exposing kernels as [`Kernel`] handles. On CUDA < 12 this returns
5//! [`crate::Error::FeatureNotSupported`].
6
7use core::ffi::{c_char, c_void};
8use std::ffi::CString;
9use std::sync::Arc;
10
11use baracuda_cuda_sys::runtime::{cudaKernel_t, cudaLibrary_t, runtime};
12use baracuda_types::{supports, CudaVersion, Feature};
13
14use crate::error::{check, Error, Result};
15
16/// A loaded CUDA library (CUDA 12.0+).
17#[derive(Clone)]
18pub struct Library {
19    inner: Arc<LibraryInner>,
20}
21
22struct LibraryInner {
23    handle: cudaLibrary_t,
24}
25
26unsafe impl Send for LibraryInner {}
27unsafe impl Sync for LibraryInner {}
28
29impl core::fmt::Debug for LibraryInner {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        f.debug_struct("Library")
32            .field("handle", &self.handle)
33            .finish_non_exhaustive()
34    }
35}
36
37impl core::fmt::Debug for Library {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        self.inner.fmt(f)
40    }
41}
42
43impl Library {
44    /// Load a library from a raw binary image (CUBIN, fatbin, or null-terminated PTX).
45    pub fn load_raw(image: &[u8]) -> Result<Self> {
46        let installed = crate::init::driver_version()?;
47        if !supports(installed, Feature::LibraryManagement) {
48            return Err(Error::FeatureNotSupported {
49                api: "cudaLibraryLoadData",
50                since: Feature::LibraryManagement.required_version(),
51            });
52        }
53
54        let r = runtime()?;
55        let cu = r.cuda_library_load_data()?;
56        let mut lib: cudaLibrary_t = core::ptr::null_mut();
57        check(unsafe {
58            cu(
59                &mut lib,
60                image.as_ptr() as *const c_void,
61                core::ptr::null_mut(), // jit_options
62                core::ptr::null_mut(), // jit_option_values
63                0,                     // num_jit_options
64                core::ptr::null_mut(), // library_options
65                core::ptr::null_mut(), // library_option_values
66                0,                     // num_library_options
67            )
68        })?;
69        Ok(Self {
70            inner: Arc::new(LibraryInner { handle: lib }),
71        })
72    }
73
74    /// Load a library from a PTX source string.
75    pub fn load_ptx(ptx_source: &str) -> Result<Self> {
76        let c_src = CString::new(ptx_source).map_err(|_| {
77            Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
78                library: "cuda-runtime",
79                symbol: "cudaLibraryLoadData(PTX input contained a NUL byte)",
80            })
81        })?;
82        Self::load_raw(c_src.as_bytes_with_nul())
83    }
84
85    /// Look up a kernel entry point by name.
86    pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
87        let r = runtime()?;
88        let cu = r.cuda_library_get_kernel()?;
89        let c_name = CString::new(name).map_err(|_| {
90            Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
91                library: "cuda-runtime",
92                symbol: "cudaLibraryGetKernel(kernel name contained a NUL byte)",
93            })
94        })?;
95        let mut kernel: cudaKernel_t = core::ptr::null_mut();
96        check(unsafe {
97            cu(
98                &mut kernel,
99                self.inner.handle,
100                c_name.as_ptr() as *const c_char,
101            )
102        })?;
103        Ok(Kernel {
104            handle: kernel,
105            _library: self.clone(),
106        })
107    }
108
109    /// Raw `cudaLibrary_t`. Use with care.
110    #[inline]
111    pub fn as_raw(&self) -> cudaLibrary_t {
112        self.inner.handle
113    }
114}
115
116impl Drop for LibraryInner {
117    fn drop(&mut self) {
118        if let Ok(r) = runtime() {
119            if let Ok(cu) = r.cuda_library_unload() {
120                let _ = unsafe { cu(self.handle) };
121            }
122        }
123    }
124}
125
126/// A kernel entry point inside a [`Library`].
127#[derive(Clone, Debug)]
128pub struct Kernel {
129    handle: cudaKernel_t,
130    // Keeps the library alive for the lifetime of the kernel.
131    _library: Library,
132}
133
134unsafe impl Send for Kernel {}
135unsafe impl Sync for Kernel {}
136
137impl Kernel {
138    /// Raw `cudaKernel_t`. Use with care.
139    #[inline]
140    pub fn as_raw(&self) -> cudaKernel_t {
141        self.handle
142    }
143
144    /// Returns the raw kernel handle cast to a `const void*` — the form
145    /// expected by `cudaLaunchKernel`. Library-loaded kernels can be
146    /// launched through the standard runtime launch function by
147    /// passing this pointer.
148    #[inline]
149    pub fn as_launch_ptr(&self) -> *const c_void {
150        self.handle as *const c_void
151    }
152
153    /// `cudaOccupancyMaxActiveBlocksPerMultiprocessor` — how many blocks
154    /// of size `block_size` can run concurrently per SM given
155    /// `dynamic_smem_bytes` of dynamic shared memory.
156    pub fn max_active_blocks_per_multiprocessor(
157        &self,
158        block_size: i32,
159        dynamic_smem_bytes: usize,
160    ) -> Result<i32> {
161        let r = runtime()?;
162        let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor()?;
163        let mut n: core::ffi::c_int = 0;
164        check(unsafe { cu(&mut n, self.as_launch_ptr(), block_size, dynamic_smem_bytes) })?;
165        Ok(n)
166    }
167
168    /// `cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags` — same
169    /// as [`Self::max_active_blocks_per_multiprocessor`] but accepting
170    /// occupancy-flag bits (0 = default, 1 = disable shared-memory
171    /// carveout adjustment).
172    pub fn max_active_blocks_per_multiprocessor_with_flags(
173        &self,
174        block_size: i32,
175        dynamic_smem_bytes: usize,
176        flags: u32,
177    ) -> Result<i32> {
178        let r = runtime()?;
179        let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor_with_flags()?;
180        let mut n: core::ffi::c_int = 0;
181        check(unsafe {
182            cu(
183                &mut n,
184                self.as_launch_ptr(),
185                block_size,
186                dynamic_smem_bytes,
187                flags,
188            )
189        })?;
190        Ok(n)
191    }
192
193    /// `cudaOccupancyAvailableDynamicSMemPerBlock` — how much dynamic
194    /// shared memory can each block use if `num_blocks` of `block_size`
195    /// threads run concurrently on each SM.
196    pub fn available_dynamic_smem_per_block(
197        &self,
198        num_blocks: i32,
199        block_size: i32,
200    ) -> Result<usize> {
201        let r = runtime()?;
202        let cu = r.cuda_occupancy_available_dynamic_smem_per_block()?;
203        let mut n: usize = 0;
204        check(unsafe { cu(&mut n, self.as_launch_ptr(), num_blocks, block_size) })?;
205        Ok(n)
206    }
207
208    /// Set a writable `cudaFuncAttribute`. The common one is
209    /// `cudaFuncAttributeMaxDynamicSharedMemorySize = 8`.
210    pub fn set_attribute(&self, attr: i32, value: i32) -> Result<()> {
211        let r = runtime()?;
212        let cu = r.cuda_func_set_attribute()?;
213        check(unsafe { cu(self.as_launch_ptr(), attr, value) })
214    }
215}
216
217/// Unused helper to check `CudaVersion` availability from outside this module.
218#[allow(dead_code)]
219fn require_library_management(installed: CudaVersion) -> Result<()> {
220    if supports(installed, Feature::LibraryManagement) {
221        Ok(())
222    } else {
223        Err(Error::FeatureNotSupported {
224            api: "cudaLibraryLoadData",
225            since: Feature::LibraryManagement.required_version(),
226        })
227    }
228}