dorset 0.1.0

A rust implementation of Stan Math automatic differentiation library.
Documentation
//! Operation multiplication.
use super::*;
pub use std::ops::Mul;

fn chain_multiplication(vi: &Vari) {
    let adj = vi.adj();
    match (vi.a.clone(), vi.b.clone()) {
        (Operand::Vari(ptr), Operand::Data(bd)) => {
            let avi: &mut Vari = ptr.clone().into();
            let avi_adj = avi.adj();
            avi.set_adj(avi_adj + adj * bd);
        }
        (Operand::Data(ad), Operand::Vari(ptr)) => {
            let bvi: &mut Vari = ptr.clone().into();
            let bvi_adj = bvi.adj();
            bvi.set_adj(bvi_adj + adj * ad);
        }
        (Operand::Vari(ap), Operand::Vari(bp)) => {
            let avi: &mut Vari = ap.clone().into();
            let avi_val = avi.val();
            let avi_adj = avi.adj();
            let bvi: &mut Vari = bp.clone().into();
            let bvi_val = bvi.val();
            let bvi_adj = bvi.adj();
            avi.set_adj(avi_adj + adj * bvi_val);
            bvi.set_adj(bvi_adj + adj * avi_val);
        }
        _ => {}
    }
}

binop!(impl Mul, mul
       for Var, Real, |x, y| x * y,
       for Real, Var, |x, y| x * y,
       for Var, Var, |x, y| x * y,
       chain Fn = chain_multiplication);

#[cfg(test)]
mod test {
    use super::*;
    use core::memory::*;

    #[test]
    fn mul() {
        let mut x: Real = 3.6;
        let mut y: Real = 3.0;
        let stack = cstack!();
        let vx = var!(stack, x);
        let vy = var!(stack, y);
        let mut v = &vx * &vy;
        v.grad();
        assert_eq!(v.val(), x * y);
        assert_eq!(vx.adj(), y);
        assert_eq!(vy.adj(), x);

        v.set_zero_all_adjoints();
        y = 8.9;
        v = &vx * &y;
        v.grad();
        assert_eq!(v.val(), x * y);
        assert_eq!(vx.adj(), y);
        assert_eq!(vy.adj(), 0.0);

        v.set_zero_all_adjoints();
        x = 83.1;
        v = &x * &vy;
        v.grad();
        assert_eq!(v.val(), x * vy.val());
        assert_eq!(vx.adj(), 0.0);
        assert_eq!(vy.adj(), x);
    }
}