cubecl_hip_sys/
lib.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod hipconfig;
10pub use hipconfig::*;
11
12mod bindings;
13#[allow(unused)]
14pub use bindings::*;
15
16#[cfg(target_os = "linux")]
17#[cfg(test)]
18mod tests {
19    use super::bindings::*;
20    use std::{ffi::CString, ptr, time::Instant};
21
22    #[test]
23    fn test_launch_kernel_end_to_end() {
24        // Kernel that computes y values of a linear equation in slop-intercept form
25        let source = CString::new(
26            r#"
27extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
28  int tid = blockIdx.x * blockDim.x + threadIdx.x;
29  if (tid < n) {
30    out[tid] = x[tid] * a + b[tid];
31  }
32}
33 "#,
34        )
35        .expect("Should construct kernel string");
36
37        let func_name = CString::new("kernel".to_string()).unwrap();
38        // reference: https://rocm.docs.amd.com/projects/HIP/en/docs-6.0.0/user_guide/hip_rtc.html
39
40        // Step 0: Select the GPU device
41        unsafe {
42            let status = hipSetDevice(0);
43            assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
44        }
45
46        let free: usize = 0;
47        let total: usize = 0;
48        unsafe {
49            let status = hipMemGetInfo(
50                &free as *const _ as *mut usize,
51                &total as *const _ as *mut usize,
52            );
53            assert_eq!(
54                status, HIP_SUCCESS,
55                "Should get the available memory of the device"
56            );
57            println!("Free: {} | Total:{}", free, total);
58        }
59
60        // Step 1: Create the program
61        let mut program: hiprtcProgram = ptr::null_mut();
62        unsafe {
63            let status = hiprtcCreateProgram(
64                &mut program,    // Program
65                source.as_ptr(), // kernel string
66                ptr::null(),     // Name of the file (there is no file)
67                0,               // Number of headers
68                ptr::null_mut(), // Header sources
69                ptr::null_mut(), // Name of header files
70            );
71            assert_eq!(
72                status, hiprtcResult_HIPRTC_SUCCESS,
73                "Should create the program"
74            );
75        }
76
77        // Step 2: Compile the program
78        unsafe {
79            let status = hiprtcCompileProgram(
80                program,         // Program
81                0,               // Number of options
82                ptr::null_mut(), // Clang Options
83            );
84            if status != hiprtcResult_HIPRTC_SUCCESS {
85                let mut log_size: usize = 0;
86                let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
87                assert_eq!(
88                    status, hiprtcResult_HIPRTC_SUCCESS,
89                    "Should retrieve the compilation log size"
90                );
91                println!("Compilation log size: {log_size}");
92                let mut log_buffer = vec![0i8; log_size];
93                let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
94                assert_eq!(
95                    status, hiprtcResult_HIPRTC_SUCCESS,
96                    "Should retrieve the compilation log contents"
97                );
98                let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
99                println!("Compilation log: {}", log.to_string_lossy());
100            }
101            assert_eq!(
102                status, hiprtcResult_HIPRTC_SUCCESS,
103                "Should compile the program"
104            );
105        }
106
107        // Step 3: Load compiled code
108        let mut code_size: usize = 0;
109        unsafe {
110            let status = hiprtcGetCodeSize(program, &mut code_size);
111            assert_eq!(
112                status, hiprtcResult_HIPRTC_SUCCESS,
113                "Should get size of compiled code"
114            );
115        }
116        let mut code: Vec<u8> = vec![0; code_size];
117        unsafe {
118            let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
119            assert_eq!(
120                status, hiprtcResult_HIPRTC_SUCCESS,
121                "Should load compiled code"
122            );
123        }
124
125        // Step 4: Once the compiled code is loaded, the program can be destroyed
126        unsafe {
127            let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
128            assert_eq!(
129                status, hiprtcResult_HIPRTC_SUCCESS,
130                "Should destroy the program"
131            );
132        }
133        assert!(!code.is_empty(), "Generated code should not be empty");
134
135        // Step 5: Allocate Memory
136        let n = 1024;
137        let a = 2.0f32;
138        let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
139        let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
140        let mut out: Vec<f32> = vec![0.0; n];
141        // Allocate GPU memory for x, y, and out
142        // There is no need to allocate memory for a and n as we can pass
143        // host pointers directly to kernel launch function
144        let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
145        let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
146        let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147        unsafe {
148            let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
149            assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
150            let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
151            assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
152            let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
153            assert_eq!(
154                status_out, HIP_SUCCESS,
155                "Should allocate memory for device_out"
156            );
157        }
158
159        // Step 6: Copy data to GPU memory
160        unsafe {
161            let status_device_x = hipMemcpy(
162                device_x,
163                x.as_ptr() as *const libc::c_void,
164                n * std::mem::size_of::<f32>(),
165                hipMemcpyKind_hipMemcpyHostToDevice,
166            );
167            assert_eq!(
168                status_device_x, HIP_SUCCESS,
169                "Should copy device_x successfully"
170            );
171            let status_device_b = hipMemcpy(
172                device_b,
173                b.as_ptr() as *const libc::c_void,
174                n * std::mem::size_of::<f32>(),
175                hipMemcpyKind_hipMemcpyHostToDevice,
176            );
177            assert_eq!(
178                status_device_b, HIP_SUCCESS,
179                "Should copy device_b successfully"
180            );
181            // Initialize the output memory on device to 0.0
182            let status_device_out = hipMemcpy(
183                device_out,
184                out.as_ptr() as *const libc::c_void,
185                n * std::mem::size_of::<f32>(),
186                hipMemcpyKind_hipMemcpyHostToDevice,
187            );
188            assert_eq!(
189                status_device_out, HIP_SUCCESS,
190                "Should copy device_out successfully"
191            );
192        }
193
194        // Step 7: Create the module containing the kernel and get the function that points to it
195        let mut module: hipModule_t = ptr::null_mut();
196        let mut function: hipFunction_t = ptr::null_mut();
197        unsafe {
198            let status_module =
199                hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
200            assert_eq!(
201                status_module, HIP_SUCCESS,
202                "Should load compiled code into module"
203            );
204            let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
205            assert_eq!(
206                status_function, HIP_SUCCESS,
207                "Should return module function"
208            );
209        }
210
211        // Step 8: Launch Kernel
212        let start_time = Instant::now();
213        // Create the array of arguments to pass to the kernel
214        // They must be in the same order as the order of declaration of the kernel arguments
215        let mut args: [*mut libc::c_void; 5] = [
216            &a as *const _ as *mut libc::c_void,
217            &device_x as *const _ as *mut libc::c_void,
218            &device_b as *const _ as *mut libc::c_void,
219            &device_out as *const _ as *mut libc::c_void,
220            &n as *const _ as *mut libc::c_void,
221        ];
222        let block_dim_x: usize = 64;
223        let grid_dim_x: usize = n / block_dim_x;
224        // We could use the default stream by passing 0 to the launch kernel but for the sake of
225        // coverage we create a stream explicitly
226        let mut stream: hipStream_t = std::ptr::null_mut();
227        unsafe {
228            let stream_status = hipStreamCreate(&mut stream);
229            assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
230        }
231        unsafe {
232            let status_launch = hipModuleLaunchKernel(
233                function, // Kernel function
234                block_dim_x as u32,
235                1,
236                1, // Grid dimensions (group of blocks)
237                grid_dim_x as u32,
238                1,
239                1,                 // Block dimensions (group of threads)
240                0,                 // Shared memory size
241                stream,            // Created stream
242                args.as_mut_ptr(), // Kernel arguments
243                ptr::null_mut(),   // Extra options
244            );
245            assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
246        }
247        // not strictly necessary but for the sake of coverage we sync here
248        unsafe {
249            let status = hipDeviceSynchronize();
250            assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
251        }
252        let duration = start_time.elapsed();
253        println!("Execution time: {}µs", duration.as_micros());
254
255        // Step 9: Copy the result back to host memory
256        unsafe {
257            hipMemcpy(
258                out.as_mut_ptr() as *mut libc::c_void,
259                device_out,
260                n * std::mem::size_of::<f32>(),
261                hipMemcpyKind_hipMemcpyDeviceToHost,
262            );
263        }
264
265        // Step 10: Verify the results
266        for i in 0..n {
267            let result = out[i];
268            let expected = a * x[i] + b[i];
269            assert_eq!(result, expected, "Output mismatch at index {}", i);
270        }
271
272        // Step 11: Free up allocated memory on GPU device
273        unsafe {
274            let status = hipFree(device_x);
275            assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
276            let status = hipFree(device_b);
277            assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
278            let status = hipFree(device_out);
279            assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
280        }
281    }
282}