cuda_oxide/
module.rs

1use std::{
2    borrow::Cow,
3    ffi::{c_void, CString},
4    ptr::null_mut,
5    rc::Rc,
6};
7
8use crate::*;
9
10// Debug must not be derived, see comment on info_buf
11/// A CUDA JIT linker context, used to compile device-specific kernels from PTX assembly or link together several precompiled binaries
12pub struct Linker<'a> {
13    inner: *mut sys::CUlinkState_st,
14    info_buf: Vec<u8>, // both info_buf and errors_buf contain uninitialized memory! they should always be NUL terminated strings
15    errors_buf: Vec<u8>,
16    handle: Rc<Handle<'a>>,
17}
18
19/// The type of input to the linker
20#[derive(Clone, Copy, Debug, PartialEq)]
21pub enum LinkerInputType {
22    Cubin,
23    Ptx,
24    Fatbin,
25}
26
27/// Linker options for CUDA, can generally just be defaulted.
28#[derive(Clone, Copy, Debug)]
29pub struct LinkerOptions {
30    /// Add debug symbols to emitted binary
31    pub debug_info: bool,
32    /// Collect INFO logs from CUDA build/link, up to 16 MB, then emit to STDOUT
33    pub log_info: bool,
34    /// Collect ERROR logs from CUDA build/link, up to 16 MB, then emit to STDOUT
35    pub log_errors: bool,
36    /// Increase log verbosity
37    pub verbose_logs: bool,
38}
39
40impl Default for LinkerOptions {
41    fn default() -> Self {
42        LinkerOptions {
43            debug_info: false,
44            log_info: true,
45            log_errors: true,
46            verbose_logs: false,
47        }
48    }
49}
50
51impl<'a> Linker<'a> {
52    /// Creates a new [`Linker`] for the given context handle, compute capability, and linker options.
53    pub fn new(
54        handle: &Rc<Handle<'a>>,
55        compute_capability: CudaVersion,
56        options: LinkerOptions,
57    ) -> CudaResult<Self> {
58        let mut linker = Linker {
59            inner: null_mut(),
60            info_buf: if options.log_info {
61                let mut buf = Vec::with_capacity(16 * 1024 * 1024);
62                buf.push(0);
63                unsafe { buf.set_len(buf.capacity()) };
64                buf
65            } else {
66                vec![]
67            },
68            errors_buf: if options.log_errors {
69                let mut buf = Vec::with_capacity(16 * 1024 * 1024);
70                buf.push(0);
71                unsafe { buf.set_len(buf.capacity()) };
72                buf
73            } else {
74                vec![]
75            },
76            handle: handle.clone(),
77        };
78        let log_verbose = if options.verbose_logs { 1u32 } else { 0u32 };
79        let debug_info = if options.debug_info { 1u32 } else { 0u32 };
80
81        let mut options = [
82            sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER,
83            sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
84            sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER,
85            sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
86            sys::CUjit_option_enum_CU_JIT_TARGET,
87            sys::CUjit_option_enum_CU_JIT_LOG_VERBOSE,
88            sys::CUjit_option_enum_CU_JIT_GENERATE_DEBUG_INFO,
89        ];
90        let target = match (compute_capability.major, compute_capability.minor) {
91            (2, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_20,
92            (2, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_21,
93            (3, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_30,
94            (3, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_32,
95            (3, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_35,
96            (3, 7) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_37,
97            (5, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_50,
98            (5, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_52,
99            (5, 3) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_53,
100            (6, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_60,
101            (6, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_61,
102            (6, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_62,
103            (7, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_70,
104            (7, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_72,
105            (7, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_75,
106            (8, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_80,
107            (8, 6) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_86,
108            (_, _) => return Err(ErrorCode::UnsupportedPtxVersion),
109        };
110
111        let mut values = [
112            linker.info_buf.as_mut_ptr() as *mut c_void,
113            linker.info_buf.len() as u32 as u64 as *mut c_void,
114            linker.errors_buf.as_mut_ptr() as *mut c_void,
115            linker.errors_buf.len() as u32 as u64 as *mut c_void,
116            target as u64 as *mut c_void,
117            log_verbose as u64 as *mut c_void,
118            debug_info as u64 as *mut c_void,
119        ];
120        cuda_error(unsafe {
121            sys::cuLinkCreate_v2(
122                options.len() as u32,
123                options.as_mut_ptr(),
124                values.as_mut_ptr(),
125                &mut linker.inner as *mut _,
126            )
127        })?;
128        Ok(linker)
129    }
130
131    fn emit_logs(&self) {
132        let info_string = self.info_buf.iter().position(|x| *x == 0);
133        if let Some(info_string) = info_string {
134            let info_string = String::from_utf8_lossy(&self.info_buf[..info_string]);
135            if !info_string.is_empty() {
136                info_string.split('\n').for_each(|line| {
137                    println!("[CUDA INFO] {}", line);
138                });
139            }
140        }
141        let error_string = self.errors_buf.iter().position(|x| *x == 0);
142        if let Some(error_string) = error_string {
143            let error_string = String::from_utf8_lossy(&self.errors_buf[..error_string]);
144            if !error_string.is_empty() {
145                error_string.split('\n').for_each(|line| {
146                    println!("[CUDA ERROR] {}", line);
147                });
148            }
149        }
150    }
151
152    /// Add an input file to the linker context. `name` is only used for logs
153    pub fn add(self, name: &str, format: LinkerInputType, in_data: &[u8]) -> CudaResult<Self> {
154        let mut data = Cow::Borrowed(in_data);
155        if format == LinkerInputType::Ptx {
156            let mut new_data = Vec::with_capacity(in_data.len() + 1);
157            new_data.extend_from_slice(in_data);
158            new_data.push(0);
159            data = Cow::Owned(new_data)
160        }
161
162        let format = match format {
163            LinkerInputType::Cubin => sys::CUjitInputType_enum_CU_JIT_INPUT_CUBIN,
164            LinkerInputType::Ptx => sys::CUjitInputType_enum_CU_JIT_INPUT_PTX,
165            LinkerInputType::Fatbin => sys::CUjitInputType_enum_CU_JIT_INPUT_FATBINARY,
166        };
167        let name = CString::new(name).unwrap();
168
169        let out = cuda_error(unsafe {
170            sys::cuLinkAddData_v2(
171                self.inner,
172                format,
173                data.as_ptr() as *mut u8 as *mut c_void,
174                data.len() as sys::size_t,
175                name.as_ptr(),
176                0,
177                null_mut(),
178                null_mut(),
179            )
180        });
181
182        if let Err(e) = out {
183            self.emit_logs();
184            return Err(e);
185        }
186        Ok(self)
187    }
188
189    /// Emit the cubin assembly binary. You probably want [`Linker::build_module`]
190    pub fn build(&self) -> CudaResult<&[u8]> {
191        let mut cubin_out: *mut c_void = null_mut();
192        let mut size_out: sys::size_t = 0;
193        let out = cuda_error(unsafe {
194            sys::cuLinkComplete(
195                self.inner,
196                &mut cubin_out as *mut *mut c_void,
197                &mut size_out as *mut sys::size_t,
198            )
199        });
200        self.emit_logs();
201        if let Err(e) = out {
202            return Err(e);
203        }
204        Ok(unsafe { std::slice::from_raw_parts(cubin_out as *const u8, size_out as usize) })
205    }
206
207    /// Build a CUDA module from this [`Linker`].
208    pub fn build_module(&self) -> CudaResult<Module<'a>> {
209        let built = self.build()?;
210        Module::load(&self.handle, built)
211    }
212}
213
214impl<'a> Drop for Linker<'a> {
215    fn drop(&mut self) {
216        if let Err(e) = cuda_error(unsafe { sys::cuLinkDestroy(self.inner) }) {
217            eprintln!("CUDA: failed to destroy cuda linker state: {:?}", e);
218        }
219    }
220}
221
222/// A loaded CUDA module
223pub struct Module<'a> {
224    handle: Rc<Handle<'a>>,
225    inner: *mut sys::CUmod_st,
226}
227
228impl<'a> Module<'a> {
229    /// Takes a raw CUDA kernel image and loads the corresponding module module into the current context.
230    /// The pointer can be a cubin or PTX or fatbin file as a NULL-terminated text string
231    pub fn load(handle: &Rc<Handle<'a>>, module: &[u8]) -> CudaResult<Self> {
232        let mut inner = null_mut();
233        cuda_error(unsafe {
234            sys::cuModuleLoadData(&mut inner as *mut _, module.as_ptr() as *const _)
235        })?;
236        Ok(Module {
237            inner,
238            handle: handle.clone(),
239        })
240    }
241
242    /// Same as [`Module::load`] but uses `fatCubin` format.
243    pub fn load_fatcubin(handle: &Rc<Handle<'a>>, module: &[u8]) -> CudaResult<Self> {
244        let mut inner = null_mut();
245        cuda_error(unsafe {
246            sys::cuModuleLoadFatBinary(&mut inner as *mut _, module.as_ptr() as *const _)
247        })?;
248        Ok(Module {
249            inner,
250            handle: handle.clone(),
251        })
252    }
253
254    /// Retrieve a reference to a define CUDA kernel within the module.
255    pub fn get_function<'b>(&'b self, name: &str) -> CudaResult<Function<'a, 'b>> {
256        let mut inner = null_mut();
257        let name = CString::new(name).unwrap();
258        cuda_error(unsafe {
259            sys::cuModuleGetFunction(&mut inner as *mut _, self.inner, name.as_ptr())
260        })?;
261        Ok(Function {
262            module: self,
263            inner,
264        })
265    }
266
267    /// Get a pointer to a global variable defined by a CUDA module.
268    pub fn get_global<'b: 'a>(&'b self, name: &str) -> CudaResult<DevicePtr<'b>> {
269        let mut out = DevicePtr {
270            handle: self.handle.clone(),
271            inner: 0,
272            len: 0,
273        };
274        let name = CString::new(name).unwrap();
275        cuda_error(unsafe {
276            sys::cuModuleGetGlobal_v2(
277                &mut out.inner,
278                &mut out.len as *mut u64 as *mut _,
279                self.inner,
280                name.as_ptr(),
281            )
282        })?;
283        Ok(out)
284    }
285
286    // pub fn get_surface(&self, name: &str) {
287
288    // }
289
290    // pub fn get_texture(&self, name: &str) {
291
292    // }
293}
294
295impl<'a> Drop for Module<'a> {
296    fn drop(&mut self) {
297        if let Err(e) = cuda_error(unsafe { sys::cuModuleUnload(self.inner) }) {
298            eprintln!("CUDA: failed to destroy cuda module: {:?}", e);
299        }
300    }
301}