dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use crate::{shapes::*, tensor::*};

use std::sync::Arc;

use num_traits::{Float, FromPrimitive};

fn make_4d<S: Shape>(strides: S::Concrete) -> [usize; 4] {
    match S::NUM_DIMS {
        3 => [0, strides[0], strides[1], strides[2]],
        4 => [strides[0], strides[1], strides[2], strides[3]],
        _ => panic!("Only implemented for 3d & 4d arrays"),
    }
}

impl super::Pool2DKind {
    fn init<E: Float>(&self) -> E {
        match self {
            super::Pool2DKind::Avg => E::zero(),
            super::Pool2DKind::Min => E::infinity(),
            super::Pool2DKind::Max => E::neg_infinity(),
        }
    }

    fn accum<E: Float>(&self, accum: &E, item: &E) -> E {
        match self {
            super::Pool2DKind::Avg => *accum + *item,
            super::Pool2DKind::Min => accum.min(*item),
            super::Pool2DKind::Max => accum.max(*item),
        }
    }

    fn normalize<E: Float + FromPrimitive>(&self, item: E, num_elements: usize) -> E {
        match self {
            super::Pool2DKind::Avg => item * E::from_f64(1.0 / num_elements as f64).unwrap(),
            super::Pool2DKind::Min => item,
            super::Pool2DKind::Max => item,
        }
    }

    fn filter<E: Float>(&self, item: E, needle: E, haystack: E) -> E {
        match self {
            super::Pool2DKind::Avg => item,
            super::Pool2DKind::Min => {
                if needle == haystack {
                    item
                } else {
                    E::zero()
                }
            }
            super::Pool2DKind::Max => {
                if needle == haystack {
                    item
                } else {
                    E::zero()
                }
            }
        }
    }
}

impl<E: Float + Dtype> super::Pool2DKernel<E> for Cpu {
    fn alloc<S: Shape>(&self, s: S) -> Result<Tensor<S, E, Self>, Self::Err> {
        self.try_zeros_like(&s)
    }
    fn forward<I: Shape, O: Shape>(
        &self,
        op: super::Pool2DOp,
        inp: &Tensor<I, E, Self>,
        out: &mut Tensor<O, E, Self>,
    ) -> Result<(), Self::Err> {
        let istr = make_4d::<I>(inp.strides);
        let ostr = make_4d::<O>(out.strides);

        let buf = inp.data.as_ref();
        let out_buf = Arc::make_mut(&mut out.data);
        for b in 0..op.batch {
            for c in 0..op.chan {
                for oh in 0..op.h_out {
                    for ow in 0..op.w_out {
                        let mut tmp = op.kind.init();
                        for k1 in 0..op.kernel {
                            let y = (oh * op.stride + op.dilation * k1).checked_sub(op.padding);
                            for k2 in 0..op.kernel {
                                let x = (ow * op.stride + op.dilation * k2).checked_sub(op.padding);
                                if let Some((y, x)) = y.zip(x) {
                                    if y < op.h_in && x < op.w_in {
                                        let inp_idx =
                                            b * istr[0] + c * istr[1] + y * istr[2] + x * istr[3];
                                        tmp = op.kind.accum(&tmp, &buf[inp_idx]);
                                    }
                                }
                            }
                        }
                        tmp = op.kind.normalize(tmp, op.kernel * op.kernel);
                        out_buf[b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3]] = tmp;
                    }
                }
            }
        }
        Ok(())
    }
    fn backward<I: Shape, O: Shape>(
        &self,
        op: super::Pool2DOp,
        inp: &Tensor<I, E, Self>,
        grad_inp: &mut Self::Vec,
        out: &Tensor<O, E, Self>,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let istr = make_4d::<I>(inp.strides);
        let ostr = make_4d::<O>(out.strides);

        let inp_buf = inp.data.as_ref();
        let out_buf = out.data.as_ref();

        for b in 0..op.batch {
            for c in 0..op.chan {
                for oh in 0..op.h_out {
                    for ow in 0..op.w_out {
                        let out_idx = b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3];
                        let go = grad_out[out_idx];
                        let go = op.kind.normalize(go, op.kernel * op.kernel);
                        let vo = out_buf[out_idx];

                        for k1 in 0..op.kernel {
                            let y = (oh * op.stride + op.dilation * k1).checked_sub(op.padding);
                            for k2 in 0..op.kernel {
                                let x = (ow * op.stride + op.dilation * k2).checked_sub(op.padding);
                                if let Some((y, x)) = y.zip(x) {
                                    if x < op.w_in && y < op.h_in {
                                        let inp_idx =
                                            b * istr[0] + c * istr[1] + y * istr[2] + x * istr[3];
                                        grad_inp[inp_idx] +=
                                            op.kind.filter(go, inp_buf[inp_idx], vo);
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        Ok(())
    }
}