use crate::{
ode_solver::problem::OdeSolverSolution, MatrixHost, OdeBuilder, OdeEquationsImplicitSens,
OdeSolverProblem, Vector,
};
use num_traits::{FromPrimitive, One, Zero};
#[allow(clippy::type_complexity)]
pub fn robertson_ode_with_sens<M: MatrixHost + 'static>(
use_coloring: bool,
) -> (
OdeSolverProblem<impl OdeEquationsImplicitSens<M = M, V = M::V, T = M::T, C = M::C>>,
OdeSolverSolution<M::V>,
) {
let problem = OdeBuilder::<M>::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.use_coloring(use_coloring)
.rhs_sens_implicit(
|x: &M::V, p: &M::V, _t: M::T, y: &mut M::V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
y[2] = p[2] * x[1] * x[1];
},
|x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| {
y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
y[1] = p[0] * v[0]
- p[1] * v[1] * x[2]
- p[1] * x[1] * v[2]
- M::T::from_f64(2.0).unwrap() * p[2] * x[1] * v[1];
y[2] = M::T::from_f64(2.0).unwrap() * p[2] * x[1] * v[1];
},
|x: &M::V, _p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| {
y[0] = -v[0] * x[0] + v[1] * x[1] * x[2];
y[1] = v[0] * x[0] - v[1] * x[1] * x[2] - v[2] * x[1] * x[1];
y[2] = v[2] * x[1] * x[1];
},
)
.init_sens(
|_p: &M::V, _t: M::T, y: &mut M::V| {
y[0] = M::T::one();
y[1] = M::T::zero();
y[2] = M::T::zero();
},
|_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V| y.fill(M::T::zero()),
3,
)
.build()
.unwrap();
let mut soln = OdeSolverSolution::default();
let data = vec![
(vec![1.0, 0.0, 0.0], 0.0),
(vec![9.851641e-01, 3.386242e-05, 1.480205e-02], 0.4),
(vec![9.055097e-01, 2.240338e-05, 9.446793e-02], 4.0),
(vec![7.158017e-01, 9.185037e-06, 2.841892e-01], 40.0),
(vec![4.505360e-01, 3.223271e-06, 5.494608e-01], 400.0),
(vec![1.832299e-01, 8.944378e-07, 8.167692e-01], 4000.0),
(vec![3.898902e-02, 1.622006e-07, 9.610108e-01], 40000.0),
(vec![4.936383e-03, 1.984224e-08, 9.950636e-01], 400000.0),
(vec![5.168093e-04, 2.068293e-09, 9.994832e-01], 4000000.0),
(vec![5.202440e-05, 2.081083e-10, 9.999480e-01], 4.0000e+07),
(vec![5.201061e-06, 2.080435e-11, 9.999948e-01], 4.0000e+08),
(vec![5.258603e-07, 2.103442e-12, 9.999995e-01], 4.0000e+09),
(vec![6.934511e-08, 2.773804e-13, 9.999999e-01], 4.0000e+10),
];
for (values, time) in data {
soln.push(
M::V::from_vec(
values
.into_iter()
.map(|v| M::T::from_f64(v).unwrap())
.collect(),
problem.context().clone(),
),
M::T::from_f64(time).unwrap(),
);
}
(problem, soln)
}