dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use crate::{
    shapes::{Shape, Unit},
    tensor::{
        cpu::{Cpu, LendingIterator},
        Tensor, ZerosTensor,
    },
};

use super::{
    CmpKernel, EqKernelOp, GeKernelOp, GtKernelOp, LeKernelOp, LtKernelOp, NeKernelOp,
    ScalarCmpKernel,
};

trait CmpOpCpuKernel<E: Unit> {
    fn func(lhs: E, rhs: E) -> bool;
}

impl<Op: CmpOpCpuKernel<E>, E: Unit> CmpKernel<Op, E> for Cpu {
    fn forward<S: Shape, T>(
        &self,
        lhs: &Tensor<S, E, Self, T>,
        rhs: &Tensor<S, E, Self, T>,
    ) -> Result<Tensor<S, bool, Self>, Self::Err> {
        let mut out: Tensor<S, bool, Self> = self.try_zeros_like(&lhs.shape)?;
        let mut lhs_iter = lhs.iter();
        let mut rhs_iter = rhs.iter();
        let mut out_iter = out.iter_mut();
        while let Some((o, (l, r))) = out_iter.next().zip(lhs_iter.next().zip(rhs_iter.next())) {
            *o = Op::func(*l, *r);
        }
        Ok(out)
    }
}

impl<Op: CmpOpCpuKernel<E>, E: Unit> ScalarCmpKernel<Op, E> for Cpu {
    fn forward<S: Shape, T>(
        &self,
        lhs: &Tensor<S, E, Self, T>,
        scalar: E,
    ) -> Result<Tensor<S, bool, Self>, Self::Err> {
        let mut out: Tensor<S, bool, Self> = self.try_zeros_like(&lhs.shape)?;
        let mut lhs_iter = lhs.iter();
        let mut out_iter = out.iter_mut();
        while let Some((o, l)) = out_iter.next().zip(lhs_iter.next()) {
            *o = Op::func(*l, scalar);
        }
        Ok(out)
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for EqKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs == rhs
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for NeKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs != rhs
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for GtKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs > rhs
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for GeKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs >= rhs
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for LtKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs < rhs
    }
}

impl<E: Unit> CmpOpCpuKernel<E> for LeKernelOp {
    fn func(lhs: E, rhs: E) -> bool {
        lhs <= rhs
    }
}