#[cfg(feature = "cudnn")]
use cudarc::cudnn::sys::{cudnnConvolutionMode_t, cudnnTensorFormat_t};
#[cfg(feature = "cudnn")]
use cudarc::cudnn::{ConvBackwardData, ConvBackwardFilter, ConvForward, Cudnn};
#[cfg(feature = "cudnn")]
use cudarc::driver::{CudaSlice, CudaStream};
#[cfg(feature = "cudnn")]
use std::sync::Arc;
#[cfg(feature = "cudnn")]
use super::cuda::CudaBackend;
#[cfg(feature = "cudnn")]
#[allow(clippy::too_many_arguments)]
pub fn cudnn_conv2d_forward(
cudnn: &Arc<Cudnn>,
stream: &Arc<CudaStream>,
cuda_backend: &CudaBackend,
input: &CudaSlice<f32>,
weight: &CudaSlice<f32>,
bias: Option<&CudaSlice<f32>>,
batch_size: usize,
in_channels: usize,
in_height: usize,
in_width: usize,
out_channels: usize,
kernel_h: usize,
kernel_w: usize,
stride: (usize, usize),
padding: (usize, usize),
groups: usize,
) -> Option<CudaSlice<f32>> {
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
let n = batch_size as i32;
let c_in = in_channels as i32;
let h = in_height as i32;
let w = in_width as i32;
let c_out = out_channels as i32;
let kh = kernel_h as i32;
let kw = kernel_w as i32;
let oh = out_h as i32;
let ow = out_w as i32;
let x_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_in, h, w])
.ok()?;
let c_in_per_group = c_in / groups as i32;
let filter_desc = cudnn
.create_4d_filter::<f32>(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[c_out, c_in_per_group, kh, kw],
)
.ok()?;
let y_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_out, oh, ow])
.ok()?;
let mut conv_desc = cudnn
.create_conv2d::<f32>(
[pad_h as i32, pad_w as i32],
[stride_h as i32, stride_w as i32],
[1, 1], cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)
.ok()?;
if groups > 1 {
conv_desc.set_group_count(groups as i32).ok()?;
}
conv_desc
.set_math_type(cudarc::cudnn::sys::cudnnMathType_t::CUDNN_DEFAULT_MATH)
.ok()?;
let op = ConvForward {
conv: &conv_desc,
x: &x_desc,
w: &filter_desc,
y: &y_desc,
};
let algo = op.pick_algorithm().ok()?;
let workspace_size = op.get_workspace_size(algo).ok()?;
let mut workspace = stream.alloc_zeros::<u8>(workspace_size.max(1)).ok()?;
let total_out = batch_size * out_channels * out_h * out_w;
let mut output = stream.alloc_zeros::<f32>(total_out).ok()?;
unsafe {
op.launch(
algo,
Some(&mut workspace),
(1.0f32, 0.0f32),
input,
weight,
&mut output,
)
.ok()?;
}
if let Some(bias_data) = bias {
let spatial = out_h * out_w;
cuda_backend
.bias_add_channels_f32(&mut output, bias_data, spatial, total_out)
.ok()?;
}
Some(output)
}
#[cfg(feature = "cudnn")]
#[allow(clippy::too_many_arguments)]
pub fn cudnn_conv2d_backward_data(
cudnn: &Arc<Cudnn>,
stream: &Arc<CudaStream>,
grad_output: &CudaSlice<f32>,
weight: &CudaSlice<f32>,
batch_size: usize,
in_channels: usize,
in_height: usize,
in_width: usize,
out_channels: usize,
kernel_h: usize,
kernel_w: usize,
out_h: usize,
out_w: usize,
stride: (usize, usize),
padding: (usize, usize),
groups: usize,
) -> Option<CudaSlice<f32>> {
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let n = batch_size as i32;
let c_in = in_channels as i32;
let h = in_height as i32;
let w = in_width as i32;
let c_out = out_channels as i32;
let kh = kernel_h as i32;
let kw = kernel_w as i32;
let oh = out_h as i32;
let ow = out_w as i32;
let c_in_per_group = c_in / groups as i32;
let dx_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_in, h, w])
.ok()?;
let filter_desc = cudnn
.create_4d_filter::<f32>(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[c_out, c_in_per_group, kh, kw],
)
.ok()?;
let dy_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_out, oh, ow])
.ok()?;
let mut conv_desc = cudnn
.create_conv2d::<f32>(
[pad_h as i32, pad_w as i32],
[stride_h as i32, stride_w as i32],
[1, 1],
cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)
.ok()?;
if groups > 1 {
conv_desc.set_group_count(groups as i32).ok()?;
}
let op = ConvBackwardData {
conv: &conv_desc,
dx: &dx_desc,
w: &filter_desc,
dy: &dy_desc,
};
let algo = op.pick_algorithm().ok()?;
let workspace_size = op.get_workspace_size(algo).ok()?;
let mut workspace = stream.alloc_zeros::<u8>(workspace_size.max(1)).ok()?;
let total_in = batch_size * in_channels * in_height * in_width;
let mut grad_input = stream.alloc_zeros::<f32>(total_in).ok()?;
unsafe {
op.launch(
algo,
Some(&mut workspace),
(1.0f32, 0.0f32),
&mut grad_input,
weight,
grad_output,
)
.ok()?;
}
Some(grad_input)
}
#[cfg(feature = "cudnn")]
#[allow(clippy::too_many_arguments)]
pub fn cudnn_conv2d_backward_filter(
cudnn: &Arc<Cudnn>,
stream: &Arc<CudaStream>,
grad_output: &CudaSlice<f32>,
input: &CudaSlice<f32>,
batch_size: usize,
in_channels: usize,
in_height: usize,
in_width: usize,
out_channels: usize,
kernel_h: usize,
kernel_w: usize,
out_h: usize,
out_w: usize,
stride: (usize, usize),
padding: (usize, usize),
groups: usize,
) -> Option<CudaSlice<f32>> {
let (stride_h, stride_w) = stride;
let (pad_h, pad_w) = padding;
let n = batch_size as i32;
let c_in = in_channels as i32;
let h = in_height as i32;
let w = in_width as i32;
let c_out = out_channels as i32;
let kh = kernel_h as i32;
let kw = kernel_w as i32;
let oh = out_h as i32;
let ow = out_w as i32;
let c_in_per_group = c_in / groups as i32;
let x_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_in, h, w])
.ok()?;
let dw_desc = cudnn
.create_4d_filter::<f32>(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[c_out, c_in_per_group, kh, kw],
)
.ok()?;
let dy_desc = cudnn
.create_4d_tensor::<f32>(cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [n, c_out, oh, ow])
.ok()?;
let mut conv_desc = cudnn
.create_conv2d::<f32>(
[pad_h as i32, pad_w as i32],
[stride_h as i32, stride_w as i32],
[1, 1],
cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)
.ok()?;
if groups > 1 {
conv_desc.set_group_count(groups as i32).ok()?;
}
let op = ConvBackwardFilter {
conv: &conv_desc,
x: &x_desc,
dw: &dw_desc,
dy: &dy_desc,
};
let algo = op.pick_algorithm().ok()?;
let workspace_size = op.get_workspace_size(algo).ok()?;
let mut workspace = stream.alloc_zeros::<u8>(workspace_size.max(1)).ok()?;
let weight_size = out_channels * (in_channels / groups) * kernel_h * kernel_w;
let mut grad_weight = stream.alloc_zeros::<f32>(weight_size).ok()?;
unsafe {
op.launch(
algo,
Some(&mut workspace),
(1.0f32, 0.0f32),
input,
&mut grad_weight,
grad_output,
)
.ok()?;
}
Some(grad_weight)
}
#[cfg(not(feature = "cudnn"))]
#[allow(clippy::too_many_arguments)]
pub fn cudnn_conv2d_forward(
_input: &[f32],
_weight: &[f32],
_bias: Option<&[f32]>,
_batch_size: usize,
_in_channels: usize,
_in_height: usize,
_in_width: usize,
_out_channels: usize,
_kernel_h: usize,
_kernel_w: usize,
_stride: (usize, usize),
_padding: (usize, usize),
_groups: usize,
) -> Option<Vec<f32>> {
None
}