use std::cell::RefCell;
use numra_autodiff::Dual;
use numra_core::Scalar;
use crate::problem::OdeSystem;
pub struct AutodiffJacobianSystem<S: Scalar, F>
where
F: Fn(S, &[Dual<S>], &mut [Dual<S>]),
{
rhs_dual: F,
dim: usize,
dual_y: RefCell<Vec<Dual<S>>>,
dual_dy: RefCell<Vec<Dual<S>>>,
}
impl<S: Scalar, F> AutodiffJacobianSystem<S, F>
where
F: Fn(S, &[Dual<S>], &mut [Dual<S>]),
{
pub fn new(rhs_dual: F, dim: usize) -> Self {
let zero = Dual::constant(S::ZERO);
Self {
rhs_dual,
dim,
dual_y: RefCell::new(vec![zero; dim]),
dual_dy: RefCell::new(vec![zero; dim]),
}
}
}
impl<S: Scalar, F> OdeSystem<S> for AutodiffJacobianSystem<S, F>
where
F: Fn(S, &[Dual<S>], &mut [Dual<S>]),
{
fn dim(&self) -> usize {
self.dim
}
fn rhs(&self, t: S, y: &[S], dy: &mut [S]) {
let n = self.dim;
let zero = Dual::constant(S::ZERO);
let mut dual_y = self.dual_y.borrow_mut();
let mut dual_dy = self.dual_dy.borrow_mut();
for i in 0..n {
dual_y[i] = Dual::constant(y[i]); dual_dy[i] = zero;
}
(self.rhs_dual)(t, &dual_y, &mut dual_dy);
for i in 0..n {
dy[i] = dual_dy[i].value();
}
}
fn jacobian(&self, t: S, y: &[S], jac: &mut [S]) {
let n = self.dim;
let zero = Dual::constant(S::ZERO);
let mut dual_y = self.dual_y.borrow_mut();
let mut dual_dy = self.dual_dy.borrow_mut();
for i in 0..n {
dual_y[i] = Dual::constant(y[i]);
}
for j in 0..n {
dual_y[j] = Dual::variable(y[j]); for i in 0..n {
dual_dy[i] = zero;
}
(self.rhs_dual)(t, &dual_y, &mut dual_dy);
for i in 0..n {
jac[i * n + j] = dual_dy[i].deriv();
}
dual_y[j] = Dual::constant(y[j]); }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ad_jacobian_2x2_nonlinear() {
let sys = AutodiffJacobianSystem::<f64, _>::new(
|_t, y: &[Dual<f64>], dy: &mut [Dual<f64>]| {
dy[0] = y[0] * y[0] + y[1];
dy[1] = y[0] * y[1];
},
2,
);
let y = [3.0_f64, 4.0];
let mut jac = [0.0_f64; 4];
sys.jacobian(0.0, &y, &mut jac);
assert!((jac[0] - 6.0).abs() < 1e-14, "∂f0/∂y0 = 2*y0");
assert!((jac[1] - 1.0).abs() < 1e-14, "∂f0/∂y1 = 1");
assert!((jac[2] - 4.0).abs() < 1e-14, "∂f1/∂y0 = y1");
assert!((jac[3] - 3.0).abs() < 1e-14, "∂f1/∂y1 = y0");
}
#[test]
fn ad_rhs_matches_f64_primal() {
let sys = AutodiffJacobianSystem::<f64, _>::new(
|_t, y: &[Dual<f64>], dy: &mut [Dual<f64>]| {
dy[0] = y[0] * y[0] + y[1];
dy[1] = y[0] * y[1];
},
2,
);
let y = [3.0_f64, 4.0];
let mut dy = [0.0_f64; 2];
sys.rhs(0.0, &y, &mut dy);
assert!((dy[0] - (9.0 + 4.0)).abs() < 1e-14);
assert!((dy[1] - (3.0 * 4.0)).abs() < 1e-14);
}
#[test]
fn ad_jacobian_reentrant() {
let sys = AutodiffJacobianSystem::<f64, _>::new(
|_t, y: &[Dual<f64>], dy: &mut [Dual<f64>]| {
dy[0] = y[0] * y[1];
dy[1] = y[1] - y[0];
},
2,
);
let mut jac = [0.0_f64; 4];
for k in 0..10_000 {
let y = [k as f64 * 0.01, (k as f64) * 0.02 - 1.0];
sys.jacobian(0.0, &y, &mut jac);
assert!((jac[0] - y[1]).abs() < 1e-14);
assert!((jac[1] - y[0]).abs() < 1e-14);
assert!((jac[2] - (-1.0)).abs() < 1e-14);
assert!((jac[3] - 1.0).abs() < 1e-14);
}
}
}