1use crate::hip::Stream;
6use crate::hip::error::{Error, Result};
7use crate::hip::ffi;
8use crate::hip::utils::Dim3;
9use std::ffi::{CString, c_void};
10use std::ptr;
11
12pub struct Function {
14 function: ffi::hipFunction_t,
15}
16
17impl Function {
18 pub unsafe fn new(module: ffi::hipModule_t, name: &str) -> Result<Self> {
20 let func_name = CString::new(name).unwrap();
21 let mut function = ptr::null_mut();
22
23 let error = unsafe { ffi::hipModuleGetFunction(&mut function, module, func_name.as_ptr()) };
24
25 if error != ffi::hipError_t_hipSuccess {
26 return Err(Error::new(error));
27 }
28
29 Ok(Self { function })
30 }
31
32 pub fn launch(
34 &self,
35 grid_dim: Dim3,
36 block_dim: Dim3,
37 shared_mem_bytes: u32,
38 stream: Option<&Stream>,
39 kernel_params: &mut [*mut c_void],
40 ) -> Result<()> {
41 let stream_ptr = match stream {
42 Some(s) => s.as_raw(),
43 None => ptr::null_mut(),
44 };
45
46 let error = unsafe {
47 ffi::hipModuleLaunchKernel(
48 self.function,
49 grid_dim.x,
50 grid_dim.y,
51 grid_dim.z,
52 block_dim.x,
53 block_dim.y,
54 block_dim.z,
55 shared_mem_bytes,
56 stream_ptr,
57 kernel_params.as_mut_ptr(),
58 ptr::null_mut(), )
60 };
61
62 if error != ffi::hipError_t_hipSuccess {
63 return Err(Error::new(error));
64 }
65
66 Ok(())
67 }
68
69 pub fn as_raw(&self) -> ffi::hipFunction_t {
71 self.function
72 }
73}
74
75pub trait KernelArg {
77 fn as_ptr(&self) -> *const c_void;
79}
80
81macro_rules! impl_kernel_arg {
83 ($($t:ty),*) => {
84 $(
85 impl KernelArg for $t {
86 fn as_ptr(&self) -> *const c_void {
87 self as *const $t as *const c_void
88 }
89 }
90 )*
91 };
92}
93
94impl_kernel_arg!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64);
95
96#[macro_export]
99macro_rules! launch_kernel {
100 ($func:expr, $grid:expr, $block:expr, $shared_mem:expr, $stream:expr, $($arg:expr),*) => {
101 {
102 let mut args = vec![];
103 $(
104 args.push($arg.as_ptr() as *mut std::ffi::c_void);
105 )*
106
107 $func.launch($grid, $block, $shared_mem, $stream, &mut args)
108 }
109 };
110}
111
112pub unsafe fn launch_kernel(
119 kernel_func_ptr: *const c_void,
120 grid_dim: Dim3,
121 block_dim: Dim3,
122 shared_mem_bytes: u32,
123 stream: Option<&Stream>,
124 args: &[*mut c_void],
125) -> Result<()> {
126 let stream_ptr = match stream {
127 Some(s) => s.as_raw(),
128 None => ptr::null_mut(),
129 };
130
131 let native_grid_dim = grid_dim.to_native();
132 let native_block_dim = block_dim.to_native();
133
134 let error = unsafe {
135 ffi::hipLaunchKernel(
136 kernel_func_ptr,
137 native_grid_dim,
138 native_block_dim,
139 args.as_ptr() as *mut *mut c_void,
140 shared_mem_bytes.try_into().unwrap(),
141 stream_ptr,
142 )
143 };
144
145 if error != ffi::hipError_t_hipSuccess {
146 return Err(Error::new(error));
147 }
148
149 Ok(())
150}
151
152#[macro_export]
157macro_rules! kernel_launcher {
158 ($name:ident, $func:path, $($arg_ty:ty),*) => {
159 pub fn $name(
160 grid_dim: $crate::hip::utils::Dim3,
161 block_dim: $crate::hip::utils::Dim3,
162 shared_mem_bytes: u32,
163 stream: Option<&$crate::hip::Stream>,
164 $($arg:$arg_ty),*
165 ) -> $crate::hip::error::Result<()> {
166 unsafe {
167 let args: Vec<*mut std::ffi::c_void> = vec![
168 $(&$arg as *const $arg_ty as *mut std::ffi::c_void),*
169 ];
170
171 $crate::hip::kernel::launch_kernel(
172 $func,
173 grid_dim,
174 block_dim,
175 shared_mem_bytes,
176 stream,
177 &args,
178 )
179 }
180 }
181 };
182}
183
184pub fn stream_to_rocrand(stream: &Stream) -> crate::rocrand::bindings::hipStream_t {
186 stream.as_raw() as crate::rocrand::bindings::hipStream_t
188}