use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::types::{Activation, ConvAlgorithm, ConvolutionDescriptor, TensorDesc, TensorDescMut};
use super::descriptor::ConvProblem;
use super::dgrad::implicit_gemm::DgradImplicitGemm;
use super::fft_conv::FftConv2dPlan;
use super::fprop::direct::{Conv1x1, DepthwiseConv};
use super::fprop::im2col_gemm::Im2colGemmConv;
use super::fprop::implicit_gemm::ImplicitGemmConv;
use super::fprop::winograd::WinogradConv;
use super::fused::{FusedBnParams, FusedConvBnAct};
use super::wgrad::implicit_gemm::WgradImplicitGemm;
pub fn conv_forward<T: GpuFloat>(
handle: &DnnHandle,
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
workspace: Option<&mut DeviceBuffer<u8>>,
) -> DnnResult<()> {
let problem = ConvProblem::from_descriptors(input, filter, output, conv_desc)?;
problem.validate()?;
let algo = problem.select_algorithm(handle.sm_version());
match algo {
ConvAlgorithm::Direct if problem.is_1x1() => {
let engine = Conv1x1::new(problem, handle.sm_version())?;
engine.execute(handle, input, filter, output)
}
ConvAlgorithm::Direct => {
let engine = DepthwiseConv::new(problem, handle.sm_version())?;
engine.execute(handle, input, filter, output)
}
ConvAlgorithm::ImplicitGemm => {
let engine = ImplicitGemmConv::new(problem, handle.sm_version());
engine.execute(handle, input, filter, None, output)
}
ConvAlgorithm::Im2colGemm => {
let ws = workspace.ok_or(DnnError::WorkspaceRequired(
Im2colGemmConv::new(problem.clone(), handle.sm_version())
.workspace_bytes()
.unwrap_or(0),
))?;
let engine = Im2colGemmConv::new(problem, handle.sm_version());
engine.execute(handle, input, filter, output, ws)
}
ConvAlgorithm::Winograd => {
let ws = workspace.ok_or_else(|| {
DnnError::WorkspaceRequired(
WinogradConv::new(problem.clone(), handle.sm_version())
.and_then(|w| w.workspace_bytes())
.unwrap_or(0),
)
})?;
let engine = WinogradConv::new(problem, handle.sm_version())?;
engine.execute(handle, input, filter, output, ws)
}
ConvAlgorithm::FftConv => {
if problem.in_dims.len() != 2
|| problem.filter_dims.len() != 2
|| problem.padding.len() != 2
|| problem.stride.len() != 2
{
let engine = ImplicitGemmConv::new(problem, handle.sm_version());
return engine.execute(handle, input, filter, None, output);
}
let sm_num = sm_version_numeric(handle.sm_version());
let fft_plan = FftConv2dPlan::new(
problem.in_channels,
problem.out_channels,
problem.filter_dims[0],
problem.filter_dims[1],
problem.stride[0],
problem.stride[1],
problem.padding[0],
problem.padding[1],
sm_num,
problem.input_type,
);
let Ok(plan) = fft_plan else {
let engine = ImplicitGemmConv::new(problem, handle.sm_version());
return engine.execute(handle, input, filter, None, output);
};
let required_ws =
plan.workspace_bytes(problem.in_dims[0], problem.in_dims[1], problem.batch)?;
let ws = workspace.ok_or(DnnError::WorkspaceRequired(required_ws))?;
if ws.len() < required_ws {
return Err(DnnError::WorkspaceRequired(required_ws));
}
let engine = Im2colGemmConv::new(problem, handle.sm_version());
engine.execute(handle, input, filter, output, ws)
}
}
}
#[inline]
fn sm_version_numeric(sm: oxicuda_ptx::arch::SmVersion) -> u32 {
match sm {
oxicuda_ptx::arch::SmVersion::Sm75 => 75,
oxicuda_ptx::arch::SmVersion::Sm80 => 80,
oxicuda_ptx::arch::SmVersion::Sm86 => 86,
oxicuda_ptx::arch::SmVersion::Sm89 => 89,
oxicuda_ptx::arch::SmVersion::Sm90 => 90,
oxicuda_ptx::arch::SmVersion::Sm90a => 90,
oxicuda_ptx::arch::SmVersion::Sm100 => 100,
oxicuda_ptx::arch::SmVersion::Sm120 => 120,
}
}
pub fn conv_backward_data<T: GpuFloat>(
handle: &DnnHandle,
filter: &TensorDesc<T>,
grad_output: &TensorDesc<T>,
grad_input: &mut TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
_workspace: Option<&mut DeviceBuffer<u8>>,
) -> DnnResult<()> {
let problem = build_dgrad_problem::<T>(filter, grad_output, grad_input, conv_desc)?;
problem.validate()?;
let engine = DgradImplicitGemm::new(problem, handle.sm_version());
engine.execute(handle, grad_output, filter, grad_input)
}
pub fn conv_backward_filter<T: GpuFloat>(
handle: &DnnHandle,
input: &TensorDesc<T>,
grad_output: &TensorDesc<T>,
grad_filter: &mut TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
_workspace: Option<&mut DeviceBuffer<u8>>,
) -> DnnResult<()> {
let problem = build_wgrad_problem::<T>(input, grad_output, grad_filter, conv_desc)?;
problem.validate()?;
let engine = WgradImplicitGemm::new(problem, handle.sm_version());
engine.execute(handle, input, grad_output, grad_filter)
}
pub fn conv_bn_relu<T: GpuFloat>(
handle: &DnnHandle,
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
bn_params: &FusedBnParams,
activation: Activation,
) -> DnnResult<()> {
let problem = ConvProblem::from_descriptors(input, filter, output, conv_desc)?;
problem.validate()?;
if bn_params.channels != problem.out_channels {
return Err(DnnError::InvalidArgument(format!(
"BN channels ({}) != out_channels ({})",
bn_params.channels, problem.out_channels
)));
}
let engine = FusedConvBnAct::new(problem, activation, handle.sm_version());
engine.execute(handle, input, filter, output, bn_params)
}
fn build_dgrad_problem<T: GpuFloat>(
filter: &TensorDesc<T>,
_grad_output: &TensorDesc<T>,
grad_input: &TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
) -> DnnResult<ConvProblem> {
let layout = grad_input.layout;
let ndim = layout.expected_ndim();
let spatial = layout.spatial_dims();
if grad_input.dims.len() != ndim {
return Err(DnnError::InvalidDimension(format!(
"grad_input has {} dims, expected {ndim}",
grad_input.dims.len()
)));
}
let batch = grad_input.dims[0];
let in_channels = grad_input.dims[1];
let in_dims = grad_input.dims[2..].to_vec();
let out_channels = filter.dims[0];
let filter_dims = if filter.dims.len() >= 2 + spatial {
filter.dims[2..2 + spatial].to_vec()
} else {
return Err(DnnError::InvalidDimension("filter dims too short".into()));
};
Ok(ConvProblem {
batch,
in_channels,
in_dims,
out_channels,
filter_dims,
padding: conv_desc.padding.clone(),
stride: conv_desc.stride.clone(),
dilation: conv_desc.dilation.clone(),
groups: conv_desc.groups,
input_type: T::PTX_TYPE,
output_type: T::PTX_TYPE,
layout,
})
}
fn build_wgrad_problem<T: GpuFloat>(
input: &TensorDesc<T>,
_grad_output: &TensorDesc<T>,
grad_filter: &TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
) -> DnnResult<ConvProblem> {
let layout = input.layout;
let spatial = layout.spatial_dims();
let batch = input.dims[0];
let in_channels = input.dims[1];
let in_dims = input.dims[2..].to_vec();
let out_channels = grad_filter.dims[0];
let filter_dims = if grad_filter.dims.len() >= 2 + spatial {
grad_filter.dims[2..2 + spatial].to_vec()
} else {
return Err(DnnError::InvalidDimension(
"grad_filter dims too short".into(),
));
};
Ok(ConvProblem {
batch,
in_channels,
in_dims,
out_channels,
filter_dims,
padding: conv_desc.padding.clone(),
stride: conv_desc.stride.clone(),
dilation: conv_desc.dilation.clone(),
groups: conv_desc.groups,
input_type: T::PTX_TYPE,
output_type: T::PTX_TYPE,
layout,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TensorLayout;
#[test]
fn select_algorithm_through_problem() {
let problem = ConvProblem {
batch: 1,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 128,
filter_dims: vec![1, 1],
padding: vec![0, 0],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: oxicuda_ptx::ir::PtxType::F32,
output_type: oxicuda_ptx::ir::PtxType::F32,
layout: TensorLayout::Nchw,
};
let algo = problem.select_algorithm(oxicuda_ptx::arch::SmVersion::Sm80);
assert_eq!(algo, ConvAlgorithm::Direct);
}
#[test]
fn build_dgrad_problem_validates_dims() {
let filter = TensorDesc::<f32>::from_raw(
0,
vec![128, 64, 3, 3],
vec![576, 9, 3, 1],
TensorLayout::Nchw,
);
let grad_out = TensorDesc::<f32>::from_raw(
0,
vec![1, 128, 32, 32],
vec![131072, 1024, 32, 1],
TensorLayout::Nchw,
);
let grad_in = TensorDescMut::<f32>::from_raw(
0,
vec![1, 64, 32, 32],
vec![65536, 1024, 32, 1],
TensorLayout::Nchw,
);
let conv_desc = ConvolutionDescriptor::conv2d(1, 1, 1, 1, 1, 1, 1);
if let (Ok(f), Ok(go), Ok(gi), Ok(cd)) = (filter, grad_out, grad_in, conv_desc) {
let problem = build_dgrad_problem::<f32>(&f, &go, &gi, &cd);
assert!(problem.is_ok());
if let Ok(p) = problem {
assert_eq!(p.batch, 1);
assert_eq!(p.in_channels, 64);
assert_eq!(p.out_channels, 128);
}
}
}
}