kn_cuda_sys/wrapper/rtc/
core.rs

1use std::collections::HashMap;
2use std::ffi::{c_void, CStr, CString};
3use std::fmt::Write;
4use std::ptr::{null, null_mut};
5use std::sync::Arc;
6
7use itertools::Itertools;
8
9use crate::bindings::{
10    CU_LAUNCH_PARAM_BUFFER_POINTER, CU_LAUNCH_PARAM_BUFFER_SIZE, CU_LAUNCH_PARAM_END, cuLaunchKernel, cuModuleGetFunction,
11    cuModuleLoadDataEx, cuModuleUnload, CUresult, nvrtcAddNameExpression, nvrtcCompileProgram, nvrtcCreateProgram,
12    nvrtcDestroyProgram, nvrtcGetLoweredName, nvrtcGetProgramLog, nvrtcGetProgramLogSize, nvrtcGetPTX,
13    nvrtcGetPTXSize, nvrtcResult,
14};
15use crate::wrapper::handle::{CudaDevice, CudaStream};
16use crate::wrapper::status::Status;
17
18#[derive(Debug)]
19pub struct CuModule {
20    inner: Arc<CuModuleInner>,
21}
22
23#[derive(Debug)]
24struct CuModuleInner {
25    device: CudaDevice,
26    inner: crate::bindings::CUmodule,
27}
28
29#[derive(Debug, Clone)]
30pub struct CuFunction {
31    // field is never used, but is present to keep module from being dropped
32    //   this is necessary because CUfunction points to something inside of the CUmodule structure
33    module: Arc<CuModuleInner>,
34    function: crate::bindings::CUfunction,
35}
36
37unsafe impl Send for CuModuleInner {}
38
39unsafe impl Send for CuFunction {}
40
41#[must_use]
42#[derive(Debug)]
43pub struct CompileResult {
44    pub source: String,
45    pub log: String,
46    pub module: Result<CuModule, nvrtcResult>,
47    pub lowered_names: HashMap<String, String>,
48}
49
50#[derive(Debug, Clone, Copy, Eq, PartialEq)]
51pub struct Dim3 {
52    pub x: u32,
53    pub y: u32,
54    pub z: u32,
55}
56
57impl Drop for CuModuleInner {
58    fn drop(&mut self) {
59        unsafe {
60            cuModuleUnload(self.inner).unwrap_in_drop();
61        }
62    }
63}
64
65impl CuModule {
66    pub unsafe fn from_ptx(device: CudaDevice, ptx: &[u8]) -> CuModule {
67        let mut inner = null_mut();
68        cuModuleLoadDataEx(
69            &mut inner as *mut _,
70            ptx.as_ptr() as *const _,
71            0,
72            null_mut(),
73            null_mut(),
74        )
75        .unwrap();
76        CuModule {
77            inner: Arc::new(CuModuleInner { device, inner }),
78        }
79    }
80
81    pub fn from_source(
82        device: CudaDevice,
83        src: &str,
84        name: Option<&str>,
85        expected_names: &[&str],
86        headers: &HashMap<&str, &str>,
87    ) -> CompileResult {
88        device.switch_to();
89
90        unsafe {
91            let mut program = null_mut();
92
93            let src_c = CString::new(src.as_bytes()).unwrap();
94            let name_c = name.map(|name| CString::new(name.as_bytes()).unwrap());
95
96            let header_names_c = headers
97                .keys()
98                .map(|s| CString::new(s.as_bytes()).unwrap())
99                .collect_vec();
100            let header_sources_c = headers
101                .values()
102                .map(|s| CString::new(s.as_bytes()).unwrap())
103                .collect_vec();
104
105            let header_names_ptr = header_names_c.iter().map(|s| s.as_ptr()).collect_vec();
106            let header_sources_ptr = header_sources_c.iter().map(|s| s.as_ptr()).collect_vec();
107
108            nvrtcCreateProgram(
109                &mut program as *mut _,
110                src_c.as_ptr() as *const i8,
111                name_c.map_or(null(), |name_c| name_c.as_ptr() as *const i8),
112                headers.len() as i32,
113                header_sources_ptr.as_ptr(),
114                header_names_ptr.as_ptr(),
115            )
116            .unwrap();
117
118            // add requested names
119            for &expected_name in expected_names {
120                let expected_name_c = CString::new(expected_name.as_bytes()).unwrap();
121                nvrtcAddNameExpression(program, expected_name_c.as_ptr()).unwrap();
122            }
123
124            // figure out the arguments
125            let cap = device.compute_capability();
126            let args = vec![
127                format!("--gpu-architecture=compute_{}{}", cap.major, cap.minor),
128                "-std=c++11".to_string(),
129                "-lineinfo".to_string(),
130                // "-G".to_string(),
131                // "--generate-line-info".to_string(),
132                // "--define-macro=NVRTC".to_string(),
133            ];
134
135            let args = args
136                .into_iter()
137                .map(CString::new)
138                .collect::<Result<Vec<CString>, _>>()
139                .unwrap();
140            let args = args.iter().map(|s| s.as_ptr() as *const i8).collect_vec();
141
142            // actually compile the program
143            let result = nvrtcCompileProgram(program, args.len() as i32, args.as_ptr());
144
145            let mut log_size: usize = 0;
146            nvrtcGetProgramLogSize(program, &mut log_size as *mut _).unwrap();
147
148            let mut log_bytes = vec![0u8; log_size];
149            nvrtcGetProgramLog(program, log_bytes.as_mut_ptr() as *mut _).unwrap();
150
151            let log_c = CString::from_vec_with_nul(log_bytes).unwrap();
152            let log = log_c.to_str().unwrap().trim().to_owned();
153
154            if result != nvrtcResult::NVRTC_SUCCESS {
155                return CompileResult {
156                    source: src.to_owned(),
157                    log,
158                    module: Err(result),
159                    lowered_names: Default::default(),
160                };
161            }
162
163            // extract the lowered names
164            let lowered_names: HashMap<String, String> = expected_names
165                .iter()
166                .map(|&expected_name| {
167                    let expected_name_c = CString::new(expected_name.as_bytes()).unwrap();
168
169                    let mut lowered_name = null();
170                    nvrtcGetLoweredName(program, expected_name_c.as_ptr(), &mut lowered_name as *mut _).unwrap();
171                    let lowered_name = CStr::from_ptr(lowered_name).to_str().unwrap().to_owned();
172
173                    (expected_name.to_owned(), lowered_name)
174                })
175                .collect();
176
177            // get the resulting assembly
178            let mut ptx_size = 0;
179            nvrtcGetPTXSize(program, &mut ptx_size as *mut _).unwrap();
180
181            let mut ptx = vec![0u8; ptx_size];
182            nvrtcGetPTX(program, ptx.as_mut_ptr() as *mut _);
183
184            nvrtcDestroyProgram(&mut program as *mut _).unwrap();
185
186            let module = CuModule::from_ptx(device, &ptx);
187
188            CompileResult {
189                source: src.to_owned(),
190                log,
191                module: Ok(module),
192                lowered_names,
193            }
194        }
195    }
196
197    /// It's probably easier to use [CompileResult::get_function_by_name] if possible.
198    pub fn get_function_by_lower_name(&self, name: &str) -> Option<CuFunction> {
199        unsafe {
200            let name_c = CString::new(name.as_bytes()).unwrap();
201            let mut function = null_mut();
202
203            let result = cuModuleGetFunction(&mut function as *mut _, self.inner.inner, name_c.as_ptr());
204
205            if result == CUresult::CUDA_ERROR_NOT_FOUND {
206                None
207            } else {
208                result.unwrap();
209                Some(CuFunction {
210                    module: Arc::clone(&self.inner),
211                    function,
212                })
213            }
214        }
215    }
216
217    pub fn device(&self) -> CudaDevice {
218        self.inner.device
219    }
220}
221
222impl CuFunction {
223    pub unsafe fn launch_kernel_pointers(
224        &self,
225        grid_dim: impl Into<Dim3>,
226        block_dim: impl Into<Dim3>,
227        shared_mem_bytes: u32,
228        stream: &CudaStream,
229        args: &[*mut c_void],
230    ) {
231        assert_eq!(self.device(), CudaDevice::current());
232        let grid_dim = grid_dim.into();
233        let block_dim = block_dim.into();
234
235        cuLaunchKernel(
236            self.function,
237            grid_dim.x,
238            grid_dim.y,
239            grid_dim.z,
240            block_dim.x,
241            block_dim.y,
242            block_dim.z,
243            shared_mem_bytes,
244            stream.inner(),
245            args.as_ptr() as *mut _,
246            null_mut(),
247        )
248        .unwrap()
249    }
250
251    pub unsafe fn launch_kernel(
252        &self,
253        grid_dim: impl Into<Dim3>,
254        block_dim: impl Into<Dim3>,
255        shared_mem_bytes: u32,
256        stream: &CudaStream,
257        args: &[u8],
258    ) -> CUresult {
259        assert_eq!(self.device(), CudaDevice::current());
260        let grid_dim = grid_dim.into();
261        let block_dim = block_dim.into();
262
263        let mut config = [
264            CU_LAUNCH_PARAM_BUFFER_POINTER,
265            args.as_ptr() as *mut c_void,
266            CU_LAUNCH_PARAM_BUFFER_SIZE,
267            &mut args.len() as *mut usize as *mut c_void,
268            CU_LAUNCH_PARAM_END,
269        ];
270
271        cuLaunchKernel(
272            self.function,
273            grid_dim.x,
274            grid_dim.y,
275            grid_dim.z,
276            block_dim.x,
277            block_dim.y,
278            block_dim.z,
279            shared_mem_bytes,
280            stream.inner(),
281            null_mut(),
282            config.as_mut_ptr(),
283        )
284    }
285
286    pub fn device(&self) -> CudaDevice {
287        self.module.device
288    }
289}
290
291impl Dim3 {
292    pub fn single(x: u32) -> Self {
293        Self { x, y: 1, z: 1 }
294    }
295
296    pub fn new(x: u32, y: u32, z: u32) -> Self {
297        Self { x, y, z }
298    }
299}
300
301impl From<u32> for Dim3 {
302    fn from(x: u32) -> Self {
303        Dim3::single(x)
304    }
305}
306
307impl CompileResult {
308    pub fn get_function_by_name(&self, name: &str) -> Result<Option<CuFunction>, nvrtcResult> {
309        let module = self.module.as_ref().map_err(|&e| e)?;
310        let function = self
311            .lowered_names
312            .get(name)
313            .and_then(|lower_name| module.get_function_by_lower_name(lower_name));
314        Ok(function)
315    }
316
317    pub fn source_with_line_numbers(&self) -> String {
318        prefix_line_numbers(&self.source)
319    }
320}
321
322pub fn prefix_line_numbers(s: &str) -> String {
323    let line_count = s.lines().count();
324    let max_number_size = (line_count + 1).to_string().len();
325
326    let mut result = String::new();
327
328    for (i, line) in s.lines().enumerate() {
329        let line_number = i + 1;
330        let line_number = format!("{}", line_number);
331
332        result.extend(std::iter::repeat(' ').take(max_number_size - line_number.len()));
333        writeln!(&mut result, "{}| {}", line_number, line).unwrap();
334    }
335
336    result
337}