use crate::{
ode_solver::problem::OdeSolverSolution, MatrixHost, OdeBuilder, OdeEquationsImplicit,
OdeSolverProblem, Vector,
};
use num_traits::{FromPrimitive, One, Zero};
#[cfg(feature = "diffsl")]
#[allow(clippy::type_complexity)]
pub fn robertson_ode_diffsl_problem<
M: MatrixHost<T = f64>,
CG: crate::CodegenModuleJit + crate::CodegenModuleCompile,
>() -> (
OdeSolverProblem<impl crate::OdeEquationsImplicitAdjoint<M = M, V = M::V, T = M::T, C = M::C>>,
OdeSolverSolution<M::V>,
) {
let code = "
in_i { k1 = 0.04, k2 = 10000, k3 = 30000000 }
u_i {
x = 1,
y = 0,
z = 0,
}
F_i {
-k1*x + k2*y*z,
k1*x - k2*y*z - k3*y*y,
k3*y*y,
}";
let problem = OdeBuilder::<M>::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_from_diffsl::<CG>(code)
.unwrap();
let mut soln = soln::<M::V>(problem.context().clone());
soln.rtol = problem.rtol;
soln.atol = problem.atol.clone();
(problem, soln)
}
#[allow(clippy::type_complexity)]
pub fn robertson_ode<M: MatrixHost + 'static>(
use_coloring: bool,
ngroups: usize,
) -> (
OdeSolverProblem<impl OdeEquationsImplicit<M = M, V = M::V, T = M::T, C = M::C>>,
OdeSolverSolution<M::V>,
) {
const N: usize = 3;
let nstates = N * ngroups;
let problem = OdeBuilder::<M>::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol(
[1.0e-8, 1.0e-14, 1.0e-6]
.iter()
.cycle()
.take(ngroups * N)
.cloned()
.collect::<Vec<f64>>(),
)
.use_coloring(use_coloring)
.rhs_implicit(
move |x: &M::V, p: &M::V, _t: M::T, y: &mut M::V| {
for ig in 0..ngroups {
let i = ig * N;
y[i] = -p[0] * x[i] + p[1] * x[i + 1] * x[i + 2];
y[i + 1] =
p[0] * x[i] - p[1] * x[i + 1] * x[i + 2] - p[2] * x[i + 1] * x[i + 1];
y[i + 2] = p[2] * x[i + 1] * x[i + 1];
}
},
move |x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| {
for ig in 0..ngroups {
let i = ig * N;
y[i] = -p[0] * v[i] + p[1] * v[i + 1] * x[i + 2] + p[1] * x[i + 1] * v[i + 2];
y[i + 1] = p[0] * v[i]
- p[1] * v[i + 1] * x[i + 2]
- p[1] * x[i + 1] * v[i + 2]
- M::T::from_f64(2.0).unwrap() * p[2] * x[i + 1] * v[i + 1];
y[i + 2] = M::T::from_f64(2.0).unwrap() * p[2] * x[i + 1] * v[i + 1];
}
},
)
.init(
move |_p: &M::V, _t: M::T, y: &mut M::V| {
for ig in 0..ngroups {
let i = ig * N;
y[i] = M::T::one();
y[i + 1] = M::T::zero();
y[i + 2] = M::T::zero();
}
},
nstates,
)
.build()
.unwrap();
let mut soln = OdeSolverSolution::default();
let data = vec![
(vec![1.0, 0.0, 0.0], 0.0),
(vec![9.851641e-01, 3.386242e-05, 1.480205e-02], 0.4),
(vec![9.055097e-01, 2.240338e-05, 9.446793e-02], 4.0),
(vec![7.158017e-01, 9.185037e-06, 2.841892e-01], 40.0),
(vec![4.505360e-01, 3.223271e-06, 5.494608e-01], 400.0),
(vec![1.832299e-01, 8.944378e-07, 8.167692e-01], 4000.0),
(vec![3.898902e-02, 1.622006e-07, 9.610108e-01], 40000.0),
(vec![4.936383e-03, 1.984224e-08, 9.950636e-01], 400000.0),
(vec![5.168093e-04, 2.068293e-09, 9.994832e-01], 4000000.0),
(vec![5.202440e-05, 2.081083e-10, 9.999480e-01], 4.0000e+07),
(vec![5.201061e-06, 2.080435e-11, 9.999948e-01], 4.0000e+08),
(vec![5.258603e-07, 2.103442e-12, 9.999995e-01], 4.0000e+09),
(vec![6.934511e-08, 2.773804e-13, 9.999999e-01], 4.0000e+10),
];
let data = data
.into_iter()
.map(|(values, time)| {
let mut newvalues = vec![];
for _ in 0..ngroups {
newvalues.extend_from_slice(values.as_slice());
}
(newvalues, time)
})
.collect::<Vec<_>>();
for (values, time) in data {
soln.push(
M::V::from_vec(
values
.into_iter()
.map(|v| M::T::from_f64(v).unwrap())
.collect(),
problem.context().clone(),
),
M::T::from_f64(time).unwrap(),
);
}
(problem, soln)
}
#[cfg(feature = "diffsl")]
fn soln<V: Vector>(ctx: V::C) -> OdeSolverSolution<V> {
let mut soln = OdeSolverSolution::default();
let data = vec![
(vec![1.0, 0.0, 0.0], 0.0),
(vec![9.851641e-01, 3.386242e-05, 1.480205e-02], 0.4),
(vec![9.055097e-01, 2.240338e-05, 9.446793e-02], 4.0),
(vec![7.158017e-01, 9.185037e-06, 2.841892e-01], 40.0),
(vec![4.505360e-01, 3.223271e-06, 5.494608e-01], 400.0),
(vec![1.832299e-01, 8.944378e-07, 8.167692e-01], 4000.0),
(vec![3.898902e-02, 1.622006e-07, 9.610108e-01], 40000.0),
(vec![4.936383e-03, 1.984224e-08, 9.950636e-01], 400000.0),
(vec![5.168093e-04, 2.068293e-09, 9.994832e-01], 4000000.0),
(vec![5.202440e-05, 2.081083e-10, 9.999480e-01], 4.0000e+07),
(vec![5.201061e-06, 2.080435e-11, 9.999948e-01], 4.0000e+08),
(vec![5.258603e-07, 2.103442e-12, 9.999995e-01], 4.0000e+09),
(vec![6.934511e-08, 2.773804e-13, 9.999999e-01], 4.0000e+10),
];
for (values, time) in data {
soln.push(
V::from_vec(
values
.into_iter()
.map(|v| V::T::from_f64(v).unwrap())
.collect(),
ctx.clone(),
),
V::T::from_f64(time).unwrap(),
);
}
soln
}