use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
const MODULE: &str = "fused_elementwise";
pub unsafe fn launch_fused_mul_add(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_ternary_kernel(
context,
stream,
device_index,
"fused_mul_add",
dtype,
a_ptr,
b_ptr,
c_ptr,
output_ptr,
numel,
)
}
}
pub unsafe fn launch_fused_add_mul(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_ternary_kernel(
context,
stream,
device_index,
"fused_add_mul",
dtype,
a_ptr,
b_ptr,
c_ptr,
output_ptr,
numel,
)
}
}
pub unsafe fn launch_fused_mul_add_scalar(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
output_ptr: u64,
numel: usize,
scale: f64,
bias: f64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, MODULE)?;
let func_name = kernel_name("fused_mul_add_scalar", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let scale_f32 = scale as f32;
let bias_f32 = bias as f32;
let mut builder = stream.launch_builder(&func);
unsafe {
builder.arg(&a_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
match dtype {
DType::F64 => {
builder.arg(&scale);
builder.arg(&bias);
}
_ => {
builder.arg(&scale_f32);
builder.arg(&bias_f32);
}
}
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA fused_mul_add_scalar kernel launch failed: {:?}",
e
))
})?;
}
Ok(())
}
unsafe fn launch_ternary_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, MODULE)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
unsafe {
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?;
}
Ok(())
}