use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
pub const CONV_MODULE: &str = "conv";
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_conv1d(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
bias_ptr: Option<u64>,
output_ptr: u64,
batch: usize,
c_in: usize,
length: usize,
c_out: usize,
kernel_size: usize,
output_length: usize,
stride: usize,
padding: usize,
dilation: usize,
groups: usize,
) -> Result<()> {
let total = batch * c_out * output_length;
if total == 0 {
return Ok(());
}
unsafe {
let module = get_or_load_module(context, device_index, CONV_MODULE)?;
let func_name = kernel_name("conv1d", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(total);
let block = (BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let batch_u32 = batch as u32;
let c_in_u32 = c_in as u32;
let length_u32 = length as u32;
let c_out_u32 = c_out as u32;
let kernel_size_u32 = kernel_size as u32;
let output_length_u32 = output_length as u32;
let stride_u32 = stride as u32;
let padding_u32 = padding as u32;
let dilation_u32 = dilation as u32;
let groups_u32 = groups as u32;
let has_bias_u32: u32 = if bias_ptr.is_some() { 1 } else { 0 };
let bias_ptr_val = bias_ptr.unwrap_or(0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr_val);
builder.arg(&output_ptr);
builder.arg(&batch_u32);
builder.arg(&c_in_u32);
builder.arg(&length_u32);
builder.arg(&c_out_u32);
builder.arg(&kernel_size_u32);
builder.arg(&output_length_u32);
builder.arg(&stride_u32);
builder.arg(&padding_u32);
builder.arg(&dilation_u32);
builder.arg(&groups_u32);
builder.arg(&has_bias_u32);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA conv1d kernel launch failed: {:?}", e)))?;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_conv2d(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
bias_ptr: Option<u64>,
output_ptr: u64,
batch: usize,
c_in: usize,
height: usize,
width: usize,
c_out: usize,
kernel_h: usize,
kernel_w: usize,
output_h: usize,
output_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
groups: usize,
) -> Result<()> {
let total = batch * c_out * output_h * output_w;
if total == 0 {
return Ok(());
}
unsafe {
let module = get_or_load_module(context, device_index, CONV_MODULE)?;
let func_name = kernel_name("conv2d", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(total);
let block = (BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let batch_u32 = batch as u32;
let c_in_u32 = c_in as u32;
let height_u32 = height as u32;
let width_u32 = width as u32;
let c_out_u32 = c_out as u32;
let kernel_h_u32 = kernel_h as u32;
let kernel_w_u32 = kernel_w as u32;
let output_h_u32 = output_h as u32;
let output_w_u32 = output_w as u32;
let stride_h_u32 = stride_h as u32;
let stride_w_u32 = stride_w as u32;
let pad_h_u32 = pad_h as u32;
let pad_w_u32 = pad_w as u32;
let dilation_h_u32 = dilation_h as u32;
let dilation_w_u32 = dilation_w as u32;
let groups_u32 = groups as u32;
let has_bias_u32: u32 = if bias_ptr.is_some() { 1 } else { 0 };
let bias_ptr_val = bias_ptr.unwrap_or(0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr_val);
builder.arg(&output_ptr);
builder.arg(&batch_u32);
builder.arg(&c_in_u32);
builder.arg(&height_u32);
builder.arg(&width_u32);
builder.arg(&c_out_u32);
builder.arg(&kernel_h_u32);
builder.arg(&kernel_w_u32);
builder.arg(&output_h_u32);
builder.arg(&output_w_u32);
builder.arg(&stride_h_u32);
builder.arg(&stride_w_u32);
builder.arg(&pad_h_u32);
builder.arg(&pad_w_u32);
builder.arg(&dilation_h_u32);
builder.arg(&dilation_w_u32);
builder.arg(&groups_u32);
builder.arg(&has_bias_u32);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA conv2d kernel launch failed: {:?}", e)))?;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_depthwise_conv2d(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
input_ptr: u64,
weight_ptr: u64,
bias_ptr: Option<u64>,
output_ptr: u64,
batch: usize,
channels: usize,
height: usize,
width: usize,
kernel_h: usize,
kernel_w: usize,
output_h: usize,
output_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
) -> Result<()> {
let total = batch * channels * output_h * output_w;
if total == 0 {
return Ok(());
}
unsafe {
let module = get_or_load_module(context, device_index, CONV_MODULE)?;
let func_name = kernel_name("depthwise_conv2d", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(total);
let block = (BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let batch_u32 = batch as u32;
let channels_u32 = channels as u32;
let height_u32 = height as u32;
let width_u32 = width as u32;
let kernel_h_u32 = kernel_h as u32;
let kernel_w_u32 = kernel_w as u32;
let output_h_u32 = output_h as u32;
let output_w_u32 = output_w as u32;
let stride_h_u32 = stride_h as u32;
let stride_w_u32 = stride_w as u32;
let pad_h_u32 = pad_h as u32;
let pad_w_u32 = pad_w as u32;
let dilation_h_u32 = dilation_h as u32;
let dilation_w_u32 = dilation_w as u32;
let has_bias_u32: u32 = if bias_ptr.is_some() { 1 } else { 0 };
let bias_ptr_val = bias_ptr.unwrap_or(0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&weight_ptr);
builder.arg(&bias_ptr_val);
builder.arg(&output_ptr);
builder.arg(&batch_u32);
builder.arg(&channels_u32);
builder.arg(&height_u32);
builder.arg(&width_u32);
builder.arg(&kernel_h_u32);
builder.arg(&kernel_w_u32);
builder.arg(&output_h_u32);
builder.arg(&output_w_u32);
builder.arg(&stride_h_u32);
builder.arg(&stride_w_u32);
builder.arg(&pad_h_u32);
builder.arg(&pad_w_u32);
builder.arg(&dilation_h_u32);
builder.arg(&dilation_w_u32);
builder.arg(&has_bias_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA depthwise_conv2d kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}