dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use crate::{shapes::*, tensor::*};

use super::ReshapeTo;

#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub enum Pool2DKind {
    Avg,
    Min,
    Max,
}

#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct Pool2DOp {
    pub kind: Pool2DKind,
    pub kernel: usize,
    pub stride: usize,
    pub padding: usize,
    pub dilation: usize,
    pub batch: usize,
    pub chan: usize,
    pub h_in: usize,
    pub h_out: usize,
    pub w_in: usize,
    pub w_out: usize,
}

pub(super) trait Pool2DKernel<E: Dtype>: Storage<E> {
    fn alloc<S: Shape>(&self, s: S) -> Result<Tensor<S, E, Self>, Self::Err>;

    fn forward<I: Shape, O: Shape>(
        &self,
        op: Pool2DOp,
        inp: &Tensor<I, E, Self>,
        out: &mut Tensor<O, E, Self>,
    ) -> Result<(), Self::Err>;

    #[allow(clippy::too_many_arguments)]
    fn backward<I: Shape, O: Shape>(
        &self,
        op: Pool2DOp,
        inp: &Tensor<I, E, Self>,
        grad_inp: &mut Self::Vec,
        out: &Tensor<O, E, Self>,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err>;
}

pub trait TryPool2D<Kernel, Stride, Padding, Dilation>: Sized {
    type Pooled;
    type Error: std::fmt::Debug;

    fn pool2d(
        self,
        kind: Pool2DKind,
        kernel: Kernel,
        stride: Stride,
        padding: Padding,
        dilation: Dilation,
    ) -> Self::Pooled {
        self.try_pool2d(kind, kernel, stride, padding, dilation)
            .unwrap()
    }

    fn try_pool2d(
        self,
        kind: Pool2DKind,
        kernel: Kernel,
        stride: Stride,
        padding: Padding,
        dilation: Dilation,
    ) -> Result<Self::Pooled, Self::Error>;
}

impl<
        const KERNEL: usize,
        const STRIDE: usize,
        const PADDING: usize,
        const DILATION: usize,
        const DIM: usize,
    > TryPool2D<Const<KERNEL>, Const<STRIDE>, Const<PADDING>, Const<DILATION>> for Const<DIM>
where
    Const<{ (DIM + 2 * PADDING - DILATION * (KERNEL - 1) - 1) / STRIDE + 1 }>: Sized,
{
    type Pooled = Const<{ (DIM + 2 * PADDING - DILATION * (KERNEL - 1) - 1) / STRIDE + 1 }>;
    type Error = std::convert::Infallible;
    fn try_pool2d(
        self,
        _: Pool2DKind,
        _: Const<KERNEL>,
        _: Const<STRIDE>,
        _: Const<PADDING>,
        _: Const<DILATION>,
    ) -> Result<Self::Pooled, Self::Error> {
        Ok(Const)
    }
}

impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim>
    TryPool2D<Kernel, Stride, Padding, Dilation> for usize
{
    type Pooled = usize;
    type Error = std::convert::Infallible;
    fn try_pool2d(
        self,
        _: Pool2DKind,
        kernel: Kernel,
        stride: Stride,
        padding: Padding,
        dilation: Dilation,
    ) -> Result<Self::Pooled, Self::Error> {
        Ok((self + 2 * padding.size() - 1)
            .checked_sub(dilation.size() * (kernel.size() - 1))
            .unwrap()
            / stride.size()
            + 1)
    }
}

impl<Chan, Kernel, Stride, Padding, Dilation, H, W, E, D, T>
    TryPool2D<Kernel, Stride, Padding, Dilation> for Tensor<(Chan, H, W), E, D, T>
where
    Chan: Dim,
    Kernel: Dim,
    Stride: Dim,
    Padding: Dim,
    Dilation: Dim,
    H: Dim + TryPool2D<Kernel, Stride, Padding, Dilation>,
    H::Pooled: Dim,
    W: Dim + TryPool2D<Kernel, Stride, Padding, Dilation>,
    W::Pooled: Dim,
    E: Dtype,
    D: Pool2DKernel<E> + crate::tensor_ops::reshape_to::ReshapeKernel<E>,
    T: Tape<E, D>,
{
    type Pooled = Tensor<(Chan, H::Pooled, W::Pooled), E, D, T>;
    type Error = D::Err;

    fn try_pool2d(
        self,
        kind: Pool2DKind,
        kernel: Kernel,
        stride: Stride,
        padding: Padding,
        dilation: Dilation,
    ) -> Result<Self::Pooled, Self::Error> {
        let (chan, h, w) = self.shape;
        let img = self.try_reshape_like(&(Const::<1>, chan, h, w))?;
        let out = img.try_pool2d(kind, kernel, stride, padding, dilation)?;
        let (_, _, out_h, out_w) = out.shape;
        out.try_reshape_like(&(chan, out_h, out_w))
    }
}

impl<Chan, Kernel, Stride, Padding, Dilation, Batch, H, W, E, D, T>
    TryPool2D<Kernel, Stride, Padding, Dilation> for Tensor<(Batch, Chan, H, W), E, D, T>
where
    Chan: Dim,
    Kernel: Dim,
    Stride: Dim,
    Padding: Dim,
    Dilation: Dim,
    Batch: Dim,
    H: Dim + TryPool2D<Kernel, Stride, Padding, Dilation>,
    H::Pooled: Dim,
    W: Dim + TryPool2D<Kernel, Stride, Padding, Dilation>,
    W::Pooled: Dim,
    E: Dtype,
    D: Pool2DKernel<E>,
    T: Tape<E, D>,
{
    type Pooled = Tensor<(Batch, Chan, H::Pooled, W::Pooled), E, D, T>;
    type Error = D::Err;

    fn try_pool2d(
        self,
        kind: Pool2DKind,
        kernel: Kernel,
        stride: Stride,
        padding: Padding,
        dilation: Dilation,
    ) -> Result<Self::Pooled, Self::Error> {
        let (batch, chan, h, w) = self.shape;
        if self.strides != self.shape.strides() {
            panic!("Image input to pool2d must be contiguous");
        }
        let h_out = h.pool2d(kind, kernel, stride, padding, dilation);
        let w_out = w.pool2d(kind, kernel, stride, padding, dilation);
        let op = Pool2DOp {
            kind,
            stride: stride.size(),
            padding: padding.size(),
            kernel: kernel.size(),
            dilation: dilation.size(),
            batch: batch.size(),
            chan: chan.size(),
            h_in: h.size(),
            h_out: h_out.size(),
            w_in: w.size(),
            w_out: w_out.size(),
        };
        let (img, mut tape) = self.split_tape();
        let mut out = img.device.alloc((batch, chan, h_out, w_out))?;
        img.device.forward(op, &img, &mut out)?;
        let img_ghost = img.ghost();
        let out_ghost = out.ghost();
        let out_clone = out.clone();
        tape.add_backward_op(move |grads| {
            grads.try_alloc_for(&img_ghost)?;
            grads.try_alloc_for(&out_ghost)?;
            let (grad_img, grad_out) = grads.mut_and_ref(&img_ghost, &out_ghost);
            img.device
                .backward(op, &img, grad_img, &out_clone, grad_out)
        });
        Ok(out.put_tape(tape))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{tensor_ops::*, tests::*};

    #[test]
    fn test_pool2d_3d_max2d_eq_grads() {
        let dev: TestDevice = Default::default();
        let x = dev
            .tensor([[[1.0, 1., 0.5, 0.2], [0.2, 0.2, 0.5, 1.2]]])
            .to_dtype::<TestDtype>();
        let r = x.leaky_trace().pool2d(
            Pool2DKind::Max,
            Const::<2>,
            Const::<1>,
            Const::<0>,
            Const::<1>,
        );
        assert_close_to_literal!(r, [[[1., 1., 1.2]]]);
        let g = r.sum().backward();
        assert_close_to_literal!(g.get(&x), [[[1., 2., 0., 0.], [0., 0., 0., 1.]]]);
    }

    #[test]
    fn test_pool2d_3d_min2d_eq_grads() {
        let dev: TestDevice = Default::default();
        let x = dev
            .tensor([[[1., 1., 0.5, 0.2], [0.2, 0.2, 0.5, 1.2]]])
            .to_dtype::<TestDtype>();
        let r = x.leaky_trace().pool2d(
            Pool2DKind::Min,
            Const::<2>,
            Const::<1>,
            Const::<0>,
            Const::<1>,
        );
        assert_close_to_literal!(r, [[[0.2, 0.2, 0.2]]]);
        let g = r.sum().backward();
        assert_close_to_literal!(g.get(&x), [[[0., 0., 0., 1.], [1., 2., 0., 0.]]]);
    }

    #[test]
    fn test_pool2d_3d_max2d() {
        let dev = TestDevice::seed_from_u64(234);
        let x: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
        let r = x.leaky_trace().pool2d(
            Pool2DKind::Max,
            Const::<2>,
            Const::<2>,
            Const::<0>,
            Const::<1>,
        );
        assert_close_to_literal!(r, [[[1.79155397, 1.10126066]], [[1.14464748, 2.26301837]]]);
        let g = r.exp().mean().backward();
        #[rustfmt::skip]
        assert_close_to_literal!(
            g.get(&x),
            [
                [[1.49969184, 0., 0., 0.75198889],[0., 0., 0., 0.],[0., 0., 0., 0.]],
                [[0., 0., 2.40301466, 0.],[0.78533345, 0., 0., 0.],[0., 0., 0., 0.]]
            ]
        );
    }

    #[test]
    fn test_pool2d_3d_min2d() {
        let dev = TestDevice::seed_from_u64(234);
        let x: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
        let r = x.leaky_trace().pool2d(
            Pool2DKind::Min,
            Const::<2>,
            Const::<2>,
            Const::<0>,
            Const::<1>,
        );
        assert_close_to_literal!(
            r,
            [[[-1.09635627, -1.07717276]], [[-0.01996479, -1.82562149]]]
        );
        let g = r.exp().mean().backward();
        #[rustfmt::skip]
        assert_close_to_literal!(
            g.get(&x),
            [
                [[0., 0., 0., 0.],[0.083521545, 0., 0., 0.08513925],[0., 0., 0., 0.]],
                [[0., 0.2450583, 0., 0.04027937],[0., 0., 0., 0.],[0., 0., 0., 0.]],
            ]
        );
    }

    #[test]
    fn test_pool2d_4d_avg2d() {
        let dev = TestDevice::seed_from_u64(234);
        let x: Tensor<Rank4<2, 4, 2, 2>, TestDtype, _> = dev.sample_normal();
        let r = x.leaky_trace().pool2d(
            Pool2DKind::Avg,
            Const::<1>,
            Const::<2>,
            Const::<0>,
            Const::<1>,
        );
        assert_close_to_literal!(
            r,
            [
                [[[1.791554]], [[-1.0963563]], [[0.86268073]], [[0.28538525]]],
                [[[1.1446475]], [[0.2833436]], [[-1.2026008]], [[0.21327473]]],
            ]
        );
        let g = r.exp().mean().backward();
        #[rustfmt::skip]
        assert_close_to_literal!(
            g.get(&x),
            [
                [[[0.7498459, 0.0], [0.0, 0.0]],[[0.041760772, 0.0], [0.0, 0.0]],[[0.29618803, 0.0], [0.0, 0.0]],[[0.16628431, 0.0], [0.0, 0.0]]],
                [[[0.39266673, 0.0], [0.0, 0.0]],[[0.16594516, 0.0], [0.0, 0.0]],[[0.037551485, 0.0], [0.0, 0.0]],[[0.15471558, 0.0], [0.0, 0.0]]]
            ]
        );
    }

    #[test]
    fn test_pool2d_dilated() {
        let dev: TestDevice = Default::default();
        let x = dev
            .tensor([[
                [0., 1., 2., 4., 5.],
                [6., 7., 8., 9., 10.],
                [11., 12., 13., 14., 15.],
                [16., 17., 18., 19., 20.],
            ]])
            .to_dtype::<TestDtype>();
        let y_max = x.leaky_trace().pool2d(
            Pool2DKind::Max,
            Const::<2>,
            Const::<1>,
            Const::<0>,
            Const::<2>,
        );
        assert_close_to_literal!(y_max, [[[13., 14., 15.], [18., 19., 20.]]]);
        let y_min = x.clone().pool2d(
            Pool2DKind::Min,
            Const::<2>,
            Const::<1>,
            Const::<0>,
            Const::<2>,
        );
        assert_close_to_literal!(y_min, [[[0., 1., 2.], [6., 7., 8.]]]);

        let grads = y_max.mean().backward();
        let v = 1.0 / 6.0;
        assert_close_to_literal!(
            grads.get(&x),
            [[
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., v, v, v],
                [0., 0., v, v, v]
            ]]
        );
    }
}