use crate::{
error::Error,
interpolate::{Interpolation, cubic_hermite_interpolate},
linalg::Matrix,
methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
ode::{ODE, OrdinaryNumericalMethod},
stats::Evals,
status::Status,
traits::{Real, State},
utils::validate_step_size_parameters,
};
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
OrdinaryNumericalMethod<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
{
fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y>,
{
let mut evals = Evals::new();
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => self.h = h0,
Err(status) => return Err(status),
}
self.stiffness_counter = 0;
self.newton_iterations = 0;
self.jacobian_evaluations = 0;
self.lu_decompositions = 0;
self.t = t0;
self.y = *y0;
ode.diff(self.t, &self.y, &mut self.dydt);
evals.function += 1;
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
let dim = y0.len();
self.jacobian = Matrix::zeros(dim, dim);
self.z = *y0;
self.jacobian_age = 0;
self.status = Status::Initialized;
Ok(evals)
}
fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y>,
{
let mut evals = Evals::new();
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y,
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y,
});
}
self.steps += 1;
let dim = self.y.len();
for stage in 0..self.stages {
let mut rhs = self.y;
for j in 0..stage {
rhs += self.k[j] * (self.a[stage][j] * self.h);
}
self.z = self.y;
let mut newton_converged = false;
let mut newton_iter = 0;
let mut increment_norm = T::infinity();
while !newton_converged && newton_iter < self.max_newton_iter {
newton_iter += 1;
self.newton_iterations += 1;
evals.newton += 1;
let t_stage = self.t + self.c[stage] * self.h;
let mut f_stage = Y::zeros();
ode.diff(t_stage, &self.z, &mut f_stage);
evals.function += 1;
let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
let mut residual_norm = T::zero();
self.rhs_newton = -residual;
for i in 0..dim {
residual_norm = residual_norm.max(residual.get(i).abs());
}
if residual_norm < self.newton_tol {
newton_converged = true;
break;
}
if newton_iter > 1 && increment_norm < self.newton_tol {
newton_converged = true;
break;
}
if newton_iter == 1 || self.jacobian_age > 3 {
ode.jacobian(t_stage, &self.z, &mut self.jacobian);
evals.jacobian += 1;
self.jacobian_age = 0;
self.jacobian
.component_mul_mut(-self.h * self.a[stage][stage]);
self.jacobian += Matrix::identity(dim);
}
self.jacobian_age += 1;
self.delta_z = self.jacobian.lin_solve(self.rhs_newton).unwrap();
self.lu_decompositions += 1;
increment_norm = T::zero();
self.z += self.delta_z;
for row_idx in 0..dim {
increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
}
}
if !newton_converged {
self.status = Status::Error(Error::Stiffness {
t: self.t,
y: self.y,
});
return Err(Error::Stiffness {
t: self.t,
y: self.y,
});
}
let t_stage = self.t + self.c[stage] * self.h;
ode.diff(t_stage, &self.z, &mut self.k[stage]);
evals.function += 1;
}
let mut y_new = self.y;
for i in 0..self.stages {
y_new += self.k[i] * (self.b[i] * self.h);
}
self.status = Status::Solving;
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.h_prev = self.h;
self.t += self.h;
self.y = y_new;
ode.diff(self.t, &self.y, &mut self.dydt);
evals.function += 1;
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = h;
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
if t_interp < self.t_prev || t_interp > self.t {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev,
t_curr: self.t,
});
}
let y_interp = cubic_hermite_interpolate(
self.t_prev,
self.t,
&self.y_prev,
&self.y,
&self.dydt_prev,
&self.dydt,
t_interp,
);
Ok(y_interp)
}
}