use crate::compile::KernelCache;
use crate::error::KernelError;
use crate::kernel::{dtype_suffix, BinaryOp, UnaryOp};
use crate::utils::grid_block_config;
use rocm_rs::hip::{DeviceMemory, Stream};
use std::sync::Arc;
pub struct OpLauncher {
cache: Arc<KernelCache>,
}
impl OpLauncher {
pub fn new(device: &rocm_rs::hip::Device) -> Result<Self, KernelError> {
let cache = Arc::new(KernelCache::new(device)?);
Ok(Self { cache })
}
pub fn launch_binary<T: Copy + Send + Sync + 'static>(
&self,
stream: &Stream,
op: BinaryOp,
numel: usize,
num_dims: usize,
dims_and_strides: Option<&DeviceMemory<usize>>,
lhs: &DeviceMemory<T>,
rhs: &DeviceMemory<T>,
output: &mut DeviceMemory<T>,
) -> Result<(), KernelError> {
use crate::kernel::BinaryKernel;
use crate::kernel::KernelSource;
let module = self
.cache
.get_or_load(BinaryKernel::NAME, BinaryKernel::CODE)?;
let kernel_name = format!("{}_{}", op.kernel_name(), dtype_suffix::<T>());
let function = module
.get_function(&kernel_name)
.map_err(|e| KernelError::Launch(format!("Kernel {} not found: {}", kernel_name, e)))?;
let (grid, block) = grid_block_config(numel);
let mut args: Vec<*mut std::ffi::c_void> = vec![
(&numel) as *const usize as *mut std::ffi::c_void,
(&num_dims) as *const usize as *mut std::ffi::c_void,
];
if let Some(info) = dims_and_strides {
args.push(info.as_ptr() as *mut std::ffi::c_void);
} else {
args.push(std::ptr::null_mut());
}
args.push(lhs.as_ptr() as *mut std::ffi::c_void);
args.push(rhs.as_ptr() as *mut std::ffi::c_void);
args.push(output.as_ptr() as *mut std::ffi::c_void);
function
.launch(grid, block, 0, Some(stream), &mut args)
.map_err(|e| KernelError::Launch(e.to_string()))?;
Ok(())
}
pub fn launch_unary<T: Copy + Send + Sync + 'static>(
&self,
stream: &Stream,
op: UnaryOp,
numel: usize,
num_dims: usize,
dims_and_strides: Option<&DeviceMemory<usize>>,
input: &DeviceMemory<T>,
output: &mut DeviceMemory<T>,
) -> Result<(), KernelError> {
use crate::kernel::KernelSource;
use crate::kernel::UnaryKernel;
let module = self
.cache
.get_or_load(UnaryKernel::NAME, UnaryKernel::CODE)?;
let kernel_name = format!("{}_{}", op.kernel_name(), dtype_suffix::<T>());
let function = module
.get_function(&kernel_name)
.map_err(|e| KernelError::Launch(format!("Kernel {} not found: {}", kernel_name, e)))?;
let (grid, block) = grid_block_config(numel);
let mut args: Vec<*mut std::ffi::c_void> = vec![
(&numel) as *const usize as *mut std::ffi::c_void,
(&num_dims) as *const usize as *mut std::ffi::c_void,
];
if let Some(info) = dims_and_strides {
args.push(info.as_ptr() as *mut std::ffi::c_void);
} else {
args.push(std::ptr::null_mut());
}
args.push(input.as_ptr() as *mut std::ffi::c_void);
args.push(output.as_ptr() as *mut std::ffi::c_void);
function
.launch(grid, block, 0, Some(stream), &mut args)
.map_err(|e| KernelError::Launch(e.to_string()))?;
Ok(())
}
pub fn launch_pow<T: Copy + Send + Sync + 'static>(
&self,
stream: &Stream,
numel: usize,
num_dims: usize,
dims_and_strides: Option<&DeviceMemory<usize>>,
input: &DeviceMemory<T>,
exp_val: T,
output: &mut DeviceMemory<T>,
) -> Result<(), KernelError> {
use crate::kernel::KernelSource;
use crate::kernel::UnaryKernel;
let module = self
.cache
.get_or_load(UnaryKernel::NAME, UnaryKernel::CODE)?;
let kernel_name = format!("upow_{}", dtype_suffix::<T>());
let function = module
.get_function(&kernel_name)
.map_err(|e| KernelError::Launch(format!("Kernel {} not found: {}", kernel_name, e)))?;
let (grid, block) = grid_block_config(numel);
let exp_ptr: *const T = &exp_val;
let mut args: Vec<*mut std::ffi::c_void> = vec![
(&numel) as *const usize as *mut std::ffi::c_void,
(&num_dims) as *const usize as *mut std::ffi::c_void,
];
if let Some(info) = dims_and_strides {
args.push(info.as_ptr() as *mut std::ffi::c_void);
} else {
args.push(std::ptr::null_mut());
}
args.push(input.as_ptr() as *mut std::ffi::c_void);
args.push(exp_ptr as *mut std::ffi::c_void);
args.push(output.as_ptr() as *mut std::ffi::c_void);
function
.launch(grid, block, 0, Some(stream), &mut args)
.map_err(|e| KernelError::Launch(e.to_string()))?;
Ok(())
}
pub fn cache(&self) -> &Arc<KernelCache> {
&self.cache
}
}