use crate::constants::N_EQUAL_STEPS;
use crate::{EulerBackward, EulerForward, ExplicitRungeKutta, Radau5};
use crate::{Method, OdeSolverTrait, Params, Stats, System, Workspace};
use crate::{Output, StrError};
use russell_lab::{vec_all_finite, Vector};
pub struct OdeSolver<'a, A> {
params: Params,
ndim: usize,
actual: Box<dyn OdeSolverTrait<A> + 'a>,
work: Workspace,
output: Output<'a, A>,
output_enabled: bool,
}
impl<'a, A> OdeSolver<'a, A> {
pub fn new(params: Params, system: System<'a, A>) -> Result<Self, StrError>
where
A: 'a,
{
params.validate()?;
if system.calc_mass.is_some() && params.method != Method::Radau5 {
return Err("the method must be Radau5 for systems with a mass matrix");
}
let ndim = system.ndim;
let actual: Box<dyn OdeSolverTrait<A>> = if params.method == Method::Radau5 {
Box::new(Radau5::new(params, system))
} else if params.method == Method::BwEuler {
Box::new(EulerBackward::new(params, system))
} else if params.method == Method::FwEuler {
Box::new(EulerForward::new(system))
} else {
Box::new(ExplicitRungeKutta::new(params, system).unwrap()) };
Ok(OdeSolver {
params,
ndim,
actual,
work: Workspace::new(params.method),
output: Output::new(),
output_enabled: false,
})
}
pub fn stats(&self) -> &Stats {
&self.work.stats
}
pub fn solve(
&mut self,
y0: &mut Vector,
x0: f64,
x1: f64,
h_equal: Option<f64>,
args: &mut A,
) -> Result<(), StrError> {
if y0.dim() != self.ndim {
return Err("y0.dim() must be equal to ndim");
}
if x1 <= x0 {
return Err("x1 must be greater than x0");
}
let info = self.params.method.information();
let (equal_stepping, mut h) = match h_equal {
Some(h_eq) => {
if h_eq < 10.0 * f64::EPSILON {
return Err("h_equal must be ≥ 10.0 * f64::EPSILON");
}
let n = f64::ceil((x1 - x0) / h_eq) as usize;
let h = (x1 - x0) / (n as f64);
(true, h)
}
None => {
if info.embedded {
let h = f64::min(self.params.step.h_ini, x1 - x0);
(false, h)
} else {
let h = (x1 - x0) / (N_EQUAL_STEPS as f64);
(true, h)
}
}
};
assert!(h > 0.0);
self.work.reset(h, self.params.step.rel_error_prev_min);
let mut x = x0; let y = y0;
if self.output_enabled {
self.output.initialize(x0, x1, self.params.stiffness.save_results)?;
if self.output.with_dense_output() {
self.actual.enable_dense_output()?;
}
let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?;
if stop {
return Ok(());
}
}
if equal_stepping {
let nstep = f64::ceil((x1 - x) / h) as usize;
for _ in 0..nstep {
self.work.stats.sw_step.reset();
self.work.stats.n_steps += 1;
self.actual.step(&mut self.work, x, &y, h, args)?;
self.work.stats.n_accepted += 1; self.actual.accept(&mut self.work, &mut x, y, h, args)?;
vec_all_finite(&y, self.params.debug)?;
if self.output_enabled {
let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?;
if stop {
self.work.stats.stop_sw_step();
self.work.stats.stop_sw_total();
return Ok(());
}
}
self.work.stats.stop_sw_step();
}
if self.output_enabled {
self.output.last(&self.work, h, x, y, args)?;
}
self.work.stats.stop_sw_total();
return Ok(());
}
let mut success = false;
let mut last_step = false;
for _ in 0..self.params.step.n_step_max {
self.work.stats.sw_step.reset();
let dx = x1 - x;
if dx <= 10.0 * f64::EPSILON {
success = true;
self.work.stats.stop_sw_step();
break;
}
h = f64::min(self.work.h_new, dx);
if h <= 10.0 * f64::EPSILON {
return Err("the stepsize becomes too small");
}
self.work.stats.n_steps += 1;
self.actual.step(&mut self.work, x, &y, h, args)?;
if self.work.iterations_diverging {
self.work.iterations_diverging = false;
self.work.follows_reject_step = true;
last_step = false;
self.work.h_new = h * self.work.h_multiplier_diverging;
continue;
}
if self.work.rel_error < 1.0 {
self.work.stats.n_accepted += 1;
self.actual.accept(&mut self.work, &mut x, y, h, args)?;
vec_all_finite(&y, self.params.debug)?;
if self.work.follows_reject_step {
self.work.h_new = f64::min(self.work.h_new, h);
}
self.work.follows_reject_step = false;
self.work.h_prev = h;
self.work.rel_error_prev = f64::max(self.params.step.rel_error_prev_min, self.work.rel_error);
self.work.stats.h_accepted = self.work.h_new;
if self.output_enabled {
let stop = self.output.execute(&self.work, h, x, y, &self.actual, args)?;
if stop {
self.work.stats.stop_sw_step();
self.work.stats.stop_sw_total();
return Ok(());
}
}
if last_step {
success = true;
self.work.stats.stop_sw_step();
break;
}
if x + self.work.h_new >= x1 {
last_step = true;
}
} else {
if self.work.stats.n_accepted > 0 {
self.work.stats.n_rejected += 1;
}
self.work.follows_reject_step = true;
last_step = false;
if self.work.stats.n_accepted == 0 && self.params.step.m_first_reject > 0.0 {
self.work.h_new = h * self.params.step.m_first_reject;
} else {
self.actual.reject(&mut self.work, h);
}
}
}
if self.output_enabled {
self.output.last(&self.work, h, x, y, args)?;
}
self.work.stats.stop_sw_total();
if success {
Ok(())
} else {
Err("variable stepping did not converge")
}
}
pub fn update_params(&mut self, params: Params) -> Result<(), StrError> {
if params.method != self.params.method {
return Err("update_params must not change the method");
}
params.validate()?;
self.actual.update_params(params);
self.params = params;
Ok(())
}
pub fn enable_output(&mut self) -> &mut Output<'a, A> {
self.output_enabled = true;
&mut self.output
}
pub fn out_step_h(&self) -> &Vec<f64> {
&self.output.step_h
}
pub fn out_step_x(&self) -> &Vec<f64> {
&self.output.step_x
}
pub fn out_step_y(&self, m: usize) -> &Vec<f64> {
&self.output.step_y.get(&m).unwrap()
}
pub fn out_step_global_error(&self) -> &Vec<f64> {
&self.output.step_global_error
}
pub fn out_dense_x(&self) -> &Vec<f64> {
&self.output.dense_x
}
pub fn out_dense_y(&self, m: usize) -> &Vec<f64> {
&self.output.dense_y.get(&m).unwrap()
}
pub fn out_stiff_step_index(&self) -> &Vec<usize> {
&self.output.stiff_step_index
}
pub fn out_stiff_x(&self) -> &Vec<f64> {
&self.output.stiff_x
}
pub fn out_stiff_h_times_rho(&self) -> &Vec<f64> {
&self.output.stiff_h_times_rho
}
}
#[cfg(test)]
mod tests {
use super::OdeSolver;
use crate::{Method, Params, Samples, System};
use crate::{NoArgs, OutCount, OutData, Stats, StrError};
use russell_lab::{approx_eq, array_approx_eq, vec_approx_eq, Vector};
use russell_sparse::Genie;
#[test]
fn new_captures_errors() {
let (system, _, _, _, _) = Samples::simple_system_with_mass_matrix(false, Genie::Umfpack);
let mut params = Params::new(Method::MdEuler);
assert_eq!(
OdeSolver::new(params, system).err(),
Some("the method must be Radau5 for systems with a mass matrix")
);
let (system, _, _, _, _) = Samples::simple_equation_constant();
params.step.m_max = 0.0; assert_eq!(
OdeSolver::new(params, system).err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
}
#[test]
fn solve_captures_errors() {
let (system, _, _, mut args, _) = Samples::simple_equation_constant();
let ndim = system.ndim;
let params = Params::new(Method::FwEuler);
let mut solver = OdeSolver::new(params, system).unwrap();
let mut y0 = Vector::new(ndim + 1); assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(),
Some("y0.dim() must be equal to ndim")
);
let mut y0 = Vector::new(ndim);
assert_eq!(
solver.solve(&mut y0, 0.0, 0.0, None, &mut args).err(),
Some("x1 must be greater than x0")
);
let h_equal = Some(f64::EPSILON); assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, h_equal, &mut args).err(),
Some("h_equal must be ≥ 10.0 * f64::EPSILON")
);
}
#[test]
fn nan_and_infinity_are_captured() {
let (system, _, mut y0, mut args, _) = Samples::brusselator_ode();
let params = Params::new(Method::FwEuler);
let mut solver = OdeSolver::new(params, system.clone()).unwrap();
assert_eq!(
solver.solve(&mut y0, 0.0, 9.0, Some(1.0), &mut args).err(),
Some("an element of the vector is either infinite or NaN")
);
let params = Params::new(Method::MdEuler);
let mut solver = OdeSolver::new(params, system).unwrap();
assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(),
Some("an element of the vector is either infinite or NaN")
);
}
#[test]
fn lack_of_convergence_is_captured() {
let (system, _, mut y0, mut args, _) = Samples::simple_equation_constant();
let mut params = Params::new(Method::MdEuler);
params.step.n_step_max = 1; let mut solver = OdeSolver::new(params, system).unwrap();
assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(),
Some("variable stepping did not converge")
);
}
#[test]
fn out_initialize_errors_are_captured() {
let (system, _, mut y0, mut args, _) = Samples::simple_equation_constant();
let params = Params::new(Method::DoPri5);
let mut solver = OdeSolver::new(params, system).unwrap();
solver
.enable_output()
.set_dense_x_out(&[0.0, 1.0])
.unwrap()
.set_dense_recording(&[0]);
assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(),
Some("the first interior x_out for dense output must be > x0")
);
solver.enable_output().set_dense_x_out(&[0.1, 1.0]).unwrap();
assert_eq!(
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).err(),
Some("the last interior x_out for dense output must be < x1")
);
}
#[test]
fn solve_with_n_equal_steps_works() {
let (system, x0, y0, mut args, _) = Samples::simple_equation_constant();
let x1 = 1.0;
let params = Params::new(Method::FwEuler);
let mut solver = OdeSolver::new(params, system).unwrap();
let mut y = y0.clone();
solver.solve(&mut y, x0, x1, None, &mut args).unwrap();
vec_approx_eq(&y, &[1.0], 1e-15);
}
#[test]
fn solve_completes_after_a_single_step() {
let (system, _, mut y0, mut args, _) = Samples::simple_equation_constant();
let mut params = Params::new(Method::DoPri5);
params.step.h_ini = 20.0; let mut solver = OdeSolver::new(params, system).unwrap();
solver.solve(&mut y0, 0.0, 1.0, None, &mut args).unwrap();
assert_eq!(solver.work.stats.n_accepted, 1);
vec_approx_eq(&y0, &[1.0], 1e-15);
}
#[test]
fn solve_with_variable_steps_works() {
let (system, _, mut y0, mut args, _) = Samples::simple_equation_constant();
let mut params = Params::new(Method::MdEuler);
params.step.h_ini = 0.1;
let mut solver = OdeSolver::new(params, system).unwrap();
solver.solve(&mut y0, 0.0, 0.3, None, &mut args).unwrap();
vec_approx_eq(&y0, &[0.3], 1e-15);
}
#[test]
fn update_params_captures_errors() {
let (system, _, _, _, _) = Samples::simple_equation_constant();
let mut params = Params::new(Method::MdEuler);
params.step.n_step_max = 0;
assert_eq!(
OdeSolver::new(params, system.clone()).err(),
Some("parameter must satisfy: n_step_max ≥ 1")
);
params.step.n_step_max = 1000;
let mut solver = OdeSolver::new(params, system).unwrap();
assert_eq!(solver.params.step.n_step_max, 1000);
params.step.n_step_max = 2; assert_eq!(solver.params.step.n_step_max, 1000);
solver.update_params(params).unwrap();
assert_eq!(solver.params.step.n_step_max, 2);
params.method = Method::FwEuler;
assert_eq!(
solver.update_params(params).err(),
Some("update_params must not change the method")
);
params.method = Method::MdEuler;
params.step.m_max = 0.0;
assert_eq!(
solver.update_params(params).err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
}
#[test]
fn solve_with_out_dense_captures_errors() {
let (system, x0, mut y0, mut args, _) = Samples::simple_equation_constant();
let x1 = 1.0;
let params = Params::new(Method::FwEuler);
let mut solver = OdeSolver::new(params, system).unwrap();
solver.enable_output().set_dense_recording(&[0]);
assert_eq!(
solver.solve(&mut y0, x0, x1, None, &mut args).err(),
Some("dense output is not available for the FwEuler method")
);
}
#[test]
fn solve_with_step_output_works() {
let (system, _, y0, mut args, y_fn_x) = Samples::simple_equation_constant();
let params = Params::new(Method::DoPri5);
let mut solver = OdeSolver::new(params, system).unwrap();
let path_key = "/tmp/russell_ode/test_solve_step_output_works";
solver
.enable_output()
.set_yx_correct(y_fn_x)
.set_step_file_writing(path_key)
.set_step_recording(&[0])
.set_step_callback(|stats, h, x, y, _args| {
assert_eq!(h, 0.2);
approx_eq(x, (stats.n_accepted as f64) * h, 1e-15);
approx_eq(y[0], (stats.n_accepted as f64) * h, 1e-15);
Ok(false)
});
let h_equal = Some(0.2);
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, h_equal, &mut args).unwrap();
vec_approx_eq(&y, &[0.4], 1e-15);
array_approx_eq(&solver.out_step_h(), &[0.2, 0.2, 0.2], 1e-15);
array_approx_eq(&solver.out_step_x(), &[0.0, 0.2, 0.4], 1e-15);
array_approx_eq(&solver.out_step_y(0), &[0.0, 0.2, 0.4], 1e-15);
array_approx_eq(&solver.out_step_global_error(), &[0.0, 0.0, 0.0], 1e-15);
let count = OutCount::read_json(&format!("{}_count.json", path_key)).unwrap();
assert_eq!(count.n, 3);
for i in 0..count.n {
let res = OutData::read_json(&format!("{}_{}.json", path_key, i)).unwrap();
assert_eq!(res.h, 0.2);
approx_eq(res.x, (i as f64) * 0.2, 1e-15);
approx_eq(res.y[0], (i as f64) * 0.2, 1e-15);
}
let cb = |_stats: &Stats, _h: f64, _x: f64, _y: &Vector, _args: &mut NoArgs| -> Result<bool, StrError> {
Err("unreachable")
};
assert_eq!(cb(&solver.stats(), 0.0, 0.0, &y0, &mut args).err(), Some("unreachable"));
solver.enable_output().set_step_callback(|stats, _h, _x, _y, _args| {
if stats.n_accepted > 0 {
Ok(true) } else {
Ok(false) }
});
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap();
assert!(y[0] > 0.0 && y[0] < 0.4);
}
#[test]
fn solve_with_step_captures_errors() {
let (system, _, y0, mut args, _) = Samples::simple_equation_constant();
let params = Params::new(Method::FwEuler);
let mut solver = OdeSolver::new(params, system).unwrap();
solver.enable_output().set_step_recording(&[0]);
solver
.enable_output()
.set_step_callback(|_stats, _h, _x, _y, _args| Err("stop with error (first accepted step)"));
let mut y = y0.clone();
assert_eq!(
solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(),
Some("stop with error (first accepted step)")
);
solver.enable_output().set_step_callback(|stats, _h, _x, _y, _args| {
if stats.n_accepted > 0 {
Err("stop with error (subsequent steps)")
} else {
Ok(false) }
});
let mut y = y0.clone();
assert_eq!(
solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(),
Some("stop with error (subsequent steps)")
);
}
#[test]
fn solve_with_dense_output_h_out_works_1() {
let (system, _, _, mut args, _) = Samples::simple_equation_constant();
let params = Params::new(Method::DoPri5);
let mut solver = OdeSolver::new(params, system).unwrap();
solver
.enable_output()
.set_dense_h_out(0.2)
.unwrap()
.set_dense_recording(&[0]);
let h_equal = Some(0.201); let x0 = 0.0;
let x1 = 1.0;
let mut y = Vector::from(&[x0]);
solver.solve(&mut y, x0, x1, h_equal, &mut args).unwrap();
vec_approx_eq(&y, &[x1], 1e-15);
let correct = &[0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
array_approx_eq(&solver.out_dense_x(), correct, 1e-15);
array_approx_eq(&solver.out_dense_y(0), correct, 1e-15);
}
#[test]
fn solve_with_dense_output_h_out_works_2() {
let (system, _, y0, mut args, y_fn_x) = Samples::simple_equation_constant();
let params = Params::new(Method::DoPri5);
let mut solver = OdeSolver::new(params, system).unwrap();
const H_OUT: f64 = 0.1;
let path_key = "/tmp/russell_ode/test_solve_dense_output_h_out_works";
solver
.enable_output()
.set_yx_correct(y_fn_x)
.set_dense_h_out(H_OUT)
.unwrap()
.set_dense_file_writing(path_key)
.unwrap()
.set_dense_recording(&[0])
.set_dense_callback(|_stats, h, x, y, _args| {
assert_eq!(h, 0.2);
approx_eq(y[0], x, 1e-15);
Ok(false)
});
let h_equal = Some(0.2);
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, h_equal, &mut args).unwrap();
vec_approx_eq(&y, &[0.4], 1e-15);
array_approx_eq(&solver.out_dense_x(), &[0.0, 0.1, 0.2, 0.3, 0.4], 1e-15);
array_approx_eq(&solver.out_dense_y(0), &[0.0, 0.1, 0.2, 0.3, 0.4], 1e-15);
let count = OutCount::read_json(&format!("{}_count.json", path_key)).unwrap();
assert_eq!(count.n, 5);
for i in 0..count.n {
let res = OutData::read_json(&format!("{}_{}.json", path_key, i)).unwrap();
assert_eq!(res.h, 0.2); approx_eq(res.x, (i as f64) * H_OUT, 1e-15);
approx_eq(res.y[0], (i as f64) * H_OUT, 1e-15);
}
let cb = |_stats: &Stats, _h: f64, _x: f64, _y: &Vector, _args: &mut NoArgs| -> Result<bool, StrError> {
Err("unreachable")
};
assert_eq!(cb(&solver.stats(), 0.0, 0.0, &y0, &mut args).err(), Some("unreachable"));
solver.enable_output().set_dense_callback(|_stats, _h, _x, _y, _args| {
Ok(true) });
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap();
assert_eq!(solver.work.stats.n_accepted, 0);
assert_eq!(y[0], 0.0);
solver.enable_output().set_dense_callback(|stats, _h, _x, _y, _args| {
if stats.n_accepted > 0 {
Ok(true) } else {
Ok(false) }
});
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, Some(0.2), &mut args).unwrap();
assert!(y[0] > 0.0 && y[0] < 0.4);
let mut y = y0.clone();
solver.solve(&mut y, 0.0, 0.4, None, &mut args).unwrap();
assert!(y[0] > 0.0 && y[0] < 0.4);
solver
.enable_output()
.set_dense_callback(|_stats, _h, _x, _y, _args| Err("stop with error"));
let mut y = y0.clone();
assert_eq!(
solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(),
Some("stop with error")
);
solver.enable_output().set_dense_callback(|stats, _h, _x, _y, _args| {
if stats.n_accepted > 0 {
Err("stop with error")
} else {
Ok(false) }
});
let mut y = y0.clone();
assert_eq!(
solver.solve(&mut y, 0.0, 0.4, None, &mut args).err(),
Some("stop with error")
);
}
#[test]
fn solve_with_dense_output_x_out_works() {
let (system, _, _, mut args, _) = Samples::simple_equation_constant();
let params = Params::new(Method::DoPri5);
let mut solver = OdeSolver::new(params, system).unwrap();
let interior_x_out = &[-0.5, 0.0, 0.5];
let selected_y = &[0];
solver
.enable_output()
.set_dense_x_out(interior_x_out)
.unwrap()
.set_dense_recording(selected_y);
let h_equal = Some(0.2);
let x0 = -1.0;
let x1 = 1.0;
let mut y = Vector::from(&[x0]);
solver.solve(&mut y, x0, x1, h_equal, &mut args).unwrap();
vec_approx_eq(&y, &[x1], 1e-15);
let correct = &[-1.0, -0.5, 0.0, 0.5, 1.0];
assert_eq!(solver.out_dense_x(), correct);
array_approx_eq(solver.out_dense_y(0), correct, 1e-15);
}
#[test]
fn solve_captures_errors_from_f_and_out() {
struct Args {
f_count: usize,
f_barrier: usize,
out_count: usize,
out_barrier: usize,
}
let mut args = Args {
f_count: 0,
f_barrier: 0,
out_count: 0,
out_barrier: 2, };
let ndim = 1;
let system = System::new(ndim, |f: &mut Vector, _x: f64, _y: &Vector, args: &mut Args| {
if args.f_count == args.f_barrier {
return Err("f: artificial error");
}
f[0] = 1.0;
args.f_count += 1;
Ok(())
});
let x0 = 0.0;
let x1 = 0.2;
let mut y = Vector::from(&[0.0]);
let mut params = Params::new(Method::DoPri8);
params.step.h_ini = 0.2;
let mut solver = OdeSolver::new(params, system).unwrap();
solver.enable_output().set_dense_h_out(0.1).unwrap().set_dense_callback(
|_stats, _h, _x, _y, args: &mut Args| {
if args.out_count == args.out_barrier {
return Err("out: artificial error");
}
args.out_count += 1;
Ok(false) },
);
assert_eq!(
solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(),
Some("f: artificial error")
);
args.f_barrier += 12; assert_eq!(
solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(),
Some("f: artificial error")
);
args.f_barrier += 2 * 12; assert_eq!(
solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(),
Some("out: artificial error")
);
args.f_barrier += 2 * 12; args.out_barrier += 2; assert_eq!(
solver.solve(&mut y, x0, x1, Some(0.2), &mut args).err(),
Some("out: artificial error")
);
args.f_count = 0;
args.f_barrier = 0;
args.out_count = 0;
args.out_barrier = 2; assert_eq!(
solver.solve(&mut y, x0, x1, None, &mut args).err(),
Some("f: artificial error")
);
args.f_barrier += 12; assert_eq!(
solver.solve(&mut y, x0, x1, None, &mut args).err(),
Some("f: artificial error")
);
args.f_count = 0;
args.f_barrier = 15 + 1; args.out_count = 0;
args.out_barrier = 2; assert_eq!(
solver.solve(&mut y, x0, x1, None, &mut args).err(),
Some("out: artificial error")
);
}
}