use anyhow::Result;
use nalgebra::ComplexField;
use num_traits::{One, Pow};
use std::rc::Rc;
use crate::{
matrix::default_solver::DefaultSolver, scalar::Scalar, scale, ConstantOp, InitOp,
NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations, OdeSolverProblem, Op,
SensEquations, SolverProblem, Vector,
};
pub enum OdeSolverStopReason<T: Scalar> {
InternalTimestep,
RootFound(T),
TstopReached,
}
pub trait OdeSolverMethod<Eqn: OdeEquations> {
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>>;
fn set_problem(&mut self, state: OdeSolverState<Eqn::V>, problem: &OdeSolverProblem<Eqn>);
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>>;
fn set_stop_time(&mut self, tstop: Eqn::T) -> Result<()>;
fn interpolate(&self, t: Eqn::T) -> Result<Eqn::V>;
fn interpolate_sens(&self, t: Eqn::T) -> Result<Vec<Eqn::V>>;
fn state(&self) -> Option<&OdeSolverState<Eqn::V>>;
fn state_mut(&mut self) -> Option<&mut OdeSolverState<Eqn::V>>;
fn order(&self) -> usize;
fn take_state(&mut self) -> Option<OdeSolverState<Eqn::V>>;
fn solve(&mut self, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Result<Eqn::V>
where
Eqn::M: DefaultSolver,
Self: Sized,
{
let state = OdeSolverState::new(problem, self)?;
self.set_problem(state, problem);
self.set_stop_time(t)?;
loop {
if let OdeSolverStopReason::TstopReached = self.step()? {
break;
}
}
Ok(self.state().unwrap().y.clone())
}
}
#[derive(Clone)]
pub struct OdeSolverState<V: Vector> {
pub y: V,
pub dy: V,
pub s: Vec<V>,
pub ds: Vec<V>,
pub t: V::T,
pub h: V::T,
}
impl<V: Vector> OdeSolverState<V> {
pub fn new<Eqn, S>(ode_problem: &OdeSolverProblem<Eqn>, solver: &S) -> Result<Self>
where
Eqn: OdeEquations<T = V::T, V = V>,
Eqn::M: DefaultSolver,
S: OdeSolverMethod<Eqn>,
{
let mut ret = Self::new_without_initialise(ode_problem);
let mut root_solver =
NewtonNonlinearSolver::new(<Eqn::M as DefaultSolver>::default_solver());
ret.set_consistent(ode_problem, &mut root_solver)?;
let mut root_solver_sens =
NewtonNonlinearSolver::new(<Eqn::M as DefaultSolver>::default_solver());
ret.set_consistent_sens(ode_problem, &mut root_solver_sens)?;
ret.set_step_size(ode_problem, solver.order());
Ok(ret)
}
pub fn new_without_initialise<Eqn>(ode_problem: &OdeSolverProblem<Eqn>) -> Self
where
Eqn: OdeEquations<T = V::T, V = V>,
{
let t = ode_problem.t0;
let h = ode_problem.h0;
let y = ode_problem.eqn.init().call(t);
let dy = V::zeros(y.len());
let nparams = ode_problem.eqn.rhs().nparams();
let (s, ds) = if ode_problem.eqn_sens.is_none() {
(vec![], vec![])
} else {
let eqn_sens = ode_problem.eqn_sens.as_ref().unwrap();
eqn_sens.init().update_state(t);
let mut s = Vec::with_capacity(nparams);
let mut ds = Vec::with_capacity(nparams);
for i in 0..nparams {
eqn_sens.init().set_param_index(i);
let si = eqn_sens.init().call(t);
let dsi = V::zeros(y.len());
s.push(si);
ds.push(dsi);
}
(s, ds)
};
Self { y, t, h, dy, s, ds }
}
pub fn set_consistent<Eqn, S>(
&mut self,
ode_problem: &OdeSolverProblem<Eqn>,
root_solver: &mut S,
) -> Result<()>
where
Eqn: OdeEquations<T = V::T, V = V>,
S: NonLinearSolver<InitOp<Eqn>> + ?Sized,
{
ode_problem
.eqn
.rhs()
.call_inplace(&self.y, self.t, &mut self.dy);
if ode_problem.eqn.mass().is_none() {
return Ok(());
}
let f = Rc::new(InitOp::new(
&ode_problem.eqn,
ode_problem.t0,
&self.y,
&self.dy,
));
let rtol = ode_problem.rtol;
let atol = ode_problem.atol.clone();
let init_problem = SolverProblem::new(f.clone(), atol, rtol);
root_solver.set_problem(&init_problem);
let mut y = f.y0.borrow().clone();
root_solver.solve_in_place(&mut y, self.t)?;
f.scatter_soln(&y, &mut self.y, &mut self.dy);
Ok(())
}
pub fn set_consistent_sens<Eqn, S>(
&mut self,
ode_problem: &OdeSolverProblem<Eqn>,
root_solver: &mut S,
) -> Result<()>
where
Eqn: OdeEquations<T = V::T, V = V>,
S: NonLinearSolver<InitOp<SensEquations<Eqn>>> + ?Sized,
{
if ode_problem.eqn_sens.is_none() {
return Ok(());
}
let eqn_sens = ode_problem.eqn_sens.as_ref().unwrap();
eqn_sens.rhs().update_state(&self.y, &self.dy, self.t);
for i in 0..ode_problem.eqn.rhs().nparams() {
eqn_sens.init().set_param_index(i);
eqn_sens.rhs().set_param_index(i);
eqn_sens
.rhs()
.call_inplace(&self.s[i], self.t, &mut self.ds[i]);
}
if ode_problem.eqn.mass().is_none() {
return Ok(());
}
for i in 0..ode_problem.eqn.rhs().nparams() {
eqn_sens.init().set_param_index(i);
eqn_sens.rhs().set_param_index(i);
let f = Rc::new(InitOp::new(
eqn_sens,
ode_problem.t0,
&self.s[i],
&self.ds[i],
));
root_solver.set_problem(&SolverProblem::new(
f.clone(),
ode_problem.atol.clone(),
ode_problem.rtol,
));
let mut y = f.y0.borrow().clone();
root_solver.solve_in_place(&mut y, self.t)?;
f.scatter_soln(&y, &mut self.s[i], &mut self.ds[i]);
}
Ok(())
}
pub fn set_step_size<Eqn>(&mut self, ode_problem: &OdeSolverProblem<Eqn>, solver_order: usize)
where
Eqn: OdeEquations<T = V::T, V = V>,
{
let y0 = &self.y;
let t0 = self.t;
let f0 = &self.dy;
let rtol = ode_problem.rtol;
let atol = ode_problem.atol.as_ref();
let d0 = y0.squared_norm(y0, atol, rtol).sqrt();
let d1 = f0.squared_norm(y0, atol, rtol).sqrt();
let h0 = if d0 < Eqn::T::from(1e-5) || d1 < Eqn::T::from(1e-5) {
Eqn::T::from(1e-6)
} else {
Eqn::T::from(0.01) * (d0 / d1)
};
let y1 = f0.clone() * scale(h0) + y0;
let t1 = t0 + h0;
let f1 = ode_problem.eqn.rhs().call(&y1, t1);
let df = f1 - f0;
let d2 = df.squared_norm(y0, atol, rtol).sqrt() / h0;
let mut max_d = d2;
if max_d < d1 {
max_d = d1;
}
let h1 = if max_d < Eqn::T::from(1e-15) {
let h1 = h0 * Eqn::T::from(1e-3);
if h1 < Eqn::T::from(1e-6) {
Eqn::T::from(1e-6)
} else {
h1
}
} else {
(Eqn::T::from(0.01) / max_d)
.pow(Eqn::T::one() / Eqn::T::from(1.0 + solver_order as f64))
};
self.h = Eqn::T::from(100.0) * h0;
if self.h > h1 {
self.h = h1;
}
}
}