tract-core 0.2.0

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::ops::prelude::*;

#[derive(Debug, Clone)]
pub enum PaddingSpec {
    Explicit(TVec<usize>, TVec<usize>),
    Valid,
    SameUpper,
    SameLower,
}

impl Default for PaddingSpec {
    fn default() -> PaddingSpec {
        PaddingSpec::Valid
    }
}

#[derive(Debug, Clone)]
pub struct ComputedPaddedDim<D: DimLike> {
    pub pad_before: TVec<D>,
    pub pad_after: TVec<D>,
    pub output: TVec<D>,
}

impl PaddingSpec {
    pub fn valid_dim(&self, d: usize) -> bool {
        match self {
            PaddingSpec::Valid => true,
            PaddingSpec::Explicit(a, b) => a[d] == 0 && b[d] == 0,
            _ => false,
        }
    }

    pub fn rm_axis(&self, d: usize) -> PaddingSpec {
        match self {
            PaddingSpec::Explicit(a, b) => {
                let mut a = a.clone();
                let mut b = b.clone();
                a.remove(d);
                b.remove(d);
                PaddingSpec::Explicit(a, b)
            }
            _ => self.clone(),
        }
    }

    pub fn compute<D: DimLike, KD: Into<D> + Copy>(
        &self,
        input_spatial_shape: &[D],
        kernel_spatial_shape: &[KD],
        dilations: &[usize],
        strides: &[usize],
    ) -> ComputedPaddedDim<D> {
        assert_eq!(dilations.len(), strides.len());
        assert_eq!(dilations.len(), input_spatial_shape.len());
        assert_eq!(dilations.len(), kernel_spatial_shape.len());
        match self {
            PaddingSpec::Valid => Self::explicit(
                input_spatial_shape,
                kernel_spatial_shape,
                dilations,
                strides,
                &*vec![0; kernel_spatial_shape.len()],
                &*vec![0; kernel_spatial_shape.len()],
            ),
            PaddingSpec::Explicit(ref bef, ref aft) => Self::explicit(
                input_spatial_shape,
                kernel_spatial_shape,
                dilations,
                strides,
                bef,
                aft,
            ),
            PaddingSpec::SameUpper => self.same(
                input_spatial_shape,
                kernel_spatial_shape,
                dilations,
                strides,
                true,
            ),
            PaddingSpec::SameLower => self.same(
                input_spatial_shape,
                kernel_spatial_shape,
                dilations,
                strides,
                false,
            ),
        }
    }

    fn explicit<D: DimLike, KD: Into<D> + Copy>(
        data_spatial_shape: &[D],
        kernel_spatial_shape: &[KD],
        dilations: &[usize],
        strides: &[usize],
        bef: &[usize],
        aft: &[usize],
    ) -> ComputedPaddedDim<D> {
        let spatial_rank = data_spatial_shape.len();
        assert_eq!(spatial_rank, kernel_spatial_shape.len());
        assert_eq!(spatial_rank, dilations.len());
        assert_eq!(spatial_rank, strides.len());
        assert_eq!(spatial_rank, aft.len());
        assert_eq!(spatial_rank, bef.len());
        let output_spatial_shape = (0..spatial_rank)
            .map(|ax| {
                let kernel_field = (kernel_spatial_shape[ax].into() - 1) * dilations[ax] + 1;
                let dim = (data_spatial_shape[ax] + bef[ax] + aft[ax] - kernel_field + 1)
                    .div_ceil(strides[ax]);
                dim
            })
            .collect();
        ComputedPaddedDim {
            output: output_spatial_shape,
            pad_before: bef.iter().map(|&x| D::from(x)).collect(),
            pad_after: aft.iter().map(|&x| D::from(x)).collect(),
        }
    }

    fn same<D: DimLike, KD: Into<D> + Copy>(
        &self,
        data_spatial_shape: &[D],
        kernel_spatial_shape: &[KD],
        dilations: &[usize],
        strides: &[usize],
        upper: bool,
    ) -> ComputedPaddedDim<D> {
        let spatial_rank = data_spatial_shape.len();
        let mut dims = tvec![];
        let mut pad_before = tvec![];
        let mut pad_after = tvec![];
        for ax in 0..spatial_rank {
            let dim = data_spatial_shape[ax].div_ceil(strides[ax]);
            let kernel_field = (kernel_spatial_shape[ax].into() - 1) * dilations[ax] + 1;
            dims.push(dim);
            let pad = (dim - 1) * strides[ax] + kernel_field - data_spatial_shape[ax];
            let lower_pad = pad / 2;
            let higher_pad = pad - pad / 2;
            if upper {
                pad_before.push(lower_pad);
                pad_after.push(higher_pad);
            } else {
                pad_after.push(lower_pad);
                pad_before.push(higher_pad);
            }
        }
        ComputedPaddedDim {
            pad_before,
            pad_after,
            output: dims,
        }
    }
}