1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod hipconfig;
10pub use hipconfig::*;
11
12mod bindings;
13#[allow(unused)]
14pub use bindings::*;
15
16#[cfg(target_os = "linux")]
17#[cfg(test)]
18mod tests {
19 use super::bindings::*;
20 use std::{ffi::CString, ptr, time::Instant};
21
22 #[test]
23 fn test_launch_kernel_end_to_end() {
24 let source = CString::new(
26 r#"
27extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
28 int tid = blockIdx.x * blockDim.x + threadIdx.x;
29 if (tid < n) {
30 out[tid] = x[tid] * a + b[tid];
31 }
32}
33 "#,
34 )
35 .expect("Should construct kernel string");
36
37 let func_name = CString::new("kernel".to_string()).unwrap();
38 unsafe {
42 let status = hipSetDevice(0);
43 assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
44 }
45
46 let free: usize = 0;
47 let total: usize = 0;
48 unsafe {
49 let status = hipMemGetInfo(
50 &free as *const _ as *mut usize,
51 &total as *const _ as *mut usize,
52 );
53 assert_eq!(
54 status, HIP_SUCCESS,
55 "Should get the available memory of the device"
56 );
57 println!("Free: {} | Total:{}", free, total);
58 }
59
60 let mut program: hiprtcProgram = ptr::null_mut();
62 unsafe {
63 let status = hiprtcCreateProgram(
64 &mut program, source.as_ptr(), ptr::null(), 0, ptr::null_mut(), ptr::null_mut(), );
71 assert_eq!(
72 status, hiprtcResult_HIPRTC_SUCCESS,
73 "Should create the program"
74 );
75 }
76
77 unsafe {
79 let status = hiprtcCompileProgram(
80 program, 0, ptr::null_mut(), );
84 if status != hiprtcResult_HIPRTC_SUCCESS {
85 let mut log_size: usize = 0;
86 let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
87 assert_eq!(
88 status, hiprtcResult_HIPRTC_SUCCESS,
89 "Should retrieve the compilation log size"
90 );
91 println!("Compilation log size: {log_size}");
92 let mut log_buffer = vec![0i8; log_size];
93 let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
94 assert_eq!(
95 status, hiprtcResult_HIPRTC_SUCCESS,
96 "Should retrieve the compilation log contents"
97 );
98 let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
99 println!("Compilation log: {}", log.to_string_lossy());
100 }
101 assert_eq!(
102 status, hiprtcResult_HIPRTC_SUCCESS,
103 "Should compile the program"
104 );
105 }
106
107 let mut code_size: usize = 0;
109 unsafe {
110 let status = hiprtcGetCodeSize(program, &mut code_size);
111 assert_eq!(
112 status, hiprtcResult_HIPRTC_SUCCESS,
113 "Should get size of compiled code"
114 );
115 }
116 let mut code: Vec<u8> = vec![0; code_size];
117 unsafe {
118 let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
119 assert_eq!(
120 status, hiprtcResult_HIPRTC_SUCCESS,
121 "Should load compiled code"
122 );
123 }
124
125 unsafe {
127 let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
128 assert_eq!(
129 status, hiprtcResult_HIPRTC_SUCCESS,
130 "Should destroy the program"
131 );
132 }
133 assert!(!code.is_empty(), "Generated code should not be empty");
134
135 let n = 1024;
137 let a = 2.0f32;
138 let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
139 let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
140 let mut out: Vec<f32> = vec![0.0; n];
141 let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
145 let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
146 let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147 unsafe {
148 let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
149 assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
150 let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
151 assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
152 let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
153 assert_eq!(
154 status_out, HIP_SUCCESS,
155 "Should allocate memory for device_out"
156 );
157 }
158
159 unsafe {
161 let status_device_x = hipMemcpy(
162 device_x,
163 x.as_ptr() as *const libc::c_void,
164 n * std::mem::size_of::<f32>(),
165 hipMemcpyKind_hipMemcpyHostToDevice,
166 );
167 assert_eq!(
168 status_device_x, HIP_SUCCESS,
169 "Should copy device_x successfully"
170 );
171 let status_device_b = hipMemcpy(
172 device_b,
173 b.as_ptr() as *const libc::c_void,
174 n * std::mem::size_of::<f32>(),
175 hipMemcpyKind_hipMemcpyHostToDevice,
176 );
177 assert_eq!(
178 status_device_b, HIP_SUCCESS,
179 "Should copy device_b successfully"
180 );
181 let status_device_out = hipMemcpy(
183 device_out,
184 out.as_ptr() as *const libc::c_void,
185 n * std::mem::size_of::<f32>(),
186 hipMemcpyKind_hipMemcpyHostToDevice,
187 );
188 assert_eq!(
189 status_device_out, HIP_SUCCESS,
190 "Should copy device_out successfully"
191 );
192 }
193
194 let mut module: hipModule_t = ptr::null_mut();
196 let mut function: hipFunction_t = ptr::null_mut();
197 unsafe {
198 let status_module =
199 hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
200 assert_eq!(
201 status_module, HIP_SUCCESS,
202 "Should load compiled code into module"
203 );
204 let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
205 assert_eq!(
206 status_function, HIP_SUCCESS,
207 "Should return module function"
208 );
209 }
210
211 let start_time = Instant::now();
213 let mut args: [*mut libc::c_void; 5] = [
216 &a as *const _ as *mut libc::c_void,
217 &device_x as *const _ as *mut libc::c_void,
218 &device_b as *const _ as *mut libc::c_void,
219 &device_out as *const _ as *mut libc::c_void,
220 &n as *const _ as *mut libc::c_void,
221 ];
222 let block_dim_x: usize = 64;
223 let grid_dim_x: usize = n / block_dim_x;
224 let mut stream: hipStream_t = std::ptr::null_mut();
227 unsafe {
228 let stream_status = hipStreamCreate(&mut stream);
229 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
230 }
231 unsafe {
232 let status_launch = hipModuleLaunchKernel(
233 function, block_dim_x as u32,
235 1,
236 1, grid_dim_x as u32,
238 1,
239 1, 0, stream, args.as_mut_ptr(), ptr::null_mut(), );
245 assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
246 }
247 unsafe {
249 let status = hipDeviceSynchronize();
250 assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
251 }
252 let duration = start_time.elapsed();
253 println!("Execution time: {}µs", duration.as_micros());
254
255 unsafe {
257 hipMemcpy(
258 out.as_mut_ptr() as *mut libc::c_void,
259 device_out,
260 n * std::mem::size_of::<f32>(),
261 hipMemcpyKind_hipMemcpyDeviceToHost,
262 );
263 }
264
265 for i in 0..n {
267 let result = out[i];
268 let expected = a * x[i] + b[i];
269 assert_eq!(result, expected, "Output mismatch at index {}", i);
270 }
271
272 unsafe {
274 let status = hipFree(device_x);
275 assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
276 let status = hipFree(device_b);
277 assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
278 let status = hipFree(device_out);
279 assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
280 }
281 }
282}