dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use super::{AdamConfig, AdamKernel, WeightDecay};
use crate::{
    dtypes::{Dtype, NotMixedPrecision},
    tensor::Cpu,
};

#[cfg(feature = "f16")]
impl AdamKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
    fn adam_kernel(
        &self,
        t: i32,
        cfg: &AdamConfig,
        param: &mut Self::Vec,
        moment1: &mut Self::Vec,
        moment2: &mut Self::Vec,
        grad: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let betas = cfg.betas.map(|x| x as f32);
        let eps = cfg.eps as f32;
        let lr = cfg.lr as f32;

        for ((p, g), (m, v)) in param
            .iter_mut()
            .zip(grad.iter().cloned())
            .zip(moment1.iter_mut().zip(moment2.iter_mut()))
        {
            let p_f32 = p.0.to_f32();
            let mut g_f32 = g.0.to_f32();
            let mut m_f32 = m.0.to_f32();
            let mut v_f32 = v.0.to_f32();

            if let Some(WeightDecay::L2(wd)) = cfg.weight_decay {
                g_f32 += (wd as f32) * p_f32;
            }

            m_f32 = m_f32 * betas[0] + g_f32 * (1.0 - betas[0]);
            v_f32 = v_f32 * betas[1] + g_f32.powi(2) * (1.0 - betas[1]);
            let m_hat = m_f32 * (1.0 - betas[0].powi(t)).recip();
            let v_hat = v_f32 * (1.0 - betas[1].powi(t)).recip();
            g_f32 = lr * m_hat / (v_hat.sqrt() + eps);

            if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay {
                g_f32 += (wd * cfg.lr) as f32 * p_f32;
            }

            p.0 = crate::dtypes::f16::from_f32(p_f32 - g_f32);
            m.0 = crate::dtypes::f16::from_f32(m_f32);
            v.0 = crate::dtypes::f16::from_f32(v_f32);
        }
        Ok(())
    }
}

impl<E: num_traits::Float + Dtype + NotMixedPrecision> AdamKernel<E> for Cpu {
    fn adam_kernel(
        &self,
        t: i32,
        cfg: &AdamConfig,
        param: &mut Self::Vec,
        moment1: &mut Self::Vec,
        moment2: &mut Self::Vec,
        grad: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let betas = cfg.betas.map(E::from_f64).map(Option::unwrap);
        let eps = E::from_f64(cfg.eps).unwrap();
        let lr = E::from_f64(cfg.lr).unwrap();

        for ((p, mut g), (m, v)) in param
            .iter_mut()
            .zip(grad.iter().cloned())
            .zip(moment1.iter_mut().zip(moment2.iter_mut()))
        {
            if let Some(WeightDecay::L2(wd)) = cfg.weight_decay {
                g += E::from_f64(wd).unwrap() * *p;
            }

            *m = *m * betas[0] + g * (E::one() - betas[0]);
            *v = *v * betas[1] + g.powi(2) * (E::one() - betas[1]);
            let m_hat = *m * (E::one() - betas[0].powi(t)).recip();
            let v_hat = *v * (E::one() - betas[1].powi(t)).recip();
            g = lr * m_hat / (v_hat.sqrt() + eps);

            if let Some(WeightDecay::Decoupled(wd)) = cfg.weight_decay {
                g += E::from_f64(wd * cfg.lr).unwrap() * *p;
            }

            *p -= g;
        }
        Ok(())
    }
}