use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
dtype_suffix, get_kernel_function, get_or_load_module, kernel_name, kernel_names,
launch_config, reduce_dim_launch_config, reduce_launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
pub(crate) use crate::ops::AccumulationPrecision;
fn reduce_kernel_name(base_op: &str, dtype: DType, acc_precision: AccumulationPrecision) -> String {
let suffix = dtype_suffix(dtype);
let acc_suffix = match dtype {
DType::F16 | DType::BF16 => match acc_precision {
AccumulationPrecision::FP32 => Some("_fp32acc"),
AccumulationPrecision::FP64 => Some("_fp64acc"),
AccumulationPrecision::Native | AccumulationPrecision::BF16 => None,
},
DType::FP8E4M3 | DType::FP8E5M2 => match acc_precision {
AccumulationPrecision::BF16 => Some("_bf16acc"),
AccumulationPrecision::FP64 => Some("_fp64acc"),
AccumulationPrecision::Native | AccumulationPrecision::FP32 => None,
},
DType::F32 => match acc_precision {
AccumulationPrecision::FP64 => Some("_fp64acc"),
_ => None,
},
_ => None,
};
match acc_suffix {
Some(s) => format!("{}_{}{}", base_op, suffix, s),
None => format!("{}_{}", base_op, suffix),
}
}
#[allow(dead_code)] pub unsafe fn launch_reduce_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
input_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<u32> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::REDUCE_MODULE)?;
let func_name = kernel_name(&kernel_names::reduce_kernel(op), dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid_size, block_size) = reduce_launch_config(numel);
let n = numel as u32;
let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA reduce kernel '{}' launch failed: {:?}",
op, e
))
})?;
Ok(grid_size)
}
}
pub unsafe fn launch_reduce_dim_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
input_ptr: u64,
output_ptr: u64,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
acc_precision: AccumulationPrecision,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::REDUCE_MODULE)?;
let base_op = kernel_names::reduce_dim_kernel(op);
let func_name = reduce_kernel_name(&base_op, dtype, acc_precision);
let func = get_kernel_function(&module, &func_name)?;
let (grid, block) = reduce_dim_launch_config(outer_size, inner_size);
let outer = outer_size as u32;
let reduce = reduce_size as u32;
let inner = inner_size as u32;
let cfg = launch_config(grid, (block, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&outer);
builder.arg(&reduce);
builder.arg(&inner);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA reduce_dim kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
}
pub unsafe fn launch_argmax_dim(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
output_ptr: u64,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::REDUCE_MODULE)?;
let func_name = kernel_name("argmax_dim", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid, block) = reduce_dim_launch_config(outer_size, inner_size);
let outer = outer_size as u32;
let reduce = reduce_size as u32;
let inner = inner_size as u32;
let cfg = launch_config(grid, (block, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&outer);
builder.arg(&reduce);
builder.arg(&inner);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA argmax_dim kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_argmin_dim(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
output_ptr: u64,
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::REDUCE_MODULE)?;
let func_name = kernel_name("argmin_dim", dtype);
let func = get_kernel_function(&module, &func_name)?;
let (grid, block) = reduce_dim_launch_config(outer_size, inner_size);
let outer = outer_size as u32;
let reduce = reduce_size as u32;
let inner = inner_size as u32;
let cfg = launch_config(grid, (block, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&outer);
builder.arg(&reduce);
builder.arg(&inner);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA argmin_dim kernel launch failed: {:?}", e))
})?;
Ok(())
}
}