use faer::{ComplexField, Conjugate, SimpleEntity};
use numra_core::Scalar;
use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
use crate::error::SolverError;
use crate::problem::OdeSystem;
use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
use crate::t_eval::{validate_grid, TEvalEmitter};
#[derive(Clone, Debug, Default)]
pub struct Esdirk32;
impl Esdirk32 {
pub fn new() -> Self {
Self
}
}
mod esdirk32_tableau {
pub const GAMMA: f64 = 0.2928932188134525;
pub const C: [f64; 3] = [0.0, 2.0 * GAMMA, 1.0];
pub const A: [[f64; 3]; 3] = [
[0.0, 0.0, 0.0],
[GAMMA, GAMMA, 0.0],
[1.0 - 2.0 * GAMMA, GAMMA, GAMMA],
];
pub const B: [f64; 3] = [1.0 - 2.0 * GAMMA, GAMMA, GAMMA];
pub const E: [f64; 3] = [1.0 - 2.0 * GAMMA - 0.5, GAMMA - 0.0, GAMMA - 0.5];
}
#[derive(Clone, Debug, Default)]
pub struct Esdirk43;
impl Esdirk43 {
pub fn new() -> Self {
Self
}
}
mod esdirk43_tableau {
pub const GAMMA: f64 = 0.4358665215084590;
pub const C: [f64; 4] = [0.0, 2.0 * GAMMA, 1.0, 1.0];
pub const A: [[f64; 4]; 4] = [
[0.0, 0.0, 0.0, 0.0],
[GAMMA, GAMMA, 0.0, 0.0],
[0.4905633884217806, 0.0735700900697604, GAMMA, 0.0],
[
0.3088099699767466,
1.4905633884217800,
-1.2352398799069855,
GAMMA,
],
];
pub const B: [f64; 4] = [
0.3088099699767466,
1.4905633884217800,
-1.2352398799069855,
GAMMA,
];
pub const E: [f64; 4] = [
0.3088099699767466 - 0.4905633884217806, 1.4905633884217800 - 0.0735700900697604, -1.2352398799069855 - GAMMA, GAMMA, ];
}
#[derive(Clone, Debug, Default)]
pub struct Esdirk54;
impl Esdirk54 {
pub fn new() -> Self {
Self
}
}
mod esdirk54_tableau {
pub const GAMMA: f64 = 0.25;
pub const C: [f64; 6] = [
0.0,
0.5, 0.14644660940672624, 0.625, 1.04, 1.0,
];
pub const A: [[f64; 6]; 6] = [
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[GAMMA, GAMMA, 0.0, 0.0, 0.0, 0.0],
[
-0.05177669529663689,
-0.05177669529663689,
GAMMA,
0.0,
0.0,
0.0,
],
[
-0.07655460838455727,
-0.07655460838455727,
0.5281092167691145,
GAMMA,
0.0,
0.0,
],
[
-0.7274063478261299,
-0.7274063478261299,
1.5849950617406794,
0.6598176339115805,
GAMMA,
0.0,
],
[
-0.01558763503571651,
-0.01558763503571651,
0.3876576709132033,
0.5017726195721631,
-0.10825502041393352,
GAMMA,
],
];
pub const B: [f64; 6] = [
-0.01558763503571651,
-0.01558763503571651,
0.3876576709132033,
0.5017726195721631,
-0.10825502041393352,
GAMMA,
];
pub const E: [f64; 6] = [
-0.08092570713246382,
-0.08092570713246382,
0.13516228008303094,
0.01879524505002539,
0.0256969660063123,
-0.01780307687444085,
];
}
fn solve_esdirk<S, Sys, const STAGES: usize>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
c: &[f64],
a: &[[f64; STAGES]; STAGES],
b: &[f64],
e: &[f64],
gamma: f64,
order: usize,
) -> Result<SolverResult<S>, SolverError>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
let dim = problem.dim();
if y0.len() != dim {
return Err(SolverError::DimensionMismatch {
expected: dim,
actual: y0.len(),
});
}
let mut t = t0;
let mut y = y0.to_vec();
let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
if let Some(grid) = options.t_eval.as_deref() {
validate_grid(grid, t0, tf)?;
}
let mut grid_emitter = options
.t_eval
.as_deref()
.map(|g| TEvalEmitter::new(g, direction_init));
let (mut t_out, mut y_out) = if grid_emitter.is_some() {
(Vec::new(), Vec::new())
} else {
(vec![t0], y0.to_vec())
};
let mut dy_old_buf = vec![S::ZERO; dim];
let mut k: Vec<Vec<S>> = (0..STAGES).map(|_| vec![S::ZERO; dim]).collect();
let mut y_stage = vec![S::ZERO; dim];
let mut y_new = vec![S::ZERO; dim];
let mut err = vec![S::ZERO; dim];
let mut jac_data = vec![S::ZERO; dim * dim];
let mut f0 = vec![S::ZERO; dim];
let mut stats = SolverStats::default();
problem.rhs(t, &y, &mut k[0]);
stats.n_eval += 1;
f0.copy_from_slice(&k[0]);
let mut h = initial_step_size(&y, &k[0], options, dim);
let h_min = options.h_min;
let h_max = options.h_max.min((tf - t0).abs());
let mut lu: Option<LUFactorization<S>> = None;
let mut need_jac = true;
let mut jac_h = h;
let direction = direction_init;
let mut step_count = 0_usize;
let mut consecutive_failures = 0_usize;
while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
if step_count >= options.max_steps {
return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
}
if (t + h - tf) * direction > S::ZERO {
h = tf - t;
}
h = h.abs().max(h_min) * direction;
if h.abs() > h_max {
h = h_max * direction;
}
if need_jac {
compute_jacobian(problem, t, &y, &f0, &mut jac_data, dim);
stats.n_jac += 1;
need_jac = false;
}
if lu.is_none() || (h - jac_h).abs() > S::from_f64(1e-10) * h.abs() {
let iter_matrix = form_iteration_matrix(&jac_data, h * S::from_f64(gamma), dim);
lu = Some(LUFactorization::new(&iter_matrix)?);
stats.n_lu += 1;
jac_h = h;
}
let step_ok = compute_esdirk_stages::<S, Sys, STAGES>(
problem,
t,
h,
&y,
c,
a,
gamma,
lu.as_ref().unwrap(),
&mut k,
&mut y_stage,
&mut stats,
dim,
)?;
if !step_ok {
stats.n_reject += 1;
consecutive_failures += 1;
h = h * S::from_f64(0.5);
need_jac = true;
if consecutive_failures >= 5 {
return Err(SolverError::Other(format!(
"Too many consecutive failures at t = {}",
t.to_f64()
)));
}
continue;
}
for i in 0..dim {
let mut sum_b = S::ZERO;
let mut sum_e = S::ZERO;
for s in 0..STAGES {
sum_b = sum_b + S::from_f64(b[s]) * k[s][i];
sum_e = sum_e + S::from_f64(e[s]) * k[s][i];
}
y_new[i] = y[i] + h * sum_b;
err[i] = h * sum_e;
}
let err_norm = error_norm(&err, &y, &y_new, options, dim);
let safety = S::from_f64(0.9);
let fac_max = S::from_f64(3.0);
let fac_min = S::from_f64(0.2);
let order_f = S::from_usize(order + 1);
if err_norm <= S::ONE {
stats.n_accept += 1;
consecutive_failures = 0;
let t_new = t + h;
dy_old_buf.copy_from_slice(&f0);
problem.rhs(t_new, &y_new, &mut f0);
stats.n_eval += 1;
if let Some(ref mut emitter) = grid_emitter {
emitter.emit_step(
t,
&y,
&dy_old_buf,
t_new,
&y_new,
&f0,
&mut t_out,
&mut y_out,
);
} else {
t_out.push(t_new);
y_out.extend_from_slice(&y_new);
}
t = t_new;
y.copy_from_slice(&y_new);
k[0].copy_from_slice(&f0);
let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
let fac = safety * err_safe.powf(-S::ONE / order_f);
let fac = fac.min(fac_max).max(fac_min);
h = h * fac;
} else {
stats.n_reject += 1;
consecutive_failures += 1;
let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
let fac = safety * err_safe.powf(-S::ONE / order_f);
let fac = fac.max(fac_min);
h = h * fac;
if consecutive_failures >= 3 {
need_jac = true;
}
}
if h.abs() < h_min {
return Err(SolverError::StepSizeTooSmall {
t: t.to_f64(),
h: h.to_f64(),
h_min: h_min.to_f64(),
});
}
step_count += 1;
}
Ok(SolverResult::new(t_out, y_out, dim, stats))
}
fn compute_esdirk_stages<S, Sys, const STAGES: usize>(
problem: &Sys,
t: S,
h: S,
y: &[S],
c: &[f64],
a: &[[f64; STAGES]; STAGES],
gamma: f64,
lu: &LUFactorization<S>,
k: &mut [Vec<S>],
y_stage: &mut [S],
stats: &mut SolverStats,
dim: usize,
) -> Result<bool, SolverError>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
for s in 1..STAGES {
for i in 0..dim {
let mut sum = S::ZERO;
for j in 0..s {
sum = sum + S::from_f64(a[s][j]) * k[j][i];
}
y_stage[i] = y[i] + h * sum;
}
let t_stage = t + S::from_f64(c[s]) * h;
let h_gamma = h * S::from_f64(gamma);
let mut converged = false;
for _iter in 0..10 {
let mut f_stage = vec![S::ZERO; dim];
problem.rhs(t_stage, y_stage, &mut f_stage);
stats.n_eval += 1;
let mut residual = vec![S::ZERO; dim];
let mut res_norm = S::ZERO;
for i in 0..dim {
let mut sum = S::ZERO;
for j in 0..s {
sum = sum + S::from_f64(a[s][j]) * k[j][i];
}
residual[i] = y_stage[i] - y[i] - h * sum - h_gamma * f_stage[i];
res_norm = res_norm + residual[i] * residual[i];
}
res_norm = res_norm.sqrt();
if res_norm < S::from_f64(1e-10) {
k[s].copy_from_slice(&f_stage);
converged = true;
break;
}
let delta = lu.solve(&residual)?;
for i in 0..dim {
y_stage[i] = y_stage[i] - delta[i];
}
}
if !converged {
return Ok(false);
}
}
Ok(true)
}
fn compute_jacobian<S, Sys>(problem: &Sys, t: S, y: &[S], f0: &[S], jac: &mut [S], dim: usize)
where
S: Scalar,
Sys: OdeSystem<S>,
{
let h_factor = S::EPSILON.sqrt();
let mut y_pert = y.to_vec();
let mut f_pert = vec![S::ZERO; dim];
for j in 0..dim {
let yj = y[j];
let h = h_factor * (S::ONE + yj.abs());
y_pert[j] = yj + h;
problem.rhs(t, &y_pert, &mut f_pert);
y_pert[j] = yj;
for i in 0..dim {
jac[i * dim + j] = (f_pert[i] - f0[i]) / h;
}
}
}
fn form_iteration_matrix<S>(jac: &[S], h_gamma: S, dim: usize) -> DenseMatrix<S>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
{
let mut m = DenseMatrix::zeros(dim, dim);
for i in 0..dim {
for j in 0..dim {
let jij = jac[i * dim + j];
if i == j {
m.set(i, j, S::ONE - h_gamma * jij);
} else {
m.set(i, j, -h_gamma * jij);
}
}
}
m
}
fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
if let Some(h0) = options.h0 {
return h0;
}
let mut y_norm = S::ZERO;
let mut f_norm = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y0[i].abs();
y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
}
y_norm = (y_norm / S::from_usize(dim)).sqrt();
f_norm = (f_norm / S::from_usize(dim)).sqrt();
if y_norm < S::EPSILON.sqrt() || f_norm < S::EPSILON.sqrt() {
S::from_f64(1e-6)
} else {
(S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
}
}
fn error_norm<S: Scalar>(
err: &[S],
y: &[S],
y_new: &[S],
options: &SolverOptions<S>,
dim: usize,
) -> S {
let mut err_norm = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
let sc = sc.max(S::from_f64(1e-15));
let scaled_err = err[i] / sc;
err_norm = err_norm + scaled_err * scaled_err;
}
(err_norm / S::from_usize(dim)).sqrt()
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk32 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
solve_esdirk::<S, Sys, 3>(
problem,
t0,
tf,
y0,
options,
&esdirk32_tableau::C,
&esdirk32_tableau::A,
&esdirk32_tableau::B,
&esdirk32_tableau::E,
esdirk32_tableau::GAMMA,
2,
)
}
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk43 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
solve_esdirk::<S, Sys, 4>(
problem,
t0,
tf,
y0,
options,
&esdirk43_tableau::C,
&esdirk43_tableau::A,
&esdirk43_tableau::B,
&esdirk43_tableau::E,
esdirk43_tableau::GAMMA,
3,
)
}
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk54 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
solve_esdirk::<S, Sys, 6>(
problem,
t0,
tf,
y0,
options,
&esdirk54_tableau::C,
&esdirk54_tableau::A,
&esdirk54_tableau::B,
&esdirk54_tableau::E,
esdirk54_tableau::GAMMA,
4,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::problem::OdeProblem;
#[test]
fn test_esdirk32_exponential() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
5.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Esdirk32::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-5.0_f64).exp();
assert!((y_final[0] - expected).abs() < 1e-3);
}
#[test]
fn test_esdirk43_stiff() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -50.0 * y[0];
},
0.0,
0.5,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Esdirk43::solve(&problem, 0.0, 0.5, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-25.0_f64).exp();
assert!((y_final[0] - expected).abs() < 0.01);
}
#[test]
fn test_esdirk54_linear_system() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0] + y[1];
dydt[1] = y[0] - y[1];
},
0.0,
5.0,
vec![1.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
let result = Esdirk54::solve(&problem, 0.0, 5.0, &[1.0, 0.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
assert!((y_final[0] + y_final[1] - 1.0).abs() < 1e-4);
}
#[test]
fn test_esdirk_van_der_pol() {
let mu = 10.0;
let problem = OdeProblem::new(
move |_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
},
0.0,
10.0,
vec![2.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Esdirk54::solve(&problem, 0.0, 10.0, &[2.0, 0.0], &options);
assert!(result.is_ok());
}
#[test]
fn test_esdirk_methods_agree() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
2.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
let r32 = Esdirk32::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let r43 = Esdirk43::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let r54 = Esdirk54::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let y32 = r32.y_final().unwrap()[0];
let y43 = r43.y_final().unwrap()[0];
let y54 = r54.y_final().unwrap()[0];
let expected = (-2.0_f64).exp();
assert!(
(y32 - expected).abs() < 1e-2,
"ESDIRK32: got {}, expected {}",
y32,
expected
);
assert!(
(y43 - expected).abs() < 1e-2,
"ESDIRK43: got {}, expected {}",
y43,
expected
);
assert!(
(y54 - expected).abs() < 1e-2,
"ESDIRK54: got {}, expected {}",
y54,
expected
);
}
}