1#![allow(clippy::too_many_arguments)]
2#![allow(clippy::useless_transmute)]
3#![allow(improper_ctypes)]
4#![allow(non_camel_case_types)]
5#![allow(non_snake_case)]
6#![allow(non_upper_case_globals)]
7#![allow(unused_variables)]
8
9pub mod build_script;
10
11#[cfg(target_os = "linux")]
12mod bindings;
13#[cfg(target_os = "linux")]
14#[allow(unused)]
15pub use bindings::*;
16
17#[cfg(target_os = "linux")]
18#[cfg(test)]
19mod tests {
20 use super::bindings::*;
21 use std::{ffi::CString, ptr, time::Instant};
22
23 #[test]
24 fn test_launch_kernel_end_to_end() {
25 let source = CString::new(
27 r#"
28extern "C" __global__ void kernel(float a, float *x, float *b, float *out, int n) {
29 int tid = blockIdx.x * blockDim.x + threadIdx.x;
30 if (tid < n) {
31 out[tid] = x[tid] * a + b[tid];
32 }
33}
34 "#,
35 )
36 .expect("Should construct kernel string");
37
38 let func_name = CString::new("kernel".to_string()).unwrap();
39 unsafe {
43 let status = hipSetDevice(0);
44 assert_eq!(status, HIP_SUCCESS, "Should set the GPU device");
45 }
46
47 let free: usize = 0;
48 let total: usize = 0;
49 unsafe {
50 let status = hipMemGetInfo(
51 &free as *const _ as *mut usize,
52 &total as *const _ as *mut usize,
53 );
54 assert_eq!(
55 status, HIP_SUCCESS,
56 "Should get the available memory of the device"
57 );
58 println!("Free: {} | Total:{}", free, total);
59 }
60
61 let mut program: hiprtcProgram = ptr::null_mut();
63 unsafe {
64 let status = hiprtcCreateProgram(
65 &mut program, source.as_ptr(), ptr::null(), 0, ptr::null_mut(), ptr::null_mut(), );
72 assert_eq!(
73 status, hiprtcResult_HIPRTC_SUCCESS,
74 "Should create the program"
75 );
76 }
77
78 unsafe {
80 let status = hiprtcCompileProgram(
81 program, 0, ptr::null_mut(), );
85 if status != hiprtcResult_HIPRTC_SUCCESS {
86 let mut log_size: usize = 0;
87 let status = hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
88 assert_eq!(
89 status, hiprtcResult_HIPRTC_SUCCESS,
90 "Should retrieve the compilation log size"
91 );
92 println!("Compilation log size: {log_size}");
93 let mut log_buffer = vec![0i8; log_size];
94 let status = hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
95 assert_eq!(
96 status, hiprtcResult_HIPRTC_SUCCESS,
97 "Should retrieve the compilation log contents"
98 );
99 let log = std::ffi::CStr::from_ptr(log_buffer.as_ptr());
100 println!("Compilation log: {}", log.to_string_lossy());
101 }
102 assert_eq!(
103 status, hiprtcResult_HIPRTC_SUCCESS,
104 "Should compile the program"
105 );
106 }
107
108 let mut code_size: usize = 0;
110 unsafe {
111 let status = hiprtcGetCodeSize(program, &mut code_size);
112 assert_eq!(
113 status, hiprtcResult_HIPRTC_SUCCESS,
114 "Should get size of compiled code"
115 );
116 }
117 let mut code: Vec<u8> = vec![0; code_size];
118 unsafe {
119 let status = hiprtcGetCode(program, code.as_mut_ptr() as *mut _);
120 assert_eq!(
121 status, hiprtcResult_HIPRTC_SUCCESS,
122 "Should load compiled code"
123 );
124 }
125
126 unsafe {
128 let status = hiprtcDestroyProgram(&mut program as *mut *mut _);
129 assert_eq!(
130 status, hiprtcResult_HIPRTC_SUCCESS,
131 "Should destroy the program"
132 );
133 }
134 assert!(!code.is_empty(), "Generated code should not be empty");
135
136 let n = 1024;
138 let a = 2.0f32;
139 let x: Vec<f32> = (0..n).map(|i| i as f32).collect();
140 let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
141 let mut out: Vec<f32> = vec![0.0; n];
142 let mut device_x: *mut ::std::os::raw::c_void = std::ptr::null_mut();
146 let mut device_b: *mut ::std::os::raw::c_void = std::ptr::null_mut();
147 let mut device_out: *mut ::std::os::raw::c_void = std::ptr::null_mut();
148 unsafe {
149 let status_x = hipMalloc(&mut device_x, n * std::mem::size_of::<f32>());
150 assert_eq!(status_x, HIP_SUCCESS, "Should allocate memory for device_x");
151 let status_b = hipMalloc(&mut device_b, n * std::mem::size_of::<f32>());
152 assert_eq!(status_b, HIP_SUCCESS, "Should allocate memory for device_b");
153 let status_out = hipMalloc(&mut device_out, n * std::mem::size_of::<f32>());
154 assert_eq!(
155 status_out, HIP_SUCCESS,
156 "Should allocate memory for device_out"
157 );
158 }
159
160 unsafe {
162 let status_device_x = hipMemcpy(
163 device_x,
164 x.as_ptr() as *const libc::c_void,
165 n * std::mem::size_of::<f32>(),
166 hipMemcpyKind_hipMemcpyHostToDevice,
167 );
168 assert_eq!(
169 status_device_x, HIP_SUCCESS,
170 "Should copy device_x successfully"
171 );
172 let status_device_b = hipMemcpy(
173 device_b,
174 b.as_ptr() as *const libc::c_void,
175 n * std::mem::size_of::<f32>(),
176 hipMemcpyKind_hipMemcpyHostToDevice,
177 );
178 assert_eq!(
179 status_device_b, HIP_SUCCESS,
180 "Should copy device_b successfully"
181 );
182 let status_device_out = hipMemcpy(
184 device_out,
185 out.as_ptr() as *const libc::c_void,
186 n * std::mem::size_of::<f32>(),
187 hipMemcpyKind_hipMemcpyHostToDevice,
188 );
189 assert_eq!(
190 status_device_out, HIP_SUCCESS,
191 "Should copy device_out successfully"
192 );
193 }
194
195 let mut module: hipModule_t = ptr::null_mut();
197 let mut function: hipFunction_t = ptr::null_mut();
198 unsafe {
199 let status_module =
200 hipModuleLoadData(&mut module, code.as_ptr() as *const libc::c_void);
201 assert_eq!(
202 status_module, HIP_SUCCESS,
203 "Should load compiled code into module"
204 );
205 let status_function = hipModuleGetFunction(&mut function, module, func_name.as_ptr());
206 assert_eq!(
207 status_function, HIP_SUCCESS,
208 "Should return module function"
209 );
210 }
211
212 let start_time = Instant::now();
214 let mut args: [*mut libc::c_void; 5] = [
217 &a as *const _ as *mut libc::c_void,
218 &device_x as *const _ as *mut libc::c_void,
219 &device_b as *const _ as *mut libc::c_void,
220 &device_out as *const _ as *mut libc::c_void,
221 &n as *const _ as *mut libc::c_void,
222 ];
223 let block_dim_x: usize = 64;
224 let grid_dim_x: usize = n / block_dim_x;
225 let mut stream: hipStream_t = std::ptr::null_mut();
228 unsafe {
229 let stream_status = hipStreamCreate(&mut stream);
230 assert_eq!(stream_status, HIP_SUCCESS, "Should create a stream");
231 }
232 unsafe {
233 let status_launch = hipModuleLaunchKernel(
234 function, block_dim_x as u32,
236 1,
237 1, grid_dim_x as u32,
239 1,
240 1, 0, stream, args.as_mut_ptr(), ptr::null_mut(), );
246 assert_eq!(status_launch, HIP_SUCCESS, "Should launch the kernel");
247 }
248 unsafe {
250 let status = hipDeviceSynchronize();
251 assert_eq!(status, HIP_SUCCESS, "Should sync with the device");
252 }
253 let duration = start_time.elapsed();
254 println!("Execution time: {}µs", duration.as_micros());
255
256 unsafe {
258 hipMemcpy(
259 out.as_mut_ptr() as *mut libc::c_void,
260 device_out,
261 n * std::mem::size_of::<f32>(),
262 hipMemcpyKind_hipMemcpyDeviceToHost,
263 );
264 }
265
266 for i in 0..n {
268 let result = out[i];
269 let expected = a * x[i] + b[i];
270 assert_eq!(result, expected, "Output mismatch at index {}", i);
271 }
272
273 unsafe {
275 let status = hipFree(device_x);
276 assert_eq!(status, HIP_SUCCESS, "Should free device_x successfully");
277 let status = hipFree(device_b);
278 assert_eq!(status, HIP_SUCCESS, "Should free device_b successfully");
279 let status = hipFree(device_out);
280 assert_eq!(status, HIP_SUCCESS, "Should free device_out successfully");
281 }
282 }
283}