cubecl_hip_sys/
lib.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod build_script;
10
11#[cfg(target_os = "linux")]
12mod bindings;
13#[cfg(target_os = "linux")]
14#[allow(unused)]
15pub use bindings::*;
16
17#[cfg(target_os = "linux")]
18#[cfg(test)]
19mod tests {
20    use super::bindings::*;
21    use std::{ffi::CString, ptr, time::Instant};
22
23    #[test]
24    fn test_launch_kernel_end_to_end() {
25        // Kernel that computes y values of a linear equation in slop-intercept form
26        let source = CString::new(
27            r#"
28extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
29  int tid = blockIdx.x * blockDim.x + threadIdx.x;
30  if (tid < n) {
31    out[tid] = x[tid] * a + b[tid];
32  }
33}
34 "#,
35        )
36        .expect("Should construct kernel string");
37
38        let func_name = CString::new("kernel".to_string()).unwrap();
39        // reference: https://rocm.docs.amd.com/projects/HIP/en/docs-6.0.0/user_guide/hip_rtc.html
40
41        // Step 0: Select the GPU device
42        unsafe {
43            let status = hipSetDevice(0);
44            assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
45        }
46
47        let free: usize = 0;
48        let total: usize = 0;
49        unsafe {
50            let status = hipMemGetInfo(
51                &free as *const _ as *mut usize,
52                &total as *const _ as *mut usize,
53            );
54            assert_eq!(
55                status, HIP_SUCCESS,
56                "Should get the available memory of the device"
57            );
58            println!("Free: {} | Total:{}", free, total);
59        }
60
61        // Step 1: Create the program
62        let mut program: hiprtcProgram = ptr::null_mut();
63        unsafe {
64            let status = hiprtcCreateProgram(
65                &mut program,    // Program
66                source.as_ptr(), // kernel string
67                ptr::null(),     // Name of the file (there is no file)
68                0,               // Number of headers
69                ptr::null_mut(), // Header sources
70                ptr::null_mut(), // Name of header files
71            );
72            assert_eq!(
73                status, hiprtcResult_HIPRTC_SUCCESS,
74                "Should create the program"
75            );
76        }
77
78        // Step 2: Compile the program
79        unsafe {
80            let status = hiprtcCompileProgram(
81                program,         // Program
82                0,               // Number of options
83                ptr::null_mut(), // Clang Options
84            );
85            if status != hiprtcResult_HIPRTC_SUCCESS {
86                let mut log_size: usize = 0;
87                let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
88                assert_eq!(
89                    status, hiprtcResult_HIPRTC_SUCCESS,
90                    "Should retrieve the compilation log size"
91                );
92                println!("Compilation log size: {log_size}");
93                let mut log_buffer = vec![0i8; log_size];
94                let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
95                assert_eq!(
96                    status, hiprtcResult_HIPRTC_SUCCESS,
97                    "Should retrieve the compilation log contents"
98                );
99                let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
100                println!("Compilation log: {}", log.to_string_lossy());
101            }
102            assert_eq!(
103                status, hiprtcResult_HIPRTC_SUCCESS,
104                "Should compile the program"
105            );
106        }
107
108        // Step 3: Load compiled code
109        let mut code_size: usize = 0;
110        unsafe {
111            let status = hiprtcGetCodeSize(program, &mut code_size);
112            assert_eq!(
113                status, hiprtcResult_HIPRTC_SUCCESS,
114                "Should get size of compiled code"
115            );
116        }
117        let mut code: Vec<u8> = vec![0; code_size];
118        unsafe {
119            let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
120            assert_eq!(
121                status, hiprtcResult_HIPRTC_SUCCESS,
122                "Should load compiled code"
123            );
124        }
125
126        // Step 4: Once the compiled code is loaded, the program can be destroyed
127        unsafe {
128            let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
129            assert_eq!(
130                status, hiprtcResult_HIPRTC_SUCCESS,
131                "Should destroy the program"
132            );
133        }
134        assert!(!code.is_empty(), "Generated code should not be empty");
135
136        // Step 5: Allocate Memory
137        let n = 1024;
138        let a = 2.0f32;
139        let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
140        let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
141        let mut out: Vec<f32> = vec![0.0; n];
142        // Allocate GPU memory for x, y, and out
143        // There is no need to allocate memory for a and n as we can pass
144        // host pointers directly to kernel launch function
145        let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
146        let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147        let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
148        unsafe {
149            let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
150            assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
151            let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
152            assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
153            let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
154            assert_eq!(
155                status_out, HIP_SUCCESS,
156                "Should allocate memory for device_out"
157            );
158        }
159
160        // Step 6: Copy data to GPU memory
161        unsafe {
162            let status_device_x = hipMemcpy(
163                device_x,
164                x.as_ptr() as *const libc::c_void,
165                n * std::mem::size_of::<f32>(),
166                hipMemcpyKind_hipMemcpyHostToDevice,
167            );
168            assert_eq!(
169                status_device_x, HIP_SUCCESS,
170                "Should copy device_x successfully"
171            );
172            let status_device_b = hipMemcpy(
173                device_b,
174                b.as_ptr() as *const libc::c_void,
175                n * std::mem::size_of::<f32>(),
176                hipMemcpyKind_hipMemcpyHostToDevice,
177            );
178            assert_eq!(
179                status_device_b, HIP_SUCCESS,
180                "Should copy device_b successfully"
181            );
182            // Initialize the output memory on device to 0.0
183            let status_device_out = hipMemcpy(
184                device_out,
185                out.as_ptr() as *const libc::c_void,
186                n * std::mem::size_of::<f32>(),
187                hipMemcpyKind_hipMemcpyHostToDevice,
188            );
189            assert_eq!(
190                status_device_out, HIP_SUCCESS,
191                "Should copy device_out successfully"
192            );
193        }
194
195        // Step 7: Create the module containing the kernel and get the function that points to it
196        let mut module: hipModule_t = ptr::null_mut();
197        let mut function: hipFunction_t = ptr::null_mut();
198        unsafe {
199            let status_module =
200                hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
201            assert_eq!(
202                status_module, HIP_SUCCESS,
203                "Should load compiled code into module"
204            );
205            let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
206            assert_eq!(
207                status_function, HIP_SUCCESS,
208                "Should return module function"
209            );
210        }
211
212        // Step 8: Launch Kernel
213        let start_time = Instant::now();
214        // Create the array of arguments to pass to the kernel
215        // They must be in the same order as the order of declaration of the kernel arguments
216        let mut args: [*mut libc::c_void; 5] = [
217            &a as *const _ as *mut libc::c_void,
218            &device_x as *const _ as *mut libc::c_void,
219            &device_b as *const _ as *mut libc::c_void,
220            &device_out as *const _ as *mut libc::c_void,
221            &n as *const _ as *mut libc::c_void,
222        ];
223        let block_dim_x: usize = 64;
224        let grid_dim_x: usize = n / block_dim_x;
225        // We could use the default stream by passing 0 to the launch kernel but for the sake of
226        // coverage we create a stream explicitly
227        let mut stream: hipStream_t = std::ptr::null_mut();
228        unsafe {
229            let stream_status = hipStreamCreate(&mut stream);
230            assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
231        }
232        unsafe {
233            let status_launch = hipModuleLaunchKernel(
234                function, // Kernel function
235                block_dim_x as u32,
236                1,
237                1, // Grid dimensions (group of blocks)
238                grid_dim_x as u32,
239                1,
240                1,                 // Block dimensions (group of threads)
241                0,                 // Shared memory size
242                stream,            // Created stream
243                args.as_mut_ptr(), // Kernel arguments
244                ptr::null_mut(),   // Extra options
245            );
246            assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
247        }
248        // not strictly necessary but for the sake of coverage we sync here
249        unsafe {
250            let status = hipDeviceSynchronize();
251            assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
252        }
253        let duration = start_time.elapsed();
254        println!("Execution time: {}µs", duration.as_micros());
255
256        // Step 9: Copy the result back to host memory
257        unsafe {
258            hipMemcpy(
259                out.as_mut_ptr() as *mut libc::c_void,
260                device_out,
261                n * std::mem::size_of::<f32>(),
262                hipMemcpyKind_hipMemcpyDeviceToHost,
263            );
264        }
265
266        // Step 10: Verify the results
267        for i in 0..n {
268            let result = out[i];
269            let expected = a * x[i] + b[i];
270            assert_eq!(result, expected, "Output mismatch at index {}", i);
271        }
272
273        // Step 11: Free up allocated memory on GPU device
274        unsafe {
275            let status = hipFree(device_x);
276            assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
277            let status = hipFree(device_b);
278            assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
279            let status = hipFree(device_out);
280            assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
281        }
282    }
283}