use std::rc::Rc;
use crate::{
vector::DefaultDenseMatrix, Closure, ClosureNoJac, ClosureWithSens, ConstantClosure,
ConstantClosureWithSens, LinearClosure, LinearClosureWithSens, Matrix, OdeEquations,
OdeSolverProblem, Op, UnitCallable, Vector,
};
use anyhow::Result;
use super::equations::OdeSolverEquations;
pub struct OdeBuilder {
t0: f64,
h0: f64,
rtol: f64,
atol: Vec<f64>,
p: Vec<f64>,
use_coloring: bool,
sensitivities: bool,
sensitivities_error_control: bool,
}
impl Default for OdeBuilder {
fn default() -> Self {
Self::new()
}
}
impl OdeBuilder {
pub fn new() -> Self {
Self {
t0: 0.0,
h0: 1.0,
rtol: 1e-6,
atol: vec![1e-6],
p: vec![],
use_coloring: false,
sensitivities: false,
sensitivities_error_control: false,
}
}
pub fn t0(mut self, t0: f64) -> Self {
self.t0 = t0;
self
}
pub fn sensitivities(mut self, sensitivities: bool) -> Self {
self.sensitivities = sensitivities;
self
}
pub fn sensitivities_error_control(mut self, sensitivities_error_control: bool) -> Self {
self.sensitivities_error_control = sensitivities_error_control;
self
}
pub fn h0(mut self, h0: f64) -> Self {
self.h0 = h0;
self
}
pub fn rtol(mut self, rtol: f64) -> Self {
self.rtol = rtol;
self
}
pub fn atol<V, T>(mut self, atol: V) -> Self
where
V: IntoIterator<Item = T>,
f64: From<T>,
{
self.atol = atol.into_iter().map(|x| f64::from(x)).collect();
self
}
pub fn p<V, T>(mut self, p: V) -> Self
where
V: IntoIterator<Item = T>,
f64: From<T>,
{
self.p = p.into_iter().map(|x| f64::from(x)).collect();
self
}
pub fn use_coloring(mut self, use_coloring: bool) -> Self {
self.use_coloring = use_coloring;
self
}
fn build_atol<V: Vector>(atol: Vec<f64>, nstates: usize) -> Result<V> {
if atol.len() == 1 {
Ok(V::from_element(nstates, V::T::from(atol[0])))
} else if atol.len() != nstates {
Err(anyhow::anyhow!(
"atol must have length 1 or equal to the number of states"
))
} else {
let mut v = V::zeros(nstates);
for (i, &a) in atol.iter().enumerate() {
v[i] = V::T::from(a);
}
Ok(v)
}
}
fn build_p<V: Vector>(p: Vec<f64>) -> V {
let mut v = V::zeros(p.len());
for (i, &p) in p.iter().enumerate() {
v[i] = V::T::from(p);
}
v
}
#[allow(clippy::type_complexity)]
pub fn build_ode_with_mass<M, F, G, H, I>(
self,
rhs: F,
rhs_jac: G,
mass: H,
init: I,
) -> Result<
OdeSolverProblem<
OdeSolverEquations<M, Closure<M, F, G>, ConstantClosure<M, I>, LinearClosure<M, H>>,
>,
>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V),
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
H: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
I: Fn(&M::V, M::T) -> M::V,
{
let p = Rc::new(Self::build_p(self.p));
let t0 = M::T::from(self.t0);
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone());
let mut mass = LinearClosure::new(mass, nstates, nstates, p.clone());
let init = ConstantClosure::new(init, p.clone());
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
mass.calculate_sparsity(t0);
}
let mass = Some(Rc::new(mass));
let rhs = Rc::new(rhs);
let init = Rc::new(init);
let eqn = OdeSolverEquations::new(rhs, mass, None, init, p);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
M::T::from(self.rtol),
atol,
M::T::from(self.t0),
M::T::from(self.h0),
false,
self.sensitivities_error_control,
)
}
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
pub fn build_ode_with_mass_and_sens<M, F, G, H, I, J, K, L>(
self,
rhs: F,
rhs_jac: G,
rhs_sens: J,
mass: H,
mass_sens: L,
init: I,
init_sens: K,
) -> Result<
OdeSolverProblem<
OdeSolverEquations<
M,
ClosureWithSens<M, F, G, J>,
ConstantClosureWithSens<M, I, K>,
LinearClosureWithSens<M, H, L>,
>,
>,
>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V),
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
H: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
I: Fn(&M::V, M::T) -> M::V,
J: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
K: Fn(&M::V, M::T, &M::V, &mut M::V),
L: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
{
let p = Rc::new(Self::build_p(self.p));
let t0 = M::T::from(self.t0);
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = ClosureWithSens::new(rhs, rhs_jac, rhs_sens, nstates, nstates, p.clone());
let mut mass = LinearClosureWithSens::new(mass, mass_sens, nstates, nstates, p.clone());
let init = ConstantClosureWithSens::new(init, init_sens, nstates, nstates, p.clone());
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
mass.calculate_sparsity(t0);
}
let mass = Some(Rc::new(mass));
let rhs = Rc::new(rhs);
let init = Rc::new(init);
let eqn = OdeSolverEquations::new(rhs, mass, None, init, p);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
M::T::from(self.rtol),
atol,
M::T::from(self.t0),
M::T::from(self.h0),
true,
self.sensitivities_error_control,
)
}
#[allow(clippy::type_complexity)]
pub fn build_ode<M, F, G, I>(
self,
rhs: F,
rhs_jac: G,
init: I,
) -> Result<OdeSolverProblem<OdeSolverEquations<M, Closure<M, F, G>, ConstantClosure<M, I>>>>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V),
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
I: Fn(&M::V, M::T) -> M::V,
{
let p = Rc::new(Self::build_p(self.p));
let t0 = M::T::from(self.t0);
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone());
let init = ConstantClosure::new(init, p.clone());
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
}
let rhs = Rc::new(rhs);
let init = Rc::new(init);
let eqn = OdeSolverEquations::new(rhs, None, None, init, p);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
M::T::from(self.rtol),
atol,
M::T::from(self.t0),
M::T::from(self.h0),
false,
self.sensitivities_error_control,
)
}
#[allow(clippy::type_complexity)]
pub fn build_ode_with_sens<M, F, G, I, J, K>(
self,
rhs: F,
rhs_jac: G,
rhs_sens: J,
init: I,
init_sens: K,
) -> Result<
OdeSolverProblem<
OdeSolverEquations<M, ClosureWithSens<M, F, G, J>, ConstantClosureWithSens<M, I, K>>,
>,
>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V),
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
I: Fn(&M::V, M::T) -> M::V,
J: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
K: Fn(&M::V, M::T, &M::V, &mut M::V),
{
let p = Rc::new(Self::build_p(self.p));
let t0 = M::T::from(self.t0);
let y0 = init(&p, t0);
let nstates = y0.len();
let init = ConstantClosureWithSens::new(init, init_sens, nstates, nstates, p.clone());
let mut rhs = ClosureWithSens::new(rhs, rhs_jac, rhs_sens, nstates, nstates, p.clone());
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
}
let rhs = Rc::new(rhs);
let init = Rc::new(init);
let eqn = OdeSolverEquations::new(rhs, None, None, init, p);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
M::T::from(self.rtol),
atol,
M::T::from(self.t0),
M::T::from(self.h0),
true,
self.sensitivities_error_control,
)
}
#[allow(clippy::type_complexity)]
pub fn build_ode_with_root<M, F, G, I, H>(
self,
rhs: F,
rhs_jac: G,
init: I,
root: H,
nroots: usize,
) -> Result<
OdeSolverProblem<
OdeSolverEquations<
M,
Closure<M, F, G>,
ConstantClosure<M, I>,
UnitCallable<M>,
ClosureNoJac<M, H>,
>,
>,
>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V),
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
H: Fn(&M::V, &M::V, M::T, &mut M::V),
I: Fn(&M::V, M::T) -> M::V,
{
let p = Rc::new(Self::build_p(self.p));
let t0 = M::T::from(self.t0);
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone());
let root = Rc::new(ClosureNoJac::new(root, nstates, nroots, p.clone()));
let init = ConstantClosure::new(init, p.clone());
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
}
let rhs = Rc::new(rhs);
let init = Rc::new(init);
let eqn = OdeSolverEquations::new(rhs, None, Some(root), init, p);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
M::T::from(self.rtol),
atol,
M::T::from(self.t0),
M::T::from(self.h0),
false,
self.sensitivities_error_control,
)
}
#[allow(clippy::type_complexity)]
pub fn build_ode_dense<V, F, G, I>(
self,
rhs: F,
rhs_jac: G,
init: I,
) -> Result<
OdeSolverProblem<OdeSolverEquations<V::M, Closure<V::M, F, G>, ConstantClosure<V::M, I>>>,
>
where
V: Vector + DefaultDenseMatrix,
F: Fn(&V, &V, V::T, &mut V),
G: Fn(&V, &V, V::T, &V, &mut V),
I: Fn(&V, V::T) -> V,
{
self.build_ode(rhs, rhs_jac, init)
}
#[cfg(feature = "diffsl")]
pub fn build_diffsl(
self,
context: &crate::ode_solver::diffsl::DiffSlContext,
) -> Result<OdeSolverProblem<crate::ode_solver::diffsl::DiffSl<'_>>> {
use crate::ode_solver::diffsl;
type V = diffsl::V;
type T = diffsl::T;
let p = Self::build_p::<V>(self.p);
let mut eqn = diffsl::DiffSl::new(context, self.use_coloring);
eqn.set_params(p);
let atol = Self::build_atol::<V>(self.atol, eqn.rhs().nstates())?;
OdeSolverProblem::new(
eqn,
T::from(self.rtol),
atol,
T::from(self.t0),
T::from(self.h0),
self.sensitivities,
self.sensitivities_error_control,
)
}
}