1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod hipconfig;
10pub use hipconfig::*;
11
12#[cfg(target_os = "linux")]
13mod bindings;
14#[cfg(target_os = "linux")]
15#[allow(unused)]
16pub use bindings::*;
17
18#[cfg(target_os = "linux")]
19#[cfg(test)]
20mod tests {
21 use super::bindings::*;
22 use std::{ffi::CString, ptr, time::Instant};
23
24 #[test]
25 fn test_launch_kernel_end_to_end() {
26 let source = CString::new(
28 r#"
29extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
30 int tid = blockIdx.x * blockDim.x + threadIdx.x;
31 if (tid < n) {
32 out[tid] = x[tid] * a + b[tid];
33 }
34}
35 "#,
36 )
37 .expect("Should construct kernel string");
38
39 let func_name = CString::new("kernel".to_string()).unwrap();
40 unsafe {
44 let status = hipSetDevice(0);
45 assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
46 }
47
48 let free: usize = 0;
49 let total: usize = 0;
50 unsafe {
51 let status = hipMemGetInfo(
52 &free as *const _ as *mut usize,
53 &total as *const _ as *mut usize,
54 );
55 assert_eq!(
56 status, HIP_SUCCESS,
57 "Should get the available memory of the device"
58 );
59 println!("Free: {} | Total:{}", free, total);
60 }
61
62 let mut program: hiprtcProgram = ptr::null_mut();
64 unsafe {
65 let status = hiprtcCreateProgram(
66 &mut program, source.as_ptr(), ptr::null(), 0, ptr::null_mut(), ptr::null_mut(), );
73 assert_eq!(
74 status, hiprtcResult_HIPRTC_SUCCESS,
75 "Should create the program"
76 );
77 }
78
79 unsafe {
81 let status = hiprtcCompileProgram(
82 program, 0, ptr::null_mut(), );
86 if status != hiprtcResult_HIPRTC_SUCCESS {
87 let mut log_size: usize = 0;
88 let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
89 assert_eq!(
90 status, hiprtcResult_HIPRTC_SUCCESS,
91 "Should retrieve the compilation log size"
92 );
93 println!("Compilation log size: {log_size}");
94 let mut log_buffer = vec![0i8; log_size];
95 let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
96 assert_eq!(
97 status, hiprtcResult_HIPRTC_SUCCESS,
98 "Should retrieve the compilation log contents"
99 );
100 let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
101 println!("Compilation log: {}", log.to_string_lossy());
102 }
103 assert_eq!(
104 status, hiprtcResult_HIPRTC_SUCCESS,
105 "Should compile the program"
106 );
107 }
108
109 let mut code_size: usize = 0;
111 unsafe {
112 let status = hiprtcGetCodeSize(program, &mut code_size);
113 assert_eq!(
114 status, hiprtcResult_HIPRTC_SUCCESS,
115 "Should get size of compiled code"
116 );
117 }
118 let mut code: Vec<u8> = vec![0; code_size];
119 unsafe {
120 let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
121 assert_eq!(
122 status, hiprtcResult_HIPRTC_SUCCESS,
123 "Should load compiled code"
124 );
125 }
126
127 unsafe {
129 let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
130 assert_eq!(
131 status, hiprtcResult_HIPRTC_SUCCESS,
132 "Should destroy the program"
133 );
134 }
135 assert!(!code.is_empty(), "Generated code should not be empty");
136
137 let n = 1024;
139 let a = 2.0f32;
140 let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
141 let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
142 let mut out: Vec<f32> = vec![0.0; n];
143 let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147 let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
148 let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
149 unsafe {
150 let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
151 assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
152 let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
153 assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
154 let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
155 assert_eq!(
156 status_out, HIP_SUCCESS,
157 "Should allocate memory for device_out"
158 );
159 }
160
161 unsafe {
163 let status_device_x = hipMemcpy(
164 device_x,
165 x.as_ptr() as *const libc::c_void,
166 n * std::mem::size_of::<f32>(),
167 hipMemcpyKind_hipMemcpyHostToDevice,
168 );
169 assert_eq!(
170 status_device_x, HIP_SUCCESS,
171 "Should copy device_x successfully"
172 );
173 let status_device_b = hipMemcpy(
174 device_b,
175 b.as_ptr() as *const libc::c_void,
176 n * std::mem::size_of::<f32>(),
177 hipMemcpyKind_hipMemcpyHostToDevice,
178 );
179 assert_eq!(
180 status_device_b, HIP_SUCCESS,
181 "Should copy device_b successfully"
182 );
183 let status_device_out = hipMemcpy(
185 device_out,
186 out.as_ptr() as *const libc::c_void,
187 n * std::mem::size_of::<f32>(),
188 hipMemcpyKind_hipMemcpyHostToDevice,
189 );
190 assert_eq!(
191 status_device_out, HIP_SUCCESS,
192 "Should copy device_out successfully"
193 );
194 }
195
196 let mut module: hipModule_t = ptr::null_mut();
198 let mut function: hipFunction_t = ptr::null_mut();
199 unsafe {
200 let status_module =
201 hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
202 assert_eq!(
203 status_module, HIP_SUCCESS,
204 "Should load compiled code into module"
205 );
206 let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
207 assert_eq!(
208 status_function, HIP_SUCCESS,
209 "Should return module function"
210 );
211 }
212
213 let start_time = Instant::now();
215 let mut args: [*mut libc::c_void; 5] = [
218 &a as *const _ as *mut libc::c_void,
219 &device_x as *const _ as *mut libc::c_void,
220 &device_b as *const _ as *mut libc::c_void,
221 &device_out as *const _ as *mut libc::c_void,
222 &n as *const _ as *mut libc::c_void,
223 ];
224 let block_dim_x: usize = 64;
225 let grid_dim_x: usize = n / block_dim_x;
226 let mut stream: hipStream_t = std::ptr::null_mut();
229 unsafe {
230 let stream_status = hipStreamCreate(&mut stream);
231 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
232 }
233 unsafe {
234 let status_launch = hipModuleLaunchKernel(
235 function, block_dim_x as u32,
237 1,
238 1, grid_dim_x as u32,
240 1,
241 1, 0, stream, args.as_mut_ptr(), ptr::null_mut(), );
247 assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
248 }
249 unsafe {
251 let status = hipDeviceSynchronize();
252 assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
253 }
254 let duration = start_time.elapsed();
255 println!("Execution time: {}µs", duration.as_micros());
256
257 unsafe {
259 hipMemcpy(
260 out.as_mut_ptr() as *mut libc::c_void,
261 device_out,
262 n * std::mem::size_of::<f32>(),
263 hipMemcpyKind_hipMemcpyDeviceToHost,
264 );
265 }
266
267 for i in 0..n {
269 let result = out[i];
270 let expected = a * x[i] + b[i];
271 assert_eq!(result, expected, "Output mismatch at index {}", i);
272 }
273
274 unsafe {
276 let status = hipFree(device_x);
277 assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
278 let status = hipFree(device_b);
279 assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
280 let status = hipFree(device_out);
281 assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
282 }
283 }
284}