1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
//! Compiled module loading (PTX, CUBIN, fatbin) and kernel entry-point lookup.
use core::ffi::{c_char, c_void};
use std::ffi::CString;
use std::sync::Arc;
use baracuda_cuda_sys::{driver, CUdeviceptr, CUfunction, CUmodule};
use crate::context::Context;
use crate::error::{check, Result};
/// A loaded CUDA module (e.g. compiled PTX).
#[derive(Clone)]
pub struct Module {
inner: Arc<ModuleInner>,
}
struct ModuleInner {
handle: CUmodule,
context: Context,
}
unsafe impl Send for ModuleInner {}
unsafe impl Sync for ModuleInner {}
impl core::fmt::Debug for ModuleInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Module")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for Module {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl Module {
/// Load a module from a raw binary image (CUBIN, fatbin, or PTX text with a trailing NUL).
///
/// For PTX, the bytes must be a null-terminated UTF-8 string matching the
/// `ptx` file on disk. [`Module::load_ptx`] is a convenience wrapper that
/// adds the NUL for you.
pub fn load_raw(context: &Context, image: &[u8]) -> Result<Self> {
context.set_current()?;
let d = driver()?;
let cu = d.cu_module_load_data()?;
let mut module: CUmodule = core::ptr::null_mut();
// SAFETY: `image.as_ptr()` is valid for reads of the image bytes.
check(unsafe { cu(&mut module, image.as_ptr() as *const c_void) })?;
Ok(Self {
inner: Arc::new(ModuleInner {
handle: module,
context: context.clone(),
}),
})
}
/// Load a module from a PTX source string.
pub fn load_ptx(context: &Context, ptx_source: &str) -> Result<Self> {
// cuModuleLoadData expects a null-terminated buffer for PTX.
let c_src = CString::new(ptx_source).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleLoadData(PTX input contained a NUL byte)",
})
})?;
Self::load_raw(context, c_src.as_bytes_with_nul())
}
/// Look up a `__device__` global variable by name. Returns
/// `(device_ptr, size_in_bytes)`.
pub fn get_global(&self, name: &str) -> Result<(CUdeviceptr, usize)> {
let d = driver()?;
let cu = d.cu_module_get_global()?;
let c_name = CString::new(name).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleGetGlobal(name contained a NUL byte)",
})
})?;
let mut dptr = CUdeviceptr(0);
let mut bytes: usize = 0;
check(unsafe {
cu(
&mut dptr,
&mut bytes,
self.inner.handle,
c_name.as_ptr() as *const c_char,
)
})?;
Ok((dptr, bytes))
}
/// Look up a kernel entry point by name.
pub fn get_function(&self, name: &str) -> Result<Function> {
let d = driver()?;
let cu = d.cu_module_get_function()?;
let c_name = CString::new(name).map_err(|_| {
crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-driver",
symbol: "cuModuleGetFunction(kernel name contained a NUL byte)",
})
})?;
let mut func: CUfunction = core::ptr::null_mut();
// SAFETY: `func` writable; `self.inner.handle` owned by this Arc;
// `c_name.as_ptr()` is null-terminated.
check(unsafe {
cu(
&mut func,
self.inner.handle,
c_name.as_ptr() as *const c_char,
)
})?;
Ok(Function {
handle: func,
_owner: FunctionOwner::Module(self.clone()),
})
}
/// Raw `CUmodule` handle. Use with care.
#[inline]
pub fn as_raw(&self) -> CUmodule {
self.inner.handle
}
/// Return the current process-wide module loading mode (eager vs. lazy).
/// Compare against
/// [`baracuda_cuda_sys::types::CUmoduleLoadingMode`] constants.
pub fn loading_mode() -> Result<i32> {
let d = driver()?;
let cu = d.cu_module_get_loading_mode()?;
let mut mode: core::ffi::c_int = 0;
check(unsafe { cu(&mut mode) })?;
Ok(mode)
}
/// Load a module from a raw image with extra JIT compiler options —
/// the typical use is capturing the JIT log when a PTX module
/// fails to compile. `options` and `option_values` are parallel
/// arrays whose entries follow the `CUjit_option` ABI (see the
/// CUDA driver reference). For PTX, the bytes must be a
/// null-terminated UTF-8 string.
///
/// # Safety
///
/// Each `option_value` must point at a value of the type the
/// matching `CUjit_option` expects (some are pointers, some are
/// integers cast to `*mut c_void`). The arrays must have the same
/// length.
pub unsafe fn load_data_ex(
context: &Context,
image: &[u8],
options: &mut [i32],
option_values: &mut [*mut core::ffi::c_void],
) -> Result<Self> {
assert_eq!(
options.len(),
option_values.len(),
"load_data_ex: options and option_values must have the same length"
);
context.set_current()?;
let d = driver()?;
let cu = d.cu_module_load_data_ex()?;
let mut module: CUmodule = core::ptr::null_mut();
check(cu(
&mut module,
image.as_ptr() as *const c_void,
options.len() as core::ffi::c_uint,
options.as_mut_ptr(),
option_values.as_mut_ptr(),
))?;
Ok(Self {
inner: Arc::new(ModuleInner {
handle: module,
context: context.clone(),
}),
})
}
/// The [`Context`] this module was loaded into.
#[inline]
pub fn context(&self) -> &Context {
&self.inner.context
}
}
impl Drop for ModuleInner {
fn drop(&mut self) {
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_module_unload() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
/// A kernel entry point — either inside a [`Module`] (classic
/// Driver API) or materialized from a [`crate::library::Kernel`] (CUDA
/// 12.0+ library API). Either way it keeps the parent alive via an Arc
/// so the kernel stays valid for as long as any [`Function`] handle exists.
#[derive(Clone, Debug)]
pub struct Function {
handle: CUfunction,
_owner: FunctionOwner,
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
enum FunctionOwner {
/// Owned by a `Module` (classic Driver API flow).
Module(Module),
/// Owned by a `Library` (CUDA 12.0+ cuLibrary flow).
Library(crate::library::Library),
}
impl Function {
/// Construct from an already-resolved `CUfunction` plus the parent
/// library that owns it. Intended for `library::Kernel`'s
/// `function_for_current_context`.
pub(crate) fn from_raw_with_library(
handle: CUfunction,
library: crate::library::Library,
) -> Self {
Self {
handle,
_owner: FunctionOwner::Library(library),
}
}
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
impl Function {
/// Raw `CUfunction`. Use with care.
#[inline]
pub fn as_raw(&self) -> CUfunction {
self.handle
}
/// The [`Module`] this kernel lives in, if it was obtained through
/// `Module::get_function`. Returns `None` for kernels materialized
/// from a `library::Kernel`.
#[inline]
pub fn module(&self) -> Option<&Module> {
match &self._owner {
FunctionOwner::Module(m) => Some(m),
FunctionOwner::Library(_) => None,
}
}
/// Query a kernel attribute (see
/// [`baracuda_cuda_sys::types::CUfunction_attribute`]).
pub fn get_attribute(&self, attribute: i32) -> Result<i32> {
let d = driver()?;
let cu = d.cu_func_get_attribute()?;
let mut v: core::ffi::c_int = 0;
check(unsafe { cu(&mut v, attribute, self.handle) })?;
Ok(v)
}
/// Return the demangled kernel name reported by the driver.
pub fn name(&self) -> Result<String> {
let d = driver()?;
let cu = d.cu_func_get_name()?;
let mut p: *const core::ffi::c_char = core::ptr::null();
check(unsafe { cu(&mut p, self.handle) })?;
if p.is_null() {
return Ok(String::new());
}
let cstr = unsafe { core::ffi::CStr::from_ptr(p) };
Ok(cstr.to_string_lossy().into_owned())
}
/// Return `(offset_in_bytes, size_in_bytes)` for the `index`-th
/// parameter in this function's ABI signature.
pub fn param_info(&self, index: usize) -> Result<(usize, usize)> {
let d = driver()?;
let cu = d.cu_func_get_param_info()?;
let mut off: usize = 0;
let mut sz: usize = 0;
check(unsafe { cu(self.handle, index, &mut off, &mut sz) })?;
Ok((off, sz))
}
/// Return the raw `CUmodule` this function was loaded from, if any.
pub fn module_raw(&self) -> Result<baracuda_cuda_sys::CUmodule> {
let d = driver()?;
let cu = d.cu_func_get_module()?;
let mut m: baracuda_cuda_sys::CUmodule = core::ptr::null_mut();
check(unsafe { cu(&mut m, self.handle) })?;
Ok(m)
}
/// Set a kernel attribute. Only a subset is writable (notably
/// `MAX_DYNAMIC_SHARED_SIZE_BYTES` and
/// `PREFERRED_SHARED_MEMORY_CARVEOUT`).
pub fn set_attribute(&self, attribute: i32, value: i32) -> Result<()> {
let d = driver()?;
let cu = d.cu_func_set_attribute()?;
check(unsafe { cu(self.handle, attribute, value) })
}
// Convenience named accessors for the most-read attributes.
/// Maximum threads per block this kernel supports on the current device.
pub fn max_threads_per_block(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::MAX_THREADS_PER_BLOCK)
}
/// Size of per-block statically-allocated shared memory (bytes).
pub fn shared_size_bytes(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::SHARED_SIZE_BYTES)
}
/// Number of registers used per thread.
pub fn num_regs(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::NUM_REGS)
}
/// Per-thread local-memory footprint (bytes).
pub fn local_size_bytes(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::LOCAL_SIZE_BYTES)
}
/// PTX version this kernel was compiled from, as `major*10 + minor`.
pub fn ptx_version(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::PTX_VERSION)
}
/// SM-architecture this kernel was compiled for, as `major*10 + minor`.
pub fn binary_version(&self) -> Result<i32> {
use baracuda_cuda_sys::types::CUfunction_attribute as A;
self.get_attribute(A::BINARY_VERSION)
}
}