baracuda_runtime/
module.rs1use core::ffi::{c_char, c_void};
8use std::ffi::CString;
9use std::sync::Arc;
10
11use baracuda_cuda_sys::runtime::{cudaKernel_t, cudaLibrary_t, runtime};
12use baracuda_types::{supports, CudaVersion, Feature};
13
14use crate::error::{check, Error, Result};
15
16#[derive(Clone)]
18pub struct Library {
19 inner: Arc<LibraryInner>,
20}
21
22struct LibraryInner {
23 handle: cudaLibrary_t,
24}
25
26unsafe impl Send for LibraryInner {}
27unsafe impl Sync for LibraryInner {}
28
29impl core::fmt::Debug for LibraryInner {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 f.debug_struct("Library")
32 .field("handle", &self.handle)
33 .finish_non_exhaustive()
34 }
35}
36
37impl core::fmt::Debug for Library {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 self.inner.fmt(f)
40 }
41}
42
43impl Library {
44 pub fn load_raw(image: &[u8]) -> Result<Self> {
46 let installed = crate::init::driver_version()?;
47 if !supports(installed, Feature::LibraryManagement) {
48 return Err(Error::FeatureNotSupported {
49 api: "cudaLibraryLoadData",
50 since: Feature::LibraryManagement.required_version(),
51 });
52 }
53
54 let r = runtime()?;
55 let cu = r.cuda_library_load_data()?;
56 let mut lib: cudaLibrary_t = core::ptr::null_mut();
57 check(unsafe {
58 cu(
59 &mut lib,
60 image.as_ptr() as *const c_void,
61 core::ptr::null_mut(), core::ptr::null_mut(), 0, core::ptr::null_mut(), core::ptr::null_mut(), 0, )
68 })?;
69 Ok(Self {
70 inner: Arc::new(LibraryInner { handle: lib }),
71 })
72 }
73
74 pub fn load_ptx(ptx_source: &str) -> Result<Self> {
76 let c_src = CString::new(ptx_source).map_err(|_| {
77 Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
78 library: "cuda-runtime",
79 symbol: "cudaLibraryLoadData(PTX input contained a NUL byte)",
80 })
81 })?;
82 Self::load_raw(c_src.as_bytes_with_nul())
83 }
84
85 pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
87 let r = runtime()?;
88 let cu = r.cuda_library_get_kernel()?;
89 let c_name = CString::new(name).map_err(|_| {
90 Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
91 library: "cuda-runtime",
92 symbol: "cudaLibraryGetKernel(kernel name contained a NUL byte)",
93 })
94 })?;
95 let mut kernel: cudaKernel_t = core::ptr::null_mut();
96 check(unsafe {
97 cu(
98 &mut kernel,
99 self.inner.handle,
100 c_name.as_ptr() as *const c_char,
101 )
102 })?;
103 Ok(Kernel {
104 handle: kernel,
105 _library: self.clone(),
106 })
107 }
108
109 #[inline]
111 pub fn as_raw(&self) -> cudaLibrary_t {
112 self.inner.handle
113 }
114}
115
116impl Drop for LibraryInner {
117 fn drop(&mut self) {
118 if let Ok(r) = runtime() {
119 if let Ok(cu) = r.cuda_library_unload() {
120 let _ = unsafe { cu(self.handle) };
121 }
122 }
123 }
124}
125
126#[derive(Clone, Debug)]
128pub struct Kernel {
129 handle: cudaKernel_t,
130 _library: Library,
132}
133
134unsafe impl Send for Kernel {}
135unsafe impl Sync for Kernel {}
136
137impl Kernel {
138 #[inline]
140 pub fn as_raw(&self) -> cudaKernel_t {
141 self.handle
142 }
143
144 #[inline]
149 pub fn as_launch_ptr(&self) -> *const c_void {
150 self.handle as *const c_void
151 }
152
153 pub fn max_active_blocks_per_multiprocessor(
157 &self,
158 block_size: i32,
159 dynamic_smem_bytes: usize,
160 ) -> Result<i32> {
161 let r = runtime()?;
162 let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor()?;
163 let mut n: core::ffi::c_int = 0;
164 check(unsafe { cu(&mut n, self.as_launch_ptr(), block_size, dynamic_smem_bytes) })?;
165 Ok(n)
166 }
167
168 pub fn max_active_blocks_per_multiprocessor_with_flags(
173 &self,
174 block_size: i32,
175 dynamic_smem_bytes: usize,
176 flags: u32,
177 ) -> Result<i32> {
178 let r = runtime()?;
179 let cu = r.cuda_occupancy_max_active_blocks_per_multiprocessor_with_flags()?;
180 let mut n: core::ffi::c_int = 0;
181 check(unsafe {
182 cu(
183 &mut n,
184 self.as_launch_ptr(),
185 block_size,
186 dynamic_smem_bytes,
187 flags,
188 )
189 })?;
190 Ok(n)
191 }
192
193 pub fn available_dynamic_smem_per_block(
197 &self,
198 num_blocks: i32,
199 block_size: i32,
200 ) -> Result<usize> {
201 let r = runtime()?;
202 let cu = r.cuda_occupancy_available_dynamic_smem_per_block()?;
203 let mut n: usize = 0;
204 check(unsafe { cu(&mut n, self.as_launch_ptr(), num_blocks, block_size) })?;
205 Ok(n)
206 }
207
208 pub fn set_attribute(&self, attr: i32, value: i32) -> Result<()> {
211 let r = runtime()?;
212 let cu = r.cuda_func_set_attribute()?;
213 check(unsafe { cu(self.as_launch_ptr(), attr, value) })
214 }
215}
216
217#[allow(dead_code)]
219fn require_library_management(installed: CudaVersion) -> Result<()> {
220 if supports(installed, Feature::LibraryManagement) {
221 Ok(())
222 } else {
223 Err(Error::FeatureNotSupported {
224 api: "cudaLibraryLoadData",
225 since: Feature::LibraryManagement.required_version(),
226 })
227 }
228}