cubecl_hip_sys/
lib.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod hipconfig;
10pub use hipconfig::*;
11
12#[cfg(target_os = "linux")]
13mod bindings;
14#[cfg(target_os = "linux")]
15#[allow(unused)]
16pub use bindings::*;
17
18#[cfg(target_os = "linux")]
19#[cfg(test)]
20mod tests {
21    use super::bindings::*;
22    use std::{ffi::CString, ptr, time::Instant};
23
24    #[test]
25    fn test_launch_kernel_end_to_end() {
26        // Kernel that computes y values of a linear equation in slop-intercept form
27        let source = CString::new(
28            r#"
29extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
30  int tid = blockIdx.x * blockDim.x + threadIdx.x;
31  if (tid < n) {
32    out[tid] = x[tid] * a + b[tid];
33  }
34}
35 "#,
36        )
37        .expect("Should construct kernel string");
38
39        let func_name = CString::new("kernel".to_string()).unwrap();
40        // reference: https://rocm.docs.amd.com/projects/HIP/en/docs-6.0.0/user_guide/hip_rtc.html
41
42        // Step 0: Select the GPU device
43        unsafe {
44            let status = hipSetDevice(0);
45            assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
46        }
47
48        let free: usize = 0;
49        let total: usize = 0;
50        unsafe {
51            let status = hipMemGetInfo(
52                &free as *const _ as *mut usize,
53                &total as *const _ as *mut usize,
54            );
55            assert_eq!(
56                status, HIP_SUCCESS,
57                "Should get the available memory of the device"
58            );
59            println!("Free: {} | Total:{}", free, total);
60        }
61
62        // Step 1: Create the program
63        let mut program: hiprtcProgram = ptr::null_mut();
64        unsafe {
65            let status = hiprtcCreateProgram(
66                &mut program,    // Program
67                source.as_ptr(), // kernel string
68                ptr::null(),     // Name of the file (there is no file)
69                0,               // Number of headers
70                ptr::null_mut(), // Header sources
71                ptr::null_mut(), // Name of header files
72            );
73            assert_eq!(
74                status, hiprtcResult_HIPRTC_SUCCESS,
75                "Should create the program"
76            );
77        }
78
79        // Step 2: Compile the program
80        unsafe {
81            let status = hiprtcCompileProgram(
82                program,         // Program
83                0,               // Number of options
84                ptr::null_mut(), // Clang Options
85            );
86            if status != hiprtcResult_HIPRTC_SUCCESS {
87                let mut log_size: usize = 0;
88                let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
89                assert_eq!(
90                    status, hiprtcResult_HIPRTC_SUCCESS,
91                    "Should retrieve the compilation log size"
92                );
93                println!("Compilation log size: {log_size}");
94                let mut log_buffer = vec![0i8; log_size];
95                let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
96                assert_eq!(
97                    status, hiprtcResult_HIPRTC_SUCCESS,
98                    "Should retrieve the compilation log contents"
99                );
100                let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
101                println!("Compilation log: {}", log.to_string_lossy());
102            }
103            assert_eq!(
104                status, hiprtcResult_HIPRTC_SUCCESS,
105                "Should compile the program"
106            );
107        }
108
109        // Step 3: Load compiled code
110        let mut code_size: usize = 0;
111        unsafe {
112            let status = hiprtcGetCodeSize(program, &mut code_size);
113            assert_eq!(
114                status, hiprtcResult_HIPRTC_SUCCESS,
115                "Should get size of compiled code"
116            );
117        }
118        let mut code: Vec<u8> = vec![0; code_size];
119        unsafe {
120            let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
121            assert_eq!(
122                status, hiprtcResult_HIPRTC_SUCCESS,
123                "Should load compiled code"
124            );
125        }
126
127        // Step 4: Once the compiled code is loaded, the program can be destroyed
128        unsafe {
129            let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
130            assert_eq!(
131                status, hiprtcResult_HIPRTC_SUCCESS,
132                "Should destroy the program"
133            );
134        }
135        assert!(!code.is_empty(), "Generated code should not be empty");
136
137        // Step 5: Allocate Memory
138        let n = 1024;
139        let a = 2.0f32;
140        let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
141        let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
142        let mut out: Vec<f32> = vec![0.0; n];
143        // Allocate GPU memory for x, y, and out
144        // There is no need to allocate memory for a and n as we can pass
145        // host pointers directly to kernel launch function
146        let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147        let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
148        let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
149        unsafe {
150            let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
151            assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
152            let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
153            assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
154            let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
155            assert_eq!(
156                status_out, HIP_SUCCESS,
157                "Should allocate memory for device_out"
158            );
159        }
160
161        // Step 6: Copy data to GPU memory
162        unsafe {
163            let status_device_x = hipMemcpy(
164                device_x,
165                x.as_ptr() as *const libc::c_void,
166                n * std::mem::size_of::<f32>(),
167                hipMemcpyKind_hipMemcpyHostToDevice,
168            );
169            assert_eq!(
170                status_device_x, HIP_SUCCESS,
171                "Should copy device_x successfully"
172            );
173            let status_device_b = hipMemcpy(
174                device_b,
175                b.as_ptr() as *const libc::c_void,
176                n * std::mem::size_of::<f32>(),
177                hipMemcpyKind_hipMemcpyHostToDevice,
178            );
179            assert_eq!(
180                status_device_b, HIP_SUCCESS,
181                "Should copy device_b successfully"
182            );
183            // Initialize the output memory on device to 0.0
184            let status_device_out = hipMemcpy(
185                device_out,
186                out.as_ptr() as *const libc::c_void,
187                n * std::mem::size_of::<f32>(),
188                hipMemcpyKind_hipMemcpyHostToDevice,
189            );
190            assert_eq!(
191                status_device_out, HIP_SUCCESS,
192                "Should copy device_out successfully"
193            );
194        }
195
196        // Step 7: Create the module containing the kernel and get the function that points to it
197        let mut module: hipModule_t = ptr::null_mut();
198        let mut function: hipFunction_t = ptr::null_mut();
199        unsafe {
200            let status_module =
201                hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
202            assert_eq!(
203                status_module, HIP_SUCCESS,
204                "Should load compiled code into module"
205            );
206            let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
207            assert_eq!(
208                status_function, HIP_SUCCESS,
209                "Should return module function"
210            );
211        }
212
213        // Step 8: Launch Kernel
214        let start_time = Instant::now();
215        // Create the array of arguments to pass to the kernel
216        // They must be in the same order as the order of declaration of the kernel arguments
217        let mut args: [*mut libc::c_void; 5] = [
218            &a as *const _ as *mut libc::c_void,
219            &device_x as *const _ as *mut libc::c_void,
220            &device_b as *const _ as *mut libc::c_void,
221            &device_out as *const _ as *mut libc::c_void,
222            &n as *const _ as *mut libc::c_void,
223        ];
224        let block_dim_x: usize = 64;
225        let grid_dim_x: usize = n / block_dim_x;
226        // We could use the default stream by passing 0 to the launch kernel but for the sake of
227        // coverage we create a stream explicitly
228        let mut stream: hipStream_t = std::ptr::null_mut();
229        unsafe {
230            let stream_status = hipStreamCreate(&mut stream);
231            assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
232        }
233        unsafe {
234            let status_launch = hipModuleLaunchKernel(
235                function, // Kernel function
236                block_dim_x as u32,
237                1,
238                1, // Grid dimensions (group of blocks)
239                grid_dim_x as u32,
240                1,
241                1,                 // Block dimensions (group of threads)
242                0,                 // Shared memory size
243                stream,            // Created stream
244                args.as_mut_ptr(), // Kernel arguments
245                ptr::null_mut(),   // Extra options
246            );
247            assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
248        }
249        // not strictly necessary but for the sake of coverage we sync here
250        unsafe {
251            let status = hipDeviceSynchronize();
252            assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
253        }
254        let duration = start_time.elapsed();
255        println!("Execution time: {}µs", duration.as_micros());
256
257        // Step 9: Copy the result back to host memory
258        unsafe {
259            hipMemcpy(
260                out.as_mut_ptr() as *mut libc::c_void,
261                device_out,
262                n * std::mem::size_of::<f32>(),
263                hipMemcpyKind_hipMemcpyDeviceToHost,
264            );
265        }
266
267        // Step 10: Verify the results
268        for i in 0..n {
269            let result = out[i];
270            let expected = a * x[i] + b[i];
271            assert_eq!(result, expected, "Output mismatch at index {}", i);
272        }
273
274        // Step 11: Free up allocated memory on GPU device
275        unsafe {
276            let status = hipFree(device_x);
277            assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
278            let status = hipFree(device_b);
279            assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
280            let status = hipFree(device_out);
281            assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
282        }
283    }
284}