use oxicuda_blas::GpuFloat;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
use crate::types::{ConvAlgorithm, ConvolutionDescriptor, TensorDesc, TensorDescMut, TensorLayout};
use super::algo_select;
#[derive(Debug, Clone)]
pub struct ConvProblem {
pub batch: u32,
pub in_channels: u32,
pub in_dims: Vec<u32>,
pub out_channels: u32,
pub filter_dims: Vec<u32>,
pub padding: Vec<u32>,
pub stride: Vec<u32>,
pub dilation: Vec<u32>,
pub groups: u32,
pub input_type: PtxType,
pub output_type: PtxType,
pub layout: TensorLayout,
}
impl ConvProblem {
pub fn output_h(&self) -> DnnResult<u32> {
let h_idx = self.h_index();
ConvolutionDescriptor::output_size(
self.in_dims[h_idx],
self.filter_dims[h_idx],
self.padding[h_idx],
self.stride[h_idx],
self.dilation[h_idx],
)
}
pub fn output_w(&self) -> DnnResult<u32> {
let w_idx = self.w_index();
ConvolutionDescriptor::output_size(
self.in_dims[w_idx],
self.filter_dims[w_idx],
self.padding[w_idx],
self.stride[w_idx],
self.dilation[w_idx],
)
}
pub fn output_dims(&self) -> DnnResult<Vec<u32>> {
self.in_dims
.iter()
.zip(self.filter_dims.iter())
.zip(self.padding.iter())
.zip(self.stride.iter())
.zip(self.dilation.iter())
.map(|((((&inp, &flt), &pad), &str_), &dil)| {
ConvolutionDescriptor::output_size(inp, flt, pad, str_, dil)
})
.collect()
}
#[must_use]
pub fn is_1x1(&self) -> bool {
self.filter_dims.iter().all(|&d| d == 1)
&& self.stride.iter().all(|&s| s == 1)
&& self.dilation.iter().all(|&d| d == 1)
}
#[must_use]
pub fn is_depthwise(&self) -> bool {
self.groups == self.in_channels && self.groups == self.out_channels
}
#[must_use]
pub fn is_grouped(&self) -> bool {
self.groups > 1 && !self.is_depthwise()
}
pub fn conv_to_gemm_dims(&self) -> DnnResult<(u32, u32, u32)> {
let out_dims = self.output_dims()?;
let spatial_product: u32 = out_dims.iter().product();
let gemm_m = self.batch.saturating_mul(spatial_product);
let gemm_n = self.out_channels;
let filter_volume: u32 = self.filter_dims.iter().product();
let channels_per_group = self.in_channels / self.groups;
let gemm_k = channels_per_group.saturating_mul(filter_volume);
Ok((gemm_m, gemm_n, gemm_k))
}
#[must_use]
pub fn select_algorithm(&self, sm: SmVersion) -> ConvAlgorithm {
algo_select::select_algorithm(self, sm)
}
pub fn from_descriptors<T: GpuFloat>(
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
output: &TensorDescMut<T>,
conv_desc: &ConvolutionDescriptor,
) -> DnnResult<Self> {
let layout = input.layout;
let ndim = layout.expected_ndim();
let spatial = layout.spatial_dims();
if input.dims.len() != ndim {
return Err(DnnError::InvalidDimension(format!(
"input has {} dims, expected {ndim} for {:?} layout",
input.dims.len(),
layout
)));
}
if filter.dims.len() != ndim {
return Err(DnnError::InvalidDimension(format!(
"filter has {} dims, expected {ndim} for {:?} layout",
filter.dims.len(),
layout
)));
}
if output.dims.len() != ndim {
return Err(DnnError::InvalidDimension(format!(
"output has {} dims, expected {ndim} for {:?} layout",
output.dims.len(),
layout
)));
}
if conv_desc.padding.len() != spatial {
return Err(DnnError::InvalidDimension(format!(
"conv_desc padding length {} != spatial dims {spatial}",
conv_desc.padding.len()
)));
}
let (batch, in_channels, in_dims) = Self::extract_input_dims(input)?;
let (out_channels, filter_dims) = Self::extract_filter_dims(filter, spatial)?;
Ok(Self {
batch,
in_channels,
in_dims,
out_channels,
filter_dims,
padding: conv_desc.padding.clone(),
stride: conv_desc.stride.clone(),
dilation: conv_desc.dilation.clone(),
groups: conv_desc.groups,
input_type: T::PTX_TYPE,
output_type: T::PTX_TYPE,
layout,
})
}
pub fn validate(&self) -> DnnResult<()> {
if self.groups == 0 {
return Err(DnnError::InvalidArgument("groups must be >= 1".into()));
}
if self.in_channels % self.groups != 0 {
return Err(DnnError::InvalidArgument(format!(
"in_channels ({}) not divisible by groups ({})",
self.in_channels, self.groups
)));
}
if self.out_channels % self.groups != 0 {
return Err(DnnError::InvalidArgument(format!(
"out_channels ({}) not divisible by groups ({})",
self.out_channels, self.groups
)));
}
let n_spatial = self.in_dims.len();
if self.filter_dims.len() != n_spatial {
return Err(DnnError::InvalidDimension(format!(
"filter spatial dims ({}) != input spatial dims ({n_spatial})",
self.filter_dims.len()
)));
}
if self.padding.len() != n_spatial
|| self.stride.len() != n_spatial
|| self.dilation.len() != n_spatial
{
return Err(DnnError::InvalidDimension(
"padding/stride/dilation length mismatch with spatial dims".into(),
));
}
for (i, &s) in self.stride.iter().enumerate() {
if s == 0 {
return Err(DnnError::InvalidArgument(format!("stride[{i}] is zero")));
}
}
for (i, &d) in self.dilation.iter().enumerate() {
if d == 0 {
return Err(DnnError::InvalidArgument(format!("dilation[{i}] is zero")));
}
}
let _out_dims = self.output_dims()?;
Ok(())
}
fn h_index(&self) -> usize {
if self.in_dims.len() == 3 { 1 } else { 0 }
}
fn w_index(&self) -> usize {
if self.in_dims.len() == 3 { 2 } else { 1 }
}
fn extract_input_dims<T: GpuFloat>(input: &TensorDesc<T>) -> DnnResult<(u32, u32, Vec<u32>)> {
let batch = input.dims[0];
let in_channels = input.dims[1];
let spatial = input.dims[2..].to_vec();
Ok((batch, in_channels, spatial))
}
fn extract_filter_dims<T: GpuFloat>(
filter: &TensorDesc<T>,
spatial_count: usize,
) -> DnnResult<(u32, Vec<u32>)> {
if filter.dims.len() < 2 + spatial_count {
return Err(DnnError::InvalidDimension(format!(
"filter has {} dims, expected at least {}",
filter.dims.len(),
2 + spatial_count
)));
}
let out_channels = filter.dims[0];
let filter_spatial = filter.dims[2..2 + spatial_count].to_vec();
Ok((out_channels, filter_spatial))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_problem_3x3() -> ConvProblem {
ConvProblem {
batch: 1,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 128,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
fn make_problem_1x1() -> ConvProblem {
ConvProblem {
batch: 2,
in_channels: 256,
in_dims: vec![16, 16],
out_channels: 512,
filter_dims: vec![1, 1],
padding: vec![0, 0],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
fn make_depthwise() -> ConvProblem {
ConvProblem {
batch: 1,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 64,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 64,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
#[test]
fn output_h_basic() {
let p = make_problem_3x3();
assert_eq!(p.output_h().ok(), Some(32));
}
#[test]
fn output_w_basic() {
let p = make_problem_3x3();
assert_eq!(p.output_w().ok(), Some(32));
}
#[test]
fn is_1x1_true() {
assert!(make_problem_1x1().is_1x1());
}
#[test]
fn is_1x1_false() {
assert!(!make_problem_3x3().is_1x1());
}
#[test]
fn is_depthwise_true() {
assert!(make_depthwise().is_depthwise());
}
#[test]
fn is_depthwise_false() {
assert!(!make_problem_3x3().is_depthwise());
}
#[test]
fn conv_to_gemm_dims_3x3() {
let p = make_problem_3x3();
let (m, n, k) = p.conv_to_gemm_dims().ok().unwrap_or((0, 0, 0));
assert_eq!(m, 1024);
assert_eq!(n, 128);
assert_eq!(k, 576);
}
#[test]
fn conv_to_gemm_dims_1x1() {
let p = make_problem_1x1();
let (m, n, k) = p.conv_to_gemm_dims().ok().unwrap_or((0, 0, 0));
assert_eq!(m, 512);
assert_eq!(n, 512);
assert_eq!(k, 256);
}
#[test]
fn validate_ok() {
assert!(make_problem_3x3().validate().is_ok());
}
#[test]
fn validate_zero_groups() {
let mut p = make_problem_3x3();
p.groups = 0;
assert!(p.validate().is_err());
}
#[test]
fn validate_channels_not_divisible() {
let mut p = make_problem_3x3();
p.groups = 3; assert!(p.validate().is_err());
}
#[test]
fn validate_zero_stride() {
let mut p = make_problem_3x3();
p.stride[0] = 0;
assert!(p.validate().is_err());
}
#[test]
fn output_dims_strided() {
let mut p = make_problem_3x3();
p.stride = vec![2, 2];
let out = p.output_dims().ok().unwrap_or_default();
assert_eq!(out, vec![16, 16]);
}
}