rocm_rs/hip/
kernel.rs

1// src/hip/kernel.rs
2//
3// Kernel launching functions for HIP
4
5use crate::hip::Stream;
6use crate::hip::error::{Error, Result};
7use crate::hip::ffi;
8use crate::hip::utils::Dim3;
9use std::ffi::{CString, c_void};
10use std::ptr;
11
12/// A wrapper around a HIP function (kernel)
13pub struct Function {
14    function: ffi::hipFunction_t,
15}
16
17impl Function {
18    /// Create a new function from a module and function name
19    pub unsafe fn new(module: ffi::hipModule_t, name: &str) -> Result<Self> {
20        let func_name = CString::new(name).unwrap();
21        let mut function = ptr::null_mut();
22
23        let error = unsafe { ffi::hipModuleGetFunction(&mut function, module, func_name.as_ptr()) };
24
25        if error != ffi::hipError_t_hipSuccess {
26            return Err(Error::new(error));
27        }
28
29        Ok(Self { function })
30    }
31
32    /// Launch the kernel with the given parameters
33    pub fn launch(
34        &self,
35        grid_dim: Dim3,
36        block_dim: Dim3,
37        shared_mem_bytes: u32,
38        stream: Option<&Stream>,
39        kernel_params: &mut [*mut c_void],
40    ) -> Result<()> {
41        let stream_ptr = match stream {
42            Some(s) => s.as_raw(),
43            None => ptr::null_mut(),
44        };
45
46        let error = unsafe {
47            ffi::hipModuleLaunchKernel(
48                self.function,
49                grid_dim.x,
50                grid_dim.y,
51                grid_dim.z,
52                block_dim.x,
53                block_dim.y,
54                block_dim.z,
55                shared_mem_bytes,
56                stream_ptr,
57                kernel_params.as_mut_ptr(),
58                ptr::null_mut(), // extra
59            )
60        };
61
62        if error != ffi::hipError_t_hipSuccess {
63            return Err(Error::new(error));
64        }
65
66        Ok(())
67    }
68
69    /// Get the raw function handle
70    pub fn as_raw(&self) -> ffi::hipFunction_t {
71        self.function
72    }
73}
74
75/// A trait for types that can be passed as kernel arguments
76pub trait KernelArg {
77    /// Get a pointer to the argument value
78    fn as_ptr(&self) -> *const c_void;
79}
80
81// Implement KernelArg for common types
82macro_rules! impl_kernel_arg {
83    ($($t:ty),*) => {
84        $(
85            impl KernelArg for $t {
86                fn as_ptr(&self) -> *const c_void {
87                    self as *const $t as *const c_void
88                }
89            }
90        )*
91    };
92}
93
94impl_kernel_arg!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64);
95
96// Helper for safe kernel launches via the hipLaunchKernel API
97/// Launch a HIP kernel using the driver API
98#[macro_export]
99macro_rules! launch_kernel {
100    ($func:expr, $grid:expr, $block:expr, $shared_mem:expr, $stream:expr, $($arg:expr),*) => {
101        {
102            let mut args = vec![];
103            $(
104                args.push($arg.as_ptr() as *mut std::ffi::c_void);
105            )*
106
107            $func.launch($grid, $block, $shared_mem, $stream, &mut args)
108        }
109    };
110}
111
112/// Launch a kernel via the runtime API
113///
114/// # Safety
115///
116/// This function is unsafe because it takes a raw function pointer and
117/// arguments that must match the function signature.
118pub unsafe fn launch_kernel(
119    kernel_func_ptr: *const c_void,
120    grid_dim: Dim3,
121    block_dim: Dim3,
122    shared_mem_bytes: u32,
123    stream: Option<&Stream>,
124    args: &[*mut c_void],
125) -> Result<()> {
126    let stream_ptr = match stream {
127        Some(s) => s.as_raw(),
128        None => ptr::null_mut(),
129    };
130
131    let native_grid_dim = grid_dim.to_native();
132    let native_block_dim = block_dim.to_native();
133
134    let error = unsafe {
135        ffi::hipLaunchKernel(
136            kernel_func_ptr,
137            native_grid_dim,
138            native_block_dim,
139            args.as_ptr() as *mut *mut c_void,
140            shared_mem_bytes.try_into().unwrap(),
141            stream_ptr,
142        )
143    };
144
145    if error != ffi::hipError_t_hipSuccess {
146        return Err(Error::new(error));
147    }
148
149    Ok(())
150}
151
152/// Macro to generate a kernel launcher function
153///
154/// This macro generates a function that takes the grid dimensions, block dimensions,
155/// shared memory size, stream, and kernel arguments, and launches the kernel.
156#[macro_export]
157macro_rules! kernel_launcher {
158    ($name:ident, $func:path, $($arg_ty:ty),*) => {
159        pub fn $name(
160            grid_dim: $crate::hip::utils::Dim3,
161            block_dim: $crate::hip::utils::Dim3,
162            shared_mem_bytes: u32,
163            stream: Option<&$crate::hip::Stream>,
164            $($arg:$arg_ty),*
165        ) -> $crate::hip::error::Result<()> {
166            unsafe {
167                let args: Vec<*mut std::ffi::c_void> = vec![
168                    $(&$arg as *const $arg_ty as *mut std::ffi::c_void),*
169                ];
170
171                $crate::hip::kernel::launch_kernel(
172                    $func,
173                    grid_dim,
174                    block_dim,
175                    shared_mem_bytes,
176                    stream,
177                    &args,
178                )
179            }
180        }
181    };
182}
183
184/// Helper function to convert a Stream reference to the rocrand stream type
185pub fn stream_to_rocrand(stream: &Stream) -> crate::rocrand::bindings::hipStream_t {
186    // Safe cast because both represent the same underlying HIP stream
187    stream.as_raw() as crate::rocrand::bindings::hipStream_t
188}