use crate::ad_ndarray::scalar::*;
use crate::autodiff::AutoDiff;
use crate::autodiffable::*;
use crate::diffable::Diffable;
use crate::autotuple::*;
use crate::compose::*;
use crate::traits::InstOne;
use ndarray::prelude::*;
use num::complex::Complex;
use crate::funcs::Identity;
use crate as autodiff;
use autodiff_derive::*;
#[derive(Clone, Copy, SimpleForwardDiffable, FuncCompose)]
struct Sum1 {}
impl Diffable<()> for Sum1 {
type Input = Array1<f64>;
type Output = Scalar<f64>;
}
impl AutoDiffable<()> for Sum1 {
fn eval(&self, x: &Array1<f64>, _: &()) -> Scalar<f64> {
Scalar::new(x.sum())
}
fn eval_grad(
&self,
x: &Array1<f64>,
_: &(),
) -> (Scalar<f64>, Array1<f64>) {
(self.eval(x, &()), x.one())
}
}
#[derive(Clone, Copy, SimpleForwardDiffable, FuncCompose)]
struct Sum2 {}
impl Diffable<()> for Sum2 {
type Input = Array2<f64>;
type Output = Scalar<f64>;
}
impl AutoDiffable<()> for Sum2 {
fn eval(&self, x: &Array2<f64>, _: &()) -> Scalar<f64> {
Scalar::new(x.sum())
}
fn eval_grad(
&self,
x: &Array2<f64>,
_: &(),
) -> (Scalar<f64>, Array2<f64>) {
(self.eval(x, &()), x.one())
}
}
#[derive(Clone, Copy, SimpleForwardDiffable, FuncCompose)]
struct Prod2 {}
impl Diffable<()> for Prod2 {
type Input = Array2<f64>;
type Output = Scalar<f64>;
}
impl AutoDiffable<()> for Prod2 {
fn eval(&self, x: &Array2<f64>, _: &()) -> Scalar<f64> {
Scalar::new(x.iter().product())
}
fn eval_grad(
&self,
x: &Array2<f64>,
_: &(),
) -> (Scalar<f64>, Array2<f64>) {
let mut grad = Array2::<f64>::ones(x.raw_dim());
for i in 0..x.dim().0 {
for j in 0..x.dim().1 {
let mut x = x.clone();
x[[i, j]] = 1.0;
grad[[i, j]] = self.eval(&x, &()).value();
}
}
(self.eval(x, &()), grad)
}
}
#[derive(Clone, Copy, SimpleForwardDiffable, FuncCompose)]
struct UpcastN {
n: usize,
}
impl Diffable<()> for UpcastN {
type Input = Array1<f64>;
type Output = Array2<f64>;
}
impl AutoDiffable<()> for UpcastN {
fn eval(&self, x: &Array1<f64>, _: &()) -> Array2<f64> {
Array2::from_shape_fn((self.n, x.len()), |(_, i)| x[i])
}
fn eval_grad(
&self,
x: &Array1<f64>,
_: &(),
) -> (Array2<f64>, Array3<f64>) {
let mut grad = Array3::<f64>::zeros((x.len(), self.n, x.len()));
for i in 0..x.len() {
for j in 0..self.n {
grad[[i, j, i]] = 1.0;
}
}
(self.eval(x, &()), grad)
}
}
#[derive(Clone, Copy, SimpleForwardDiffable, FuncCompose)]
struct VertCastN {
n: usize,
}
impl Diffable<()> for VertCastN {
type Input = Array1<f64>;
type Output = Array2<f64>;
}
impl AutoDiffable<()> for VertCastN {
fn eval(&self, x: &Array1<f64>, _: &()) -> Array2<f64> {
Array2::from_shape_fn((x.len(), self.n), |(i, _)| x[i])
}
fn eval_grad(
&self,
x: &Array1<f64>,
_: &(),
) -> (Array2<f64>, Array3<f64>) {
let mut grad = Array3::<f64>::zeros((x.len(), x.len(), self.n));
for i in 0..x.len() {
for j in 0..self.n {
grad[[i, i, j]] = 1.0;
}
}
(self.eval(x, &()), grad)
}
}
#[derive(Clone, Copy)]
struct SumAutoTuples {}
type SInput = AutoTuple<(Array2<f64>, Array1<f64>)>;
type SOutput = AutoTuple<(Scalar<f64>,)>;
impl Diffable<()> for SumAutoTuples {
type Input = SInput;
type Output = SOutput;
}
impl AutoDiffable<()> for SumAutoTuples {
fn eval(&self, x: &SInput, _: &()) -> SOutput {
AutoTuple::new((Scalar::new((**x).0.sum() + (**x).1.sum()),))
}
fn eval_grad(&self, x: &SInput, _: &()) -> (SOutput, SInput) {
(self.eval(x, &()), AutoTuple::new(((**x).0.one(), (**x).1.one())))
}
}
impl ForwardDiffable<()> for SumAutoTuples {
fn eval_forward_grad(
&self,
x: &SInput,
dx: &SInput,
_: &(),
) -> (SOutput, SOutput) {
let mut gradval = 0.0_f64;
gradval += (**dx).0.sum();
gradval += (**dx).1.sum();
(
self.eval(x, &()),
AutoTuple::new((Scalar::new(gradval),)),
)
}
}
#[derive(Clone, Copy, FuncCompose, SimpleForwardDiffable)]
struct ComposeSumUpcastAutoTuple(SumAutoTuples, UpcastAutoTuple);
impl Diffable<()> for ComposeSumUpcastAutoTuple {
type Input = AutoTuple<(Array1<f64>,)>;
type Output = AutoTuple<(Scalar<f64>,)>;
}
impl AutoDiffable<()> for ComposeSumUpcastAutoTuple {
fn eval_grad(&self, x: &Self::Input, _: &()) -> (Self::Output, AutoTuple<(Array1<f64>,)>) {
let (y, _dy) = self.0.eval_grad(&self.1.eval(x, &()), &());
(y, (Array1::ones((**x).0.len()) * ((**x).0.len() as f64 + 1.0)).into())
}
}
impl FuncCompose<(), UpcastAutoTuple> for SumAutoTuples {
type Output = AutoDiff<(), ComposeSumUpcastAutoTuple>;
fn func_compose(self, rhs: UpcastAutoTuple) -> Self::Output {
AutoDiff::new(ComposeSumUpcastAutoTuple(self, rhs))
}
}
#[derive(Clone, Copy, FuncCompose)]
struct UpcastAutoTuple {}
type UInput = AutoTuple<(Array1<f64>,)>;
type UOutput = AutoTuple<(Array2<f64>, Array1<f64>)>;
type UGrad = AutoTuple<(Array3<f64>, Array2<f64>)>;
impl Diffable<()> for UpcastAutoTuple {
type Input = UInput;
type Output = UOutput;
}
impl AutoDiffable<()> for UpcastAutoTuple {
fn eval(&self, x: &UInput, _: &()) -> UOutput {
let x = (**x).0.clone();
let n = x.len();
let mut y = Array2::zeros((n, n));
for i in 0..n {
y.row_mut(i).assign(&x);
}
AutoTuple::new((y, x))
}
fn eval_grad(&self, x: &UInput, _: &()) -> (UOutput, UGrad) {
let xc = (**x).0.clone();
let n = xc.len();
let mut d1_dx = Array3::zeros((n, n, n));
let mut d2_dx = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
d1_dx[[i, j, i]] = 1.0;
d2_dx[[i, j]] = if i == j { 1.0 } else { 0.0 };
}
}
(
self.eval(x, &()),
AutoTuple::new((
d1_dx,
d2_dx,
)),
)
}
}
impl ForwardDiffable<()> for UpcastAutoTuple {
fn eval_forward_grad(
&self,
x: &UInput,
dx: &UInput,
_: &(),
) -> (UOutput, UOutput) {
let xc = (**x).0.clone();
let n = xc.len();
let mut d1 = Array2::zeros((n, n));
let mut d2 = Array1::zeros((n,));
for i in 0..n {
for j in 0..n {
d1[[i, j]] = (**dx).0[i];
}
d2[[i]] = (**dx).0[i];
}
(
self.eval(x, &()),
AutoTuple::new((
d1,
d2,
)),
)
}
}
#[derive(Clone, Copy)]
struct OnlyForward {}
impl Diffable<()> for OnlyForward {
type Input = Complex<f64>;
type Output = Complex<f64>;
}
impl ForwardDiffable<()> for OnlyForward {
fn eval_forward(&self, x: &Complex<f64>, _: &()) -> Complex<f64> {
x * x.conj()
}
fn eval_forward_grad(
&self,
x: &Complex<f64>,
dx: &Complex<f64>,
_: &(),
) -> (Complex<f64>, Complex<f64>) {
(x * x.conj(), 2.0 * x.conj() * dx)
}
}
#[test]
fn test_ad_ndarray() {
let f = AutoDiff::new(Sum1 {});
let i = AutoDiff::new(Identity::new());
let x = Array1::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
let dx = Array1::<f64>::from_vec(vec![1.0, 1.0, 1.0]);
let (f_x, df_dx) = f.eval_forward_grad(&x, &dx, &());
let (i_x, di_dx) = i.eval_forward_grad(&x, &dx, &());
assert_eq!(f_x.value(), 6.0);
assert_eq!(df_dx.value(), 3.0);
assert_eq!(i_x, x);
assert_eq!(di_dx, dx);
let g = f * i;
let (g_x, dg_dx) = g.eval_forward_grad(&x, &dx, &());
assert_eq!(g_x, Array1::from_vec(vec![6.0, 12.0, 18.0]));
assert_eq!(dg_dx, Array1::from_vec(vec![9.0, 12.0, 15.0]));
let y = Array1::from_vec(vec![2.0, 3.0]);
let dy = Array1::from_vec(vec![1.0, 1.0]);
let u = AutoDiff::new(UpcastN { n: 3 });
let (u_y, du_dy) = u.eval_grad(&y, &());
let du = u.forward_grad(&y, &dy, &());
assert_eq!(
u_y,
Array2::from_shape_vec((3, 2), vec![2.0, 3.0, 2.0, 3.0, 2.0, 3.0]).unwrap()
);
assert_eq!(
du_dy,
Array3::from_shape_vec((2, 3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0])
.unwrap()
);
assert_eq!(
du,
Array2::from_shape_vec((3, 2), vec![dy[0], dy[1], dy[0], dy[1], dy[0], dy[1],]).unwrap()
);
assert_eq!(du, (du_dy.reversed_axes()*&dy).sum_axis(Axis(2)).reversed_axes());
let vu = AutoDiff::new(VertCastN { n: 3 });
let (vu_y, dvu_dy) = vu.eval_grad(&y, &());
let dvu = vu.forward_grad(&y, &dy, &());
assert_eq!(
vu_y,
Array2::from_shape_vec((2, 3), vec![2.0, 2.0, 2.0, 3.0, 3.0, 3.0]).unwrap()
);
assert_eq!(
dvu_dy,
Array3::from_shape_vec((2, 2, 3), vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.unwrap()
);
assert_eq!(
dvu,
Array2::from_shape_vec((2, 3), vec![dy[0], dy[0], dy[0], dy[1], dy[1], dy[1],]).unwrap()
);
let s2 = AutoDiff::new(Sum2 {});
let v = s2.compose(u);
let (_v_y, _dv_dy) = v.eval_grad(&y, &());
let (v_y, dv_dy) = v.eval_forward_grad(&y, &dy, &());
assert_eq!(v_y.value(), 15.0);
assert_eq!(dv_dy.value(), 6.0);
let p2 = AutoDiff::new(Prod2 {});
let w = p2.compose(u);
let (w_y, dw_dy) = w.eval_forward_grad(&y, &dy, &());
assert_eq!(w_y.value(), 216.0);
assert_eq!(dw_dy.value(), 540.0);
let sum_auto_tuples: AutoDiff<(), SumAutoTuples> = AutoDiff::new(SumAutoTuples {});
let a2 = Array2::<f64>::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let a1 = Array1::<f64>::from_vec(vec![5.0, 6.0]);
let x: SInput = AutoTuple::new((a2.clone(), a1.clone()));
let dx: SInput = AutoTuple::new((a2.clone().one(), a1.clone().one()));
let (y, dy_dx): (SOutput, SOutput) = sum_auto_tuples.eval_forward_grad(&x, &dx, &());
assert_eq!((*y).0.value(), 21.0);
assert_eq!((*dy_dx).0.value(), 6.0);
let f2 = sum_auto_tuples * sum_auto_tuples;
let (f2_x, df2_dx): (SOutput, SOutput) = f2.eval_forward_grad(&x, &dx, &());
assert_eq!((*f2_x).0.value(), 441.0);
assert_eq!((*df2_dx).0.value(), 2.0 * 21.0 * 6.0);
let upcast_auto_tuple: AutoDiff<(), UpcastAutoTuple> = AutoDiff::new(UpcastAutoTuple {});
let a1 = Array1::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
let x = AutoTuple::new((a1.clone(),));
let dx = AutoTuple::new((a1.clone().one(),));
let (y, dy_dx): (SInput, UOutput) = upcast_auto_tuple.eval_forward_grad(&x, &dx, &());
assert_eq!(
(*y).0,
Array2::from_shape_vec((3, 3), vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0])
.unwrap()
.t()
);
assert_eq!((*y).1, a1);
assert_eq!(
(*dy_dx).0,
(dx.0 .0.clone()
* Array3::from_shape_vec(
(3, 3, 3),
vec![
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1., 0.0, 0.0, 0., 1., 0., 0., 1., 0., 0., 1., 0.,
0., 0., 1., 0., 0., 1., 0., 0., 1.,
]
)
.unwrap())
.sum_axis(Axis(2))
);
assert_eq!((*dy_dx).1, (*dx).0);
let (s_of_u, _dsu) = sum_auto_tuples.eval_forward_grad(&y, &dy_dx, &());
assert_eq!((*s_of_u).0.value(), 24.0);
let su = sum_auto_tuples.compose(upcast_auto_tuple);
let (su_x, dsu_dx) = su.eval_forward_grad(&x, &dx, &());
assert_eq!(su_x, s_of_u);
assert_eq!((*dsu_dx).0.value(), 12.0);
}