use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_name, kernel_names, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
#[inline]
fn fused_norm_launch_config(
batch_size: usize,
hidden_size: usize,
shared_arrays: usize,
dtype: DType,
) -> (u32, u32, u32) {
let block_size = BLOCK_SIZE.min(hidden_size as u32);
let grid_size = batch_size as u32;
let elem_size = match dtype {
DType::F64 => 8u32,
_ => 4u32, };
let shared_mem = (shared_arrays as u32) * block_size * elem_size;
(grid_size, block_size, shared_mem)
}
pub unsafe fn launch_fused_add_rms_norm(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
residual_ptr: u64,
weight_ptr: u64,
output_ptr: u64,
pre_norm_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module =
get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?;
let func_name = kernel_name("fused_add_rms_norm", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) =
fused_norm_launch_config(batch_size, hidden_size, 1, dtype);
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let eps_f64 = eps as f64;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&residual_ptr);
builder.arg(&weight_ptr);
builder.arg(&output_ptr);
builder.arg(&pre_norm_ptr);
builder.arg(&batch);
builder.arg(&hidden);
if dtype == DType::F64 {
builder.arg(&eps_f64);
} else {
builder.arg(&eps);
}
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA fused_add_rms_norm kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}
pub unsafe fn launch_fused_add_rms_norm_bwd(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
grad_ptr: u64,
pre_norm_ptr: u64,
weight_ptr: u64,
d_input_residual_ptr: u64,
d_weight_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module =
get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?;
let func_name = kernel_name("fused_add_rms_norm_bwd", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) =
fused_norm_launch_config(batch_size, hidden_size, 2, dtype);
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let eps_f64 = eps as f64;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&grad_ptr);
builder.arg(&pre_norm_ptr);
builder.arg(&weight_ptr);
builder.arg(&d_input_residual_ptr);
builder.arg(&d_weight_ptr);
builder.arg(&batch);
builder.arg(&hidden);
if dtype == DType::F64 {
builder.arg(&eps_f64);
} else {
builder.arg(&eps);
}
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA fused_add_rms_norm_bwd kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}
pub unsafe fn launch_fused_add_layer_norm(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
residual_ptr: u64,
weight_ptr: u64,
bias_ptr: u64,
output_ptr: u64,
pre_norm_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module =
get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?;
let func_name = kernel_name("fused_add_layer_norm", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) =
fused_norm_launch_config(batch_size, hidden_size, 2, dtype);
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let eps_f64 = eps as f64;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&residual_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr);
builder.arg(&output_ptr);
builder.arg(&pre_norm_ptr);
builder.arg(&batch);
builder.arg(&hidden);
if dtype == DType::F64 {
builder.arg(&eps_f64);
} else {
builder.arg(&eps);
}
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA fused_add_layer_norm kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}
pub unsafe fn launch_fused_add_layer_norm_bwd(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
grad_ptr: u64,
pre_norm_ptr: u64,
weight_ptr: u64,
d_input_residual_ptr: u64,
d_weight_ptr: u64,
d_bias_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module =
get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?;
let func_name = kernel_name("fused_add_layer_norm_bwd", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) =
fused_norm_launch_config(batch_size, hidden_size, 4, dtype);
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let eps_f64 = eps as f64;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&grad_ptr);
builder.arg(&pre_norm_ptr);
builder.arg(&weight_ptr);
builder.arg(&d_input_residual_ptr);
builder.arg(&d_weight_ptr);
builder.arg(&d_bias_ptr);
builder.arg(&batch);
builder.arg(&hidden);
if dtype == DType::F64 {
builder.arg(&eps_f64);
} else {
builder.arg(&eps);
}
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA fused_add_layer_norm_bwd kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}