use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
const FUSED_ACTIVATION_MUL_MODULE: &str = "fused_activation_mul";
const FUSED_ACTIVATION_MUL_BWD_MODULE: &str = "fused_activation_mul_bwd";
unsafe fn launch_fused_activation_mul_fwd(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, FUSED_ACTIVATION_MUL_MODULE)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
unsafe {
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?;
}
Ok(())
}
unsafe fn launch_fused_activation_mul_bwd(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
grad_ptr: u64,
a_ptr: u64,
b_ptr: u64,
d_a_ptr: u64,
d_b_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, FUSED_ACTIVATION_MUL_BWD_MODULE)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
unsafe {
builder.arg(&grad_ptr);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&d_a_ptr);
builder.arg(&d_b_ptr);
builder.arg(&n);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?;
}
Ok(())
}
macro_rules! fused_activation_mul_fwd {
($($(#[doc = $doc:expr])* $name:ident => $op:expr),+ $(,)?) => {
$(
$(#[doc = $doc])*
///
/// # Safety
///
/// All pointers must be valid device memory with at least `numel` elements.
pub unsafe fn $name(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_fused_activation_mul_fwd(
context, stream, device_index, $op, dtype, a_ptr, b_ptr, output_ptr, numel,
)
}
}
)+
};
}
fused_activation_mul_fwd! {
launch_silu_mul => "silu_mul",
launch_gelu_mul => "gelu_mul",
launch_relu_mul => "relu_mul",
launch_sigmoid_mul => "sigmoid_mul",
}
macro_rules! fused_activation_mul_bwd {
($($(#[doc = $doc:expr])* $name:ident => $op:expr),+ $(,)?) => {
$(
$(#[doc = $doc])*
///
/// # Safety
///
/// All pointers must be valid device memory with at least `numel` elements.
pub unsafe fn $name(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
grad_ptr: u64,
a_ptr: u64,
b_ptr: u64,
d_a_ptr: u64,
d_b_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_fused_activation_mul_bwd(
context, stream, device_index, $op, dtype, grad_ptr, a_ptr, b_ptr,
d_a_ptr, d_b_ptr, numel,
)
}
}
)+
};
}
fused_activation_mul_bwd! {
launch_silu_mul_bwd => "silu_mul_bwd",
launch_gelu_mul_bwd => "gelu_mul_bwd",
launch_relu_mul_bwd => "relu_mul_bwd",
launch_sigmoid_mul_bwd => "sigmoid_mul_bwd",
}