use super::variable::Variable;
pub trait Gradient<T, S> {
fn wrt(&self, v: T) -> S;
}
impl<'v> Gradient<&Variable<'v>, f64> for Vec<f64> {
#[inline]
fn wrt(&self, v: &Variable) -> f64 {
self[v.index]
}
}
impl<'v> Gradient<&Vec<Variable<'v>>, Vec<f64>> for Vec<f64> {
#[inline]
fn wrt(&self, v: &Vec<Variable<'v>>) -> Vec<f64> {
let mut gradient = Vec::with_capacity(v.len());
for i in v {
gradient.push(self.wrt(i));
}
gradient
}
}
impl<'v> Gradient<&[Variable<'v>], Vec<f64>> for Vec<f64> {
#[inline]
fn wrt(&self, v: &[Variable<'v>]) -> Vec<f64> {
let mut gradient = Vec::with_capacity(v.len());
for i in v {
gradient.push(self.wrt(i));
}
gradient
}
}
impl<'v, const N: usize> Gradient<[Variable<'v>; N], Vec<f64>> for Vec<f64> {
#[inline]
fn wrt(&self, v: [Variable<'v>; N]) -> Vec<f64> {
let mut gradient = Vec::with_capacity(N);
for i in v {
gradient.push(self.wrt(&i));
}
gradient
}
}
impl<'v, const N: usize> Gradient<&[Variable<'v>; N], Vec<f64>> for Vec<f64> {
#[inline]
fn wrt(&self, v: &[Variable<'v>; N]) -> Vec<f64> {
let mut gradient = Vec::with_capacity(N);
for i in v {
gradient.push(self.wrt(i));
}
gradient
}
}
#[cfg(test)]
mod tests {
use crate::{assert_approx_equal, autodiff::*};
#[test]
fn x_times_y_plus_sin_x() {
let t = Tape::new();
let x = t.var(0.5);
let y = t.var(4.2);
let z = x * y + x.sin();
let grad = z.accumulate();
assert_approx_equal!(z.value, 2.579425538604203, 1e-14);
assert_approx_equal!(grad.wrt(&x), y.value + x.value.cos(), 1e-15);
assert_approx_equal!(grad.wrt(&y), x.value, 1e-15);
}
#[test]
fn x_times_y_plus_tan_x() {
let t = Tape::new();
let x = t.var(1.0);
let y = t.var(2.0);
let z = x * y + x.tan();
let grad = z.accumulate();
assert_approx_equal!(z.value, 3.5574077246549, 1e-14);
assert_approx_equal!(grad.wrt(&x), 5.4255188208147597, 1e-15);
assert_approx_equal!(grad.wrt(&y), 1.0, 1e-15);
}
#[test]
fn cosh_x_times_y() {
let t = Tape::new();
let x = t.var(1.0);
let y = t.var(2.0);
let z = (x * y).cosh();
let grad = z.accumulate();
println!("{}", grad.wrt(&x));
assert_approx_equal!(z.value, 3.762195691083631459, 1e-10);
assert_approx_equal!(grad.wrt(&x), 7.2537208156940375, 1e-10);
assert_approx_equal!(grad.wrt(&y), 3.62686040784701, 1e-10);
}
#[test]
fn cosh_xy_div_tanh_x_times_sinh_y() {
let t = Tape::new();
let x = t.var(1.0);
let y = t.var(2.0);
let z = (x * y).cosh() / (x.tanh() * y.sinh());
let grad = z.accumulate();
assert_approx_equal!(z.value, 1.3620308304831552, 1e-8);
assert_approx_equal!(grad.wrt(&x), 1.87499075136386965, 1e-15);
assert_approx_equal!(grad.wrt(&y), -0.099819345045613269, 1e-15);
}
#[test]
fn test_block_assign() {
let t = Tape::new();
let x = t.var(1.0);
let y = t.var(2.0);
let f = {
let z = x.sin() + y.tan();
z.exp()
};
let grad = f.accumulate();
println!("Grad wrt x = 1.0: \t{}", grad.wrt(&x));
println!("Grad wrt y = 2.0: \t{}", grad.wrt(&y));
assert_approx_equal!(grad.wrt(&x), 0.1409718084254616945815, 1e-15);
assert_approx_equal!(grad.wrt(&y), 1.5066148885971964908277, 1e-15);
}
#[test]
fn test_closure() {
let t = Tape::new();
let x = t.var(1.0);
let y = t.var(2.0);
let z = || (x * y).cosh() / (x.tanh() * y.sinh());
let grad = z().accumulate();
assert_approx_equal!(z().value, 1.3620308304831552, 1e-8);
assert_approx_equal!(grad.wrt(&x), 1.87499075136386965, 1e-15);
assert_approx_equal!(grad.wrt(&y), -0.099819345045613269, 1e-15);
}
#[test]
fn test_function_input() {
fn diff_fn<'v>(params: &[Variable<'v>], data: &[f64]) -> Variable<'v> {
params[0].powf(params[1]) + data[0].sin() - params[2].asinh() / data[1]
}
let tape = Tape::new();
let params = tape.vars(&[3.0, 2.0, 1.0]);
let data = [1., 2.];
let result = diff_fn(¶ms, &data);
let gradients = result.accumulate();
println!("{:?}", gradients.wrt(¶ms));
println!("{:?}", gradients);
}
}