use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_name, kernel_names, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
#[inline]
fn norm_launch_config(batch_size: usize, hidden_size: usize) -> (u32, u32, u32) {
let block_size = BLOCK_SIZE.min(hidden_size as u32);
let grid_size = batch_size as u32;
let shared_mem = block_size * 4; (grid_size, block_size, shared_mem)
}
pub unsafe fn launch_rms_norm(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
output_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::NORM_MODULE)?;
let func_name = kernel_name("rms_norm", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) = norm_launch_config(batch_size, hidden_size);
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&output_ptr);
builder.arg(&batch);
builder.arg(&hidden);
builder.arg(&eps);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA rms_norm kernel launch failed: {:?}", e)))?;
Ok(())
}
}
pub unsafe fn launch_layer_norm(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
bias_ptr: u64,
output_ptr: u64,
batch_size: usize,
hidden_size: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::NORM_MODULE)?;
let func_name = kernel_name("layer_norm", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size, shared_mem) = norm_launch_config(batch_size, hidden_size);
let shared_mem = shared_mem * 2;
let batch = batch_size as u32;
let hidden = hidden_size as u32;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr);
builder.arg(&output_ptr);
builder.arg(&batch);
builder.arg(&hidden);
builder.arg(&eps);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA layer_norm kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_group_norm(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
bias_ptr: u64,
output_ptr: u64,
batch: usize,
channels: usize,
spatial: usize,
num_groups: usize,
channels_per_group: usize,
eps: f32,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::NORM_MODULE)?;
let func_name = kernel_name("group_norm", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid_size = (batch * num_groups) as u32;
let group_size = channels_per_group * spatial;
let block_size = BLOCK_SIZE.min(group_size as u32);
let shared_mem = block_size * 2 * 4;
let batch_u32 = batch as u32;
let channels_u32 = channels as u32;
let spatial_u32 = spatial as u32;
let num_groups_u32 = num_groups as u32;
let channels_per_group_u32 = channels_per_group as u32;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr);
builder.arg(&output_ptr);
builder.arg(&batch_u32);
builder.arg(&channels_u32);
builder.arg(&spatial_u32);
builder.arg(&num_groups_u32);
builder.arg(&channels_per_group_u32);
builder.arg(&eps);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA group_norm kernel launch failed: {:?}", e))
})?;
Ok(())
}
}