1use core::ffi::{c_char, c_void};
13use std::ffi::CString;
14use std::sync::Arc;
15
16use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUkernel, CUlibrary};
17
18use crate::error::{check, Result};
19use crate::module::Function;
20
21#[derive(Clone)]
23pub struct Library {
24 inner: Arc<LibraryInner>,
25}
26
27struct LibraryInner {
28 handle: CUlibrary,
29}
30
31unsafe impl Send for LibraryInner {}
32unsafe impl Sync for LibraryInner {}
33
34impl core::fmt::Debug for LibraryInner {
35 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36 f.debug_struct("Library")
37 .field("handle", &self.handle)
38 .finish_non_exhaustive()
39 }
40}
41
42impl core::fmt::Debug for Library {
43 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44 self.inner.fmt(f)
45 }
46}
47
48impl Library {
49 pub fn load_raw(image: &[u8]) -> Result<Self> {
52 let d = driver()?;
53 let cu = d.cu_library_load_data()?;
54 let mut lib: CUlibrary = core::ptr::null_mut();
55 check(unsafe {
56 cu(
57 &mut lib,
58 image.as_ptr() as *const c_void,
59 core::ptr::null_mut(), core::ptr::null_mut(), 0, core::ptr::null_mut(), core::ptr::null_mut(), 0, )
66 })?;
67 Ok(Self {
68 inner: Arc::new(LibraryInner { handle: lib }),
69 })
70 }
71
72 pub fn load_ptx(ptx: &str) -> Result<Self> {
74 let c_src = CString::new(ptx).map_err(|_| {
75 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
76 library: "cuda-driver",
77 symbol: "cuLibraryLoadData(PTX contained a NUL byte)",
78 })
79 })?;
80 Self::load_raw(c_src.as_bytes_with_nul())
81 }
82
83 pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
85 let d = driver()?;
86 let cu = d.cu_library_get_kernel()?;
87 let c_name = CString::new(name).map_err(|_| {
88 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
89 library: "cuda-driver",
90 symbol: "cuLibraryGetKernel(name contained a NUL byte)",
91 })
92 })?;
93 let mut kernel: CUkernel = core::ptr::null_mut();
94 check(unsafe {
95 cu(
96 &mut kernel,
97 self.inner.handle,
98 c_name.as_ptr() as *const c_char,
99 )
100 })?;
101 Ok(Kernel {
102 handle: kernel,
103 _library: self.clone(),
104 })
105 }
106
107 pub fn kernel_count(&self) -> Result<u32> {
109 let d = driver()?;
110 let cu = d.cu_library_get_kernel_count()?;
111 let mut n: core::ffi::c_uint = 0;
112 check(unsafe { cu(&mut n, self.inner.handle) })?;
113 Ok(n)
114 }
115
116 pub fn enumerate_kernels(&self) -> Result<Vec<Kernel>> {
119 let d = driver()?;
120 let n = self.kernel_count()?;
121 let cu = d.cu_library_enumerate_kernels()?;
122 let mut raw: Vec<baracuda_cuda_sys::CUkernel> = vec![core::ptr::null_mut(); n as usize];
123 if n > 0 {
124 check(unsafe { cu(raw.as_mut_ptr(), n, self.inner.handle) })?;
125 }
126 Ok(raw
127 .into_iter()
128 .map(|h| Kernel {
129 handle: h,
130 _library: self.clone(),
131 })
132 .collect())
133 }
134
135 pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
139 let d = driver()?;
140 let cu = d.cu_library_get_module()?;
141 let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
142 check(unsafe { cu(&mut m, self.inner.handle) })?;
143 Ok(m)
144 }
145
146 pub fn get_managed(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
148 let d = driver()?;
149 let cu = d.cu_library_get_managed()?;
150 let c_name = CString::new(name).map_err(|_| {
151 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
152 library: "cuda-driver",
153 symbol: "cuLibraryGetManaged(name contained a NUL byte)",
154 })
155 })?;
156 let mut dptr = CUdeviceptr(0);
157 let mut bytes: usize = 0;
158 check(unsafe {
159 cu(
160 &mut dptr,
161 &mut bytes,
162 self.inner.handle,
163 c_name.as_ptr() as *const c_char,
164 )
165 })?;
166 Ok((dptr, bytes))
167 }
168
169 pub fn get_unified_function(&self, name: &str) -> Result<*mut core::ffi::c_void> {
173 let d = driver()?;
174 let cu = d.cu_library_get_unified_function()?;
175 let c_name = CString::new(name).map_err(|_| {
176 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
177 library: "cuda-driver",
178 symbol: "cuLibraryGetUnifiedFunction(name contained a NUL byte)",
179 })
180 })?;
181 let mut fptr: *mut core::ffi::c_void = core::ptr::null_mut();
182 check(unsafe {
183 cu(
184 &mut fptr,
185 self.inner.handle,
186 c_name.as_ptr() as *const c_char,
187 )
188 })?;
189 Ok(fptr)
190 }
191
192 pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
196 let d = driver()?;
197 let cu = d.cu_library_get_global()?;
198 let c_name = CString::new(name).map_err(|_| {
199 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
200 library: "cuda-driver",
201 symbol: "cuLibraryGetGlobal(name contained a NUL byte)",
202 })
203 })?;
204 let mut dptr = CUdeviceptr(0);
205 let mut bytes: usize = 0;
206 check(unsafe {
207 cu(
208 &mut dptr,
209 &mut bytes,
210 self.inner.handle,
211 c_name.as_ptr() as *const c_char,
212 )
213 })?;
214 Ok((dptr, bytes))
215 }
216
217 #[inline]
219 pub fn as_raw(&self) -> CUlibrary {
220 self.inner.handle
221 }
222}
223
224impl Drop for LibraryInner {
225 fn drop(&mut self) {
226 if let Ok(d) = driver() {
227 if let Ok(cu) = d.cu_library_unload() {
228 let _ = unsafe { cu(self.handle) };
229 }
230 }
231 }
232}
233
234#[derive(Clone, Debug)]
238pub struct Kernel {
239 handle: CUkernel,
240 _library: Library,
241}
242
243unsafe impl Send for Kernel {}
244unsafe impl Sync for Kernel {}
245
246impl Kernel {
247 #[inline]
249 pub fn as_raw(&self) -> CUkernel {
250 self.handle
251 }
252
253 pub fn function_for_current_context(&self) -> Result<Function> {
257 let d = driver()?;
258 let cu = d.cu_kernel_get_function()?;
259 let mut f: CUfunction = core::ptr::null_mut();
260 check(unsafe { cu(&mut f, self.handle) })?;
261 Ok(Function::from_raw_with_library(f, self._library.clone()))
263 }
264
265 pub fn attribute(&self, attr: i32, device: &crate::Device) -> Result<i32> {
267 let d = driver()?;
268 let cu = d.cu_kernel_get_attribute()?;
269 let mut v: core::ffi::c_int = 0;
270 check(unsafe { cu(&mut v, attr, self.handle, device.as_raw()) })?;
271 Ok(v)
272 }
273
274 pub fn set_attribute(&self, attr: i32, value: i32, device: &crate::Device) -> Result<()> {
278 let d = driver()?;
279 let cu = d.cu_kernel_set_attribute()?;
280 check(unsafe { cu(attr, value, self.handle, device.as_raw()) })
281 }
282
283 pub fn name(&self) -> Result<String> {
285 let d = driver()?;
286 let cu = d.cu_kernel_get_name()?;
287 let mut p: *const core::ffi::c_char = core::ptr::null();
288 check(unsafe { cu(&mut p, self.handle) })?;
289 if p.is_null() {
290 return Ok(String::new());
291 }
292 let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
294 Ok(cstr.to_string_lossy().into_owned())
295 }
296
297 pub fn set_cache_config(&self, config: u32, device: &crate::Device) -> Result<()> {
302 let d = driver()?;
303 let cu = d.cu_kernel_set_cache_config()?;
304 check(unsafe { cu(self.handle, config as core::ffi::c_int, device.as_raw()) })
305 }
306
307 pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
311 let d = driver()?;
312 let cu = d.cu_kernel_get_param_info()?;
313 let mut off: usize = 0;
314 let mut sz: usize = 0;
315 check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
316 Ok((off, sz))
317 }
318
319 pub fn library(&self) -> Library {
321 self._library.clone()
322 }
323}