1use std::collections::HashMap;
2use std::ffi::{c_void, CStr, CString};
3use std::fmt::Write;
4use std::ptr::{null, null_mut};
5use std::sync::Arc;
6
7use itertools::Itertools;
8
9use crate::bindings::{
10 CU_LAUNCH_PARAM_BUFFER_POINTER, CU_LAUNCH_PARAM_BUFFER_SIZE, CU_LAUNCH_PARAM_END, cuLaunchKernel, cuModuleGetFunction,
11 cuModuleLoadDataEx, cuModuleUnload, CUresult, nvrtcAddNameExpression, nvrtcCompileProgram, nvrtcCreateProgram,
12 nvrtcDestroyProgram, nvrtcGetLoweredName, nvrtcGetProgramLog, nvrtcGetProgramLogSize, nvrtcGetPTX,
13 nvrtcGetPTXSize, nvrtcResult,
14};
15use crate::wrapper::handle::{CudaDevice, CudaStream};
16use crate::wrapper::status::Status;
17
18#[derive(Debug)]
19pub struct CuModule {
20 inner: Arc<CuModuleInner>,
21}
22
23#[derive(Debug)]
24struct CuModuleInner {
25 device: CudaDevice,
26 inner: crate::bindings::CUmodule,
27}
28
29#[derive(Debug, Clone)]
30pub struct CuFunction {
31 module: Arc<CuModuleInner>,
34 function: crate::bindings::CUfunction,
35}
36
37unsafe impl Send for CuModuleInner {}
38
39unsafe impl Send for CuFunction {}
40
41#[must_use]
42#[derive(Debug)]
43pub struct CompileResult {
44 pub source: String,
45 pub log: String,
46 pub module: Result<CuModule, nvrtcResult>,
47 pub lowered_names: HashMap<String, String>,
48}
49
50#[derive(Debug, Clone, Copy, Eq, PartialEq)]
51pub struct Dim3 {
52 pub x: u32,
53 pub y: u32,
54 pub z: u32,
55}
56
57impl Drop for CuModuleInner {
58 fn drop(&mut self) {
59 unsafe {
60 cuModuleUnload(self.inner).unwrap_in_drop();
61 }
62 }
63}
64
65impl CuModule {
66 pub unsafe fn from_ptx(device: CudaDevice, ptx: &[u8]) -> CuModule {
67 let mut inner = null_mut();
68 cuModuleLoadDataEx(
69 &mut inner as *mut _,
70 ptx.as_ptr() as *const _,
71 0,
72 null_mut(),
73 null_mut(),
74 )
75 .unwrap();
76 CuModule {
77 inner: Arc::new(CuModuleInner { device, inner }),
78 }
79 }
80
81 pub fn from_source(
82 device: CudaDevice,
83 src: &str,
84 name: Option<&str>,
85 expected_names: &[&str],
86 headers: &HashMap<&str, &str>,
87 ) -> CompileResult {
88 device.switch_to();
89
90 unsafe {
91 let mut program = null_mut();
92
93 let src_c = CString::new(src.as_bytes()).unwrap();
94 let name_c = name.map(|name| CString::new(name.as_bytes()).unwrap());
95
96 let header_names_c = headers
97 .keys()
98 .map(|s| CString::new(s.as_bytes()).unwrap())
99 .collect_vec();
100 let header_sources_c = headers
101 .values()
102 .map(|s| CString::new(s.as_bytes()).unwrap())
103 .collect_vec();
104
105 let header_names_ptr = header_names_c.iter().map(|s| s.as_ptr()).collect_vec();
106 let header_sources_ptr = header_sources_c.iter().map(|s| s.as_ptr()).collect_vec();
107
108 nvrtcCreateProgram(
109 &mut program as *mut _,
110 src_c.as_ptr() as *const i8,
111 name_c.map_or(null(), |name_c| name_c.as_ptr() as *const i8),
112 headers.len() as i32,
113 header_sources_ptr.as_ptr(),
114 header_names_ptr.as_ptr(),
115 )
116 .unwrap();
117
118 for &expected_name in expected_names {
120 let expected_name_c = CString::new(expected_name.as_bytes()).unwrap();
121 nvrtcAddNameExpression(program, expected_name_c.as_ptr()).unwrap();
122 }
123
124 let cap = device.compute_capability();
126 let args = vec![
127 format!("--gpu-architecture=compute_{}{}", cap.major, cap.minor),
128 "-std=c++11".to_string(),
129 "-lineinfo".to_string(),
130 ];
134
135 let args = args
136 .into_iter()
137 .map(CString::new)
138 .collect::<Result<Vec<CString>, _>>()
139 .unwrap();
140 let args = args.iter().map(|s| s.as_ptr() as *const i8).collect_vec();
141
142 let result = nvrtcCompileProgram(program, args.len() as i32, args.as_ptr());
144
145 let mut log_size: usize = 0;
146 nvrtcGetProgramLogSize(program, &mut log_size as *mut _).unwrap();
147
148 let mut log_bytes = vec![0u8; log_size];
149 nvrtcGetProgramLog(program, log_bytes.as_mut_ptr() as *mut _).unwrap();
150
151 let log_c = CString::from_vec_with_nul(log_bytes).unwrap();
152 let log = log_c.to_str().unwrap().trim().to_owned();
153
154 if result != nvrtcResult::NVRTC_SUCCESS {
155 return CompileResult {
156 source: src.to_owned(),
157 log,
158 module: Err(result),
159 lowered_names: Default::default(),
160 };
161 }
162
163 let lowered_names: HashMap<String, String> = expected_names
165 .iter()
166 .map(|&expected_name| {
167 let expected_name_c = CString::new(expected_name.as_bytes()).unwrap();
168
169 let mut lowered_name = null();
170 nvrtcGetLoweredName(program, expected_name_c.as_ptr(), &mut lowered_name as *mut _).unwrap();
171 let lowered_name = CStr::from_ptr(lowered_name).to_str().unwrap().to_owned();
172
173 (expected_name.to_owned(), lowered_name)
174 })
175 .collect();
176
177 let mut ptx_size = 0;
179 nvrtcGetPTXSize(program, &mut ptx_size as *mut _).unwrap();
180
181 let mut ptx = vec![0u8; ptx_size];
182 nvrtcGetPTX(program, ptx.as_mut_ptr() as *mut _);
183
184 nvrtcDestroyProgram(&mut program as *mut _).unwrap();
185
186 let module = CuModule::from_ptx(device, &ptx);
187
188 CompileResult {
189 source: src.to_owned(),
190 log,
191 module: Ok(module),
192 lowered_names,
193 }
194 }
195 }
196
197 pub fn get_function_by_lower_name(&self, name: &str) -> Option<CuFunction> {
199 unsafe {
200 let name_c = CString::new(name.as_bytes()).unwrap();
201 let mut function = null_mut();
202
203 let result = cuModuleGetFunction(&mut function as *mut _, self.inner.inner, name_c.as_ptr());
204
205 if result == CUresult::CUDA_ERROR_NOT_FOUND {
206 None
207 } else {
208 result.unwrap();
209 Some(CuFunction {
210 module: Arc::clone(&self.inner),
211 function,
212 })
213 }
214 }
215 }
216
217 pub fn device(&self) -> CudaDevice {
218 self.inner.device
219 }
220}
221
222impl CuFunction {
223 pub unsafe fn launch_kernel_pointers(
224 &self,
225 grid_dim: impl Into<Dim3>,
226 block_dim: impl Into<Dim3>,
227 shared_mem_bytes: u32,
228 stream: &CudaStream,
229 args: &[*mut c_void],
230 ) {
231 assert_eq!(self.device(), CudaDevice::current());
232 let grid_dim = grid_dim.into();
233 let block_dim = block_dim.into();
234
235 cuLaunchKernel(
236 self.function,
237 grid_dim.x,
238 grid_dim.y,
239 grid_dim.z,
240 block_dim.x,
241 block_dim.y,
242 block_dim.z,
243 shared_mem_bytes,
244 stream.inner(),
245 args.as_ptr() as *mut _,
246 null_mut(),
247 )
248 .unwrap()
249 }
250
251 pub unsafe fn launch_kernel(
252 &self,
253 grid_dim: impl Into<Dim3>,
254 block_dim: impl Into<Dim3>,
255 shared_mem_bytes: u32,
256 stream: &CudaStream,
257 args: &[u8],
258 ) -> CUresult {
259 assert_eq!(self.device(), CudaDevice::current());
260 let grid_dim = grid_dim.into();
261 let block_dim = block_dim.into();
262
263 let mut config = [
264 CU_LAUNCH_PARAM_BUFFER_POINTER,
265 args.as_ptr() as *mut c_void,
266 CU_LAUNCH_PARAM_BUFFER_SIZE,
267 &mut args.len() as *mut usize as *mut c_void,
268 CU_LAUNCH_PARAM_END,
269 ];
270
271 cuLaunchKernel(
272 self.function,
273 grid_dim.x,
274 grid_dim.y,
275 grid_dim.z,
276 block_dim.x,
277 block_dim.y,
278 block_dim.z,
279 shared_mem_bytes,
280 stream.inner(),
281 null_mut(),
282 config.as_mut_ptr(),
283 )
284 }
285
286 pub fn device(&self) -> CudaDevice {
287 self.module.device
288 }
289}
290
291impl Dim3 {
292 pub fn single(x: u32) -> Self {
293 Self { x, y: 1, z: 1 }
294 }
295
296 pub fn new(x: u32, y: u32, z: u32) -> Self {
297 Self { x, y, z }
298 }
299}
300
301impl From<u32> for Dim3 {
302 fn from(x: u32) -> Self {
303 Dim3::single(x)
304 }
305}
306
307impl CompileResult {
308 pub fn get_function_by_name(&self, name: &str) -> Result<Option<CuFunction>, nvrtcResult> {
309 let module = self.module.as_ref().map_err(|&e| e)?;
310 let function = self
311 .lowered_names
312 .get(name)
313 .and_then(|lower_name| module.get_function_by_lower_name(lower_name));
314 Ok(function)
315 }
316
317 pub fn source_with_line_numbers(&self) -> String {
318 prefix_line_numbers(&self.source)
319 }
320}
321
322pub fn prefix_line_numbers(s: &str) -> String {
323 let line_count = s.lines().count();
324 let max_number_size = (line_count + 1).to_string().len();
325
326 let mut result = String::new();
327
328 for (i, line) in s.lines().enumerate() {
329 let line_number = i + 1;
330 let line_number = format!("{}", line_number);
331
332 result.extend(std::iter::repeat(' ').take(max_number_size - line_number.len()));
333 writeln!(&mut result, "{}| {}", line_number, line).unwrap();
334 }
335
336 result
337}