1use std::{
2 borrow::Cow,
3 ffi::{c_void, CString},
4 ptr::null_mut,
5 rc::Rc,
6};
7
8use crate::*;
9
10pub struct Linker<'a> {
13 inner: *mut sys::CUlinkState_st,
14 info_buf: Vec<u8>, errors_buf: Vec<u8>,
16 handle: Rc<Handle<'a>>,
17}
18
19#[derive(Clone, Copy, Debug, PartialEq)]
21pub enum LinkerInputType {
22 Cubin,
23 Ptx,
24 Fatbin,
25}
26
27#[derive(Clone, Copy, Debug)]
29pub struct LinkerOptions {
30 pub debug_info: bool,
32 pub log_info: bool,
34 pub log_errors: bool,
36 pub verbose_logs: bool,
38}
39
40impl Default for LinkerOptions {
41 fn default() -> Self {
42 LinkerOptions {
43 debug_info: false,
44 log_info: true,
45 log_errors: true,
46 verbose_logs: false,
47 }
48 }
49}
50
51impl<'a> Linker<'a> {
52 pub fn new(
54 handle: &Rc<Handle<'a>>,
55 compute_capability: CudaVersion,
56 options: LinkerOptions,
57 ) -> CudaResult<Self> {
58 let mut linker = Linker {
59 inner: null_mut(),
60 info_buf: if options.log_info {
61 let mut buf = Vec::with_capacity(16 * 1024 * 1024);
62 buf.push(0);
63 unsafe { buf.set_len(buf.capacity()) };
64 buf
65 } else {
66 vec![]
67 },
68 errors_buf: if options.log_errors {
69 let mut buf = Vec::with_capacity(16 * 1024 * 1024);
70 buf.push(0);
71 unsafe { buf.set_len(buf.capacity()) };
72 buf
73 } else {
74 vec![]
75 },
76 handle: handle.clone(),
77 };
78 let log_verbose = if options.verbose_logs { 1u32 } else { 0u32 };
79 let debug_info = if options.debug_info { 1u32 } else { 0u32 };
80
81 let mut options = [
82 sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER,
83 sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
84 sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER,
85 sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
86 sys::CUjit_option_enum_CU_JIT_TARGET,
87 sys::CUjit_option_enum_CU_JIT_LOG_VERBOSE,
88 sys::CUjit_option_enum_CU_JIT_GENERATE_DEBUG_INFO,
89 ];
90 let target = match (compute_capability.major, compute_capability.minor) {
91 (2, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_20,
92 (2, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_21,
93 (3, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_30,
94 (3, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_32,
95 (3, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_35,
96 (3, 7) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_37,
97 (5, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_50,
98 (5, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_52,
99 (5, 3) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_53,
100 (6, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_60,
101 (6, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_61,
102 (6, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_62,
103 (7, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_70,
104 (7, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_72,
105 (7, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_75,
106 (8, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_80,
107 (8, 6) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_86,
108 (_, _) => return Err(ErrorCode::UnsupportedPtxVersion),
109 };
110
111 let mut values = [
112 linker.info_buf.as_mut_ptr() as *mut c_void,
113 linker.info_buf.len() as u32 as u64 as *mut c_void,
114 linker.errors_buf.as_mut_ptr() as *mut c_void,
115 linker.errors_buf.len() as u32 as u64 as *mut c_void,
116 target as u64 as *mut c_void,
117 log_verbose as u64 as *mut c_void,
118 debug_info as u64 as *mut c_void,
119 ];
120 cuda_error(unsafe {
121 sys::cuLinkCreate_v2(
122 options.len() as u32,
123 options.as_mut_ptr(),
124 values.as_mut_ptr(),
125 &mut linker.inner as *mut _,
126 )
127 })?;
128 Ok(linker)
129 }
130
131 fn emit_logs(&self) {
132 let info_string = self.info_buf.iter().position(|x| *x == 0);
133 if let Some(info_string) = info_string {
134 let info_string = String::from_utf8_lossy(&self.info_buf[..info_string]);
135 if !info_string.is_empty() {
136 info_string.split('\n').for_each(|line| {
137 println!("[CUDA INFO] {}", line);
138 });
139 }
140 }
141 let error_string = self.errors_buf.iter().position(|x| *x == 0);
142 if let Some(error_string) = error_string {
143 let error_string = String::from_utf8_lossy(&self.errors_buf[..error_string]);
144 if !error_string.is_empty() {
145 error_string.split('\n').for_each(|line| {
146 println!("[CUDA ERROR] {}", line);
147 });
148 }
149 }
150 }
151
152 pub fn add(self, name: &str, format: LinkerInputType, in_data: &[u8]) -> CudaResult<Self> {
154 let mut data = Cow::Borrowed(in_data);
155 if format == LinkerInputType::Ptx {
156 let mut new_data = Vec::with_capacity(in_data.len() + 1);
157 new_data.extend_from_slice(in_data);
158 new_data.push(0);
159 data = Cow::Owned(new_data)
160 }
161
162 let format = match format {
163 LinkerInputType::Cubin => sys::CUjitInputType_enum_CU_JIT_INPUT_CUBIN,
164 LinkerInputType::Ptx => sys::CUjitInputType_enum_CU_JIT_INPUT_PTX,
165 LinkerInputType::Fatbin => sys::CUjitInputType_enum_CU_JIT_INPUT_FATBINARY,
166 };
167 let name = CString::new(name).unwrap();
168
169 let out = cuda_error(unsafe {
170 sys::cuLinkAddData_v2(
171 self.inner,
172 format,
173 data.as_ptr() as *mut u8 as *mut c_void,
174 data.len() as sys::size_t,
175 name.as_ptr(),
176 0,
177 null_mut(),
178 null_mut(),
179 )
180 });
181
182 if let Err(e) = out {
183 self.emit_logs();
184 return Err(e);
185 }
186 Ok(self)
187 }
188
189 pub fn build(&self) -> CudaResult<&[u8]> {
191 let mut cubin_out: *mut c_void = null_mut();
192 let mut size_out: sys::size_t = 0;
193 let out = cuda_error(unsafe {
194 sys::cuLinkComplete(
195 self.inner,
196 &mut cubin_out as *mut *mut c_void,
197 &mut size_out as *mut sys::size_t,
198 )
199 });
200 self.emit_logs();
201 if let Err(e) = out {
202 return Err(e);
203 }
204 Ok(unsafe { std::slice::from_raw_parts(cubin_out as *const u8, size_out as usize) })
205 }
206
207 pub fn build_module(&self) -> CudaResult<Module<'a>> {
209 let built = self.build()?;
210 Module::load(&self.handle, built)
211 }
212}
213
214impl<'a> Drop for Linker<'a> {
215 fn drop(&mut self) {
216 if let Err(e) = cuda_error(unsafe { sys::cuLinkDestroy(self.inner) }) {
217 eprintln!("CUDA: failed to destroy cuda linker state: {:?}", e);
218 }
219 }
220}
221
222pub struct Module<'a> {
224 handle: Rc<Handle<'a>>,
225 inner: *mut sys::CUmod_st,
226}
227
228impl<'a> Module<'a> {
229 pub fn load(handle: &Rc<Handle<'a>>, module: &[u8]) -> CudaResult<Self> {
232 let mut inner = null_mut();
233 cuda_error(unsafe {
234 sys::cuModuleLoadData(&mut inner as *mut _, module.as_ptr() as *const _)
235 })?;
236 Ok(Module {
237 inner,
238 handle: handle.clone(),
239 })
240 }
241
242 pub fn load_fatcubin(handle: &Rc<Handle<'a>>, module: &[u8]) -> CudaResult<Self> {
244 let mut inner = null_mut();
245 cuda_error(unsafe {
246 sys::cuModuleLoadFatBinary(&mut inner as *mut _, module.as_ptr() as *const _)
247 })?;
248 Ok(Module {
249 inner,
250 handle: handle.clone(),
251 })
252 }
253
254 pub fn get_function<'b>(&'b self, name: &str) -> CudaResult<Function<'a, 'b>> {
256 let mut inner = null_mut();
257 let name = CString::new(name).unwrap();
258 cuda_error(unsafe {
259 sys::cuModuleGetFunction(&mut inner as *mut _, self.inner, name.as_ptr())
260 })?;
261 Ok(Function {
262 module: self,
263 inner,
264 })
265 }
266
267 pub fn get_global<'b: 'a>(&'b self, name: &str) -> CudaResult<DevicePtr<'b>> {
269 let mut out = DevicePtr {
270 handle: self.handle.clone(),
271 inner: 0,
272 len: 0,
273 };
274 let name = CString::new(name).unwrap();
275 cuda_error(unsafe {
276 sys::cuModuleGetGlobal_v2(
277 &mut out.inner,
278 &mut out.len as *mut u64 as *mut _,
279 self.inner,
280 name.as_ptr(),
281 )
282 })?;
283 Ok(out)
284 }
285
286 }
294
295impl<'a> Drop for Module<'a> {
296 fn drop(&mut self) {
297 if let Err(e) = cuda_error(unsafe { sys::cuModuleUnload(self.inner) }) {
298 eprintln!("CUDA: failed to destroy cuda module: {:?}", e);
299 }
300 }
301}