use nalgebra::SVector;
use roots::{find_root_brent, SimpleConvergency};
use super::callback::Callback;
use super::controller::{Controller, PIController, TryStep};
use super::integrator::Integrator;
use super::ode::ODE;
#[derive(Clone)]
pub struct Problem<'a, const D: usize, S, P>
where
S: Integrator<D>,
{
ode: ODE<'a, D, P>,
integrator: S,
controller: PIController,
callbacks: Vec<Callback<'a, D, P>>,
}
impl<'a, const D: usize, S, P> Problem<'a, D, S, P>
where
S: Integrator<D> + Copy,
{
pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self {
Problem {
ode,
integrator,
controller,
callbacks: Vec::new(),
}
}
pub fn solve(&mut self) -> Solution<S, D> {
let mut convergency = SimpleConvergency {
eps: 1e-12,
max_iter: 1000,
};
let mut times: Vec<f64> = vec![self.ode.t];
let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
while self.ode.t < self.ode.t_end {
if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end {
self.controller.next_step_guess = TryStep::NotYetAccepted(
self.ode.t_end - self.ode.t,
);
}
let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE {
let (mut trial_y, mut err_option, mut dense_option) =
self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
let mut err = err_option.unwrap();
let initial_guess = self.controller.next_step_guess.extract();
let mut next_step_guess = <PIController as Controller<D>>::determine_step(
&mut self.controller,
initial_guess,
err,
);
while !next_step_guess.is_accepted() {
(trial_y, err_option, dense_option) =
self.integrator.step(&self.ode, next_step_guess.extract());
next_step_guess = <PIController as Controller<D>>::determine_step(
&mut self.controller,
next_step_guess.extract(),
err,
);
err = err_option.unwrap();
}
self.controller.next_step_guess = next_step_guess.reset().unwrap();
(trial_y, next_step_guess.extract(), dense_option)
} else {
let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
(trial_y, self.controller.next_step_guess.extract(), dense_option)
};
if !self.callbacks.is_empty() {
for callback in &self.callbacks {
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
* (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params)
< 0.0
{
let f = |test_t| {
let test_y = self.integrator.step(&self.ode, test_t).0;
(callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
};
let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap();
curr_step = root;
(new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step);
(callback.effect)(&mut self.ode);
}
}
}
self.ode.y = new_y;
self.ode.t += curr_step;
times.push(self.ode.t);
states.push(self.ode.y);
dense_coefficients.push(dense_option.unwrap());
}
Solution {
integrator: self.integrator,
times,
states,
dense: dense_coefficients,
}
}
pub fn with_callback(mut self, callback: Callback<'a, D, P>) -> Self {
self.callbacks.push(callback);
Self {
ode: self.ode,
integrator: self.integrator,
controller: self.controller,
callbacks: self.callbacks,
}
}
}
pub struct Solution<S, const D: usize>
where
S: Integrator<D>,
{
pub integrator: S,
pub times: Vec<f64>,
pub states: Vec<SVector<f64, D>>,
pub dense: Vec<Vec<SVector<f64, D>>>,
}
impl<S, const D: usize> Solution<S, D>
where
S: Integrator<D>,
{
pub fn interpolate(&self, t: f64) -> SVector<f64, D> {
let last = self.times.last().unwrap();
let first = self.times.first().unwrap();
let mut times = self.times.clone();
if *first > *last {
times.reverse();
}
if t < *first || t > *last {
panic!();
}
match times.binary_search_by(|x| x.total_cmp(&t)) {
Ok(index) => self.states[index],
Err(end_index) => {
let t_start = times[end_index - 1];
let t_end = times[end_index];
self.integrator
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::callback::stop;
use crate::controller::PIController;
use crate::integrator::dormand_prince::DormandPrince45;
use approx::assert_relative_eq;
use nalgebra::Vector3;
#[test]
fn test_problem() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
let controller = PIController::default();
let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve();
solution
.times
.iter()
.zip(solution.states.iter())
.for_each(|(time, state)| {
assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2);
})
}
#[test]
fn test_with_callback() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
let controller = PIController::default();
let value_too_high = Callback {
event: &|_: f64, y: SVector<f64, 3>, _: &Params| 10.0 - y[0],
effect: &stop,
};
let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
let solution = problem.solve();
assert_relative_eq!(
solution.states.last().unwrap()[0],
10.0,
max_relative = 1e-11
);
}
#[test]
fn test_with_interpolation() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
let controller = PIController::default();
let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve();
assert_relative_eq!(
solution.interpolate(8.8)[0],
8.8_f64.exp(),
max_relative = 1e-6
);
}
}