use faer::{ComplexField, Conjugate, SimpleEntity};
use numra_core::Scalar;
use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
use crate::error::SolverError;
use crate::problem::OdeSystem;
use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
use crate::t_eval::{validate_grid, TEvalEmitter};
#[derive(Clone, Debug, Default)]
pub struct Radau5;
impl Radau5 {
pub fn new() -> Self {
Self
}
}
mod coefficients {
pub const SQRT6: f64 = 2.449489742783178;
pub const C1: f64 = (4.0 - SQRT6) / 10.0; pub const C2: f64 = (4.0 + SQRT6) / 10.0; #[allow(dead_code)]
pub const C3: f64 = 1.0;
pub const DD1: f64 = -(13.0 + 7.0 * SQRT6) / 3.0;
pub const DD2: f64 = (-13.0 + 7.0 * SQRT6) / 3.0;
pub const DD3: f64 = -1.0 / 3.0;
const CUBERT81: f64 = 4.3267487109222245;
const CUBERT9: f64 = 2.080083823051904;
const U1_RAW: f64 = (6.0 + CUBERT81 - CUBERT9) / 30.0;
pub const U1: f64 = 1.0 / U1_RAW;
const ALPH_RAW: f64 = (12.0 - CUBERT81 + CUBERT9) / 60.0;
const BETA_RAW: f64 = (CUBERT81 + CUBERT9) * 1.7320508075688772 / 60.0; const CNO: f64 = ALPH_RAW * ALPH_RAW + BETA_RAW * BETA_RAW;
pub const ALPH: f64 = ALPH_RAW / CNO; pub const BETA: f64 = BETA_RAW / CNO;
pub const T11: f64 = 9.1232394870892942792e-02;
pub const T12: f64 = -0.14125529502095420843;
pub const T13: f64 = -3.0029194105147424492e-02;
pub const T21: f64 = 0.24171793270710701896;
pub const T22: f64 = 0.20412935229379993199;
pub const T23: f64 = 0.38294211275726193779;
pub const T31: f64 = 0.96604818261509293619;
pub const T32: f64 = 1.0;
#[allow(dead_code)]
pub const T33: f64 = 0.0;
pub const TI11: f64 = 4.3255798900631553510;
pub const TI12: f64 = 0.33919925181580986954;
pub const TI13: f64 = 0.54177053993587487119;
pub const TI21: f64 = -4.1787185915519047273;
pub const TI22: f64 = -0.32768282076106238708;
pub const TI23: f64 = 0.47662355450055045196;
pub const TI31: f64 = -0.50287263494578687595;
pub const TI32: f64 = 2.5719269498556054292;
pub const TI33: f64 = -0.59603920482822492497;
pub const P11: f64 = 13.0 / 3.0 + 7.0 * SQRT6 / 3.0;
pub const P12: f64 = -23.0 / 3.0 - 22.0 * SQRT6 / 3.0;
pub const P13: f64 = 10.0 / 3.0 + 5.0 * SQRT6;
pub const P21: f64 = 13.0 / 3.0 - 7.0 * SQRT6 / 3.0;
pub const P22: f64 = -23.0 / 3.0 + 22.0 * SQRT6 / 3.0;
pub const P23: f64 = 10.0 / 3.0 - 5.0 * SQRT6;
pub const P31: f64 = 1.0 / 3.0;
pub const P32: f64 = -8.0 / 3.0;
pub const P33: f64 = 10.0 / 3.0;
}
const MAX_NEWTON_ITER: usize = 7;
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Radau5 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
let dim = problem.dim();
if y0.len() != dim {
return Err(SolverError::DimensionMismatch {
expected: dim,
actual: y0.len(),
});
}
let mut t = t0;
let mut y = y0.to_vec();
let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
if let Some(grid) = options.t_eval.as_deref() {
validate_grid(grid, t0, tf)?;
}
let mut grid_emitter = options
.t_eval
.as_deref()
.map(|g| TEvalEmitter::new(g, direction_init));
let (mut t_out, mut y_out) = if grid_emitter.is_some() {
(Vec::new(), Vec::new())
} else {
(vec![t0], y0.to_vec())
};
let mut dy_old_buf = vec![S::ZERO; dim];
let mut f0 = vec![S::ZERO; dim];
let mut z1 = vec![S::ZERO; dim];
let mut z2 = vec![S::ZERO; dim];
let mut z3 = vec![S::ZERO; dim];
let mut w1 = vec![S::ZERO; dim];
let mut w2 = vec![S::ZERO; dim];
let mut w3 = vec![S::ZERO; dim];
let mut cont = vec![S::ZERO; dim];
let mut scal = vec![S::ZERO; dim];
let mut y_new = vec![S::ZERO; dim];
let mut err = vec![S::ZERO; dim];
let mut jac_data = vec![S::ZERO; dim * dim];
let mut z1_prev = vec![S::ZERO; dim];
let mut z2_prev = vec![S::ZERO; dim];
let mut z3_prev = vec![S::ZERO; dim];
let mut h_prev: S = S::ONE; let mut have_prev = false;
let mut h_abs_old: Option<S> = None;
let mut err_norm_old: Option<S> = None;
let has_mass = problem.has_mass_matrix();
let mass_data = if has_mass {
let mut m = vec![S::ZERO; dim * dim];
problem.mass_matrix(&mut m);
Some(m)
} else {
None
};
let mass_ref = mass_data.as_deref();
let mut stats = SolverStats::default();
for i in 0..dim {
scal[i] = options.atol + options.rtol * y[i].abs();
}
problem.rhs(t, &y, &mut f0);
stats.n_eval += 1;
let mut h = Self::initial_step_size(&y, &f0, options, dim);
let h_min = options.h_min;
let h_max = (tf - t0).abs() * S::from_f64(0.5);
let mut lu_real: Option<LUFactorization<S>> = None;
let mut lu_complex: Option<LUFactorization<S>> = None;
let mut need_jac = true;
let mut first = true;
let mut reject = false;
let mut step_count = 0usize;
let direction = if tf > t0 { S::ONE } else { -S::ONE };
while (tf - t) * direction > S::ZERO {
if step_count >= options.max_steps {
return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
}
if (t + h - tf) * direction > S::ZERO {
h = tf - t;
}
if need_jac {
problem.jacobian(t, &y, &mut jac_data);
stats.n_jac += 1;
need_jac = false;
lu_real = None;
lu_complex = None;
}
if lu_real.is_none() {
let (e1, e2) = Self::form_transformed_matrices(&jac_data, h, dim, mass_ref);
lu_real = Some(LUFactorization::new(&e1)?);
lu_complex = Some(LUFactorization::new(&e2)?);
stats.n_lu += 2;
}
for i in 0..dim {
scal[i] = options.atol + options.rtol * y[i].abs();
}
let use_extrapolation = !first && !reject && have_prev;
if !use_extrapolation {
for i in 0..dim {
z1[i] = S::ZERO;
z2[i] = S::ZERO;
z3[i] = S::ZERO;
w1[i] = S::ZERO;
w2[i] = S::ZERO;
w3[i] = S::ZERO;
}
} else {
let p11 = S::from_f64(coefficients::P11);
let p12 = S::from_f64(coefficients::P12);
let p13 = S::from_f64(coefficients::P13);
let p21 = S::from_f64(coefficients::P21);
let p22 = S::from_f64(coefficients::P22);
let p23 = S::from_f64(coefficients::P23);
let p31 = S::from_f64(coefficients::P31);
let p32 = S::from_f64(coefficients::P32);
let p33 = S::from_f64(coefficients::P33);
let c1 = S::from_f64(coefficients::C1);
let c2 = S::from_f64(coefficients::C2);
let c3 = S::ONE;
let r1 = h * c1 / h_prev;
let r2 = h * c2 / h_prev;
let r3 = h * c3 / h_prev;
for i in 0..dim {
let q0 = z1_prev[i] * p11 + z2_prev[i] * p21 + z3_prev[i] * p31;
let q1 = z1_prev[i] * p12 + z2_prev[i] * p22 + z3_prev[i] * p32;
let q2 = z1_prev[i] * p13 + z2_prev[i] * p23 + z3_prev[i] * p33;
z1[i] = q0 * r1 + q1 * r1 * r1 + q2 * r1 * r1 * r1;
z2[i] = q0 * r2 + q1 * r2 * r2 + q2 * r2 * r2 * r2;
z3[i] = q0 * r3 + q1 * r3 * r3 + q2 * r3 * r3 * r3;
}
let ti11 = S::from_f64(coefficients::TI11);
let ti12 = S::from_f64(coefficients::TI12);
let ti13 = S::from_f64(coefficients::TI13);
let ti21 = S::from_f64(coefficients::TI21);
let ti22 = S::from_f64(coefficients::TI22);
let ti23 = S::from_f64(coefficients::TI23);
let ti31 = S::from_f64(coefficients::TI31);
let ti32 = S::from_f64(coefficients::TI32);
let ti33 = S::from_f64(coefficients::TI33);
for i in 0..dim {
w1[i] = ti11 * z1[i] + ti12 * z2[i] + ti13 * z3[i];
w2[i] = ti21 * z1[i] + ti22 * z2[i] + ti23 * z3[i];
w3[i] = ti31 * z1[i] + ti32 * z2[i] + ti33 * z3[i];
}
}
let newton_result = Self::newton_iteration(
problem,
t,
h,
&y,
&scal,
&mut z1,
&mut z2,
&mut z3,
&mut w1,
&mut w2,
&mut w3,
&mut cont,
lu_real.as_ref().unwrap(),
lu_complex.as_ref().unwrap(),
mass_ref,
&mut stats,
dim,
options,
);
let (newton_converged, newt_iter) = match newton_result {
Ok((converged, iter)) => (converged, iter),
Err(_) => (false, MAX_NEWTON_ITER),
};
if !newton_converged {
h = h * S::from_f64(0.5);
stats.n_reject += 1;
reject = true;
need_jac = true;
if h.abs() < h_min {
return Err(SolverError::StepSizeTooSmall {
t: t.to_f64(),
h: h.to_f64(),
h_min: h_min.to_f64(),
});
}
continue;
}
for i in 0..dim {
y_new[i] = y[i] + z3[i];
}
let err_norm = Self::error_estimate(
problem,
t,
&f0,
&z1,
&z2,
&z3,
&y,
&y_new,
h,
options,
lu_real.as_ref().unwrap(),
&mut err,
dim,
first,
reject,
&mut stats,
mass_ref,
);
let safety = Self::safety_factor::<S>(newt_iter, MAX_NEWTON_ITER);
let pred = Self::predict_factor(h.abs(), h_abs_old, err_norm, err_norm_old);
let factor = (safety * pred).max(S::from_f64(0.2)).min(S::from_f64(8.0));
if err_norm < S::ONE {
stats.n_accept += 1;
z1_prev.copy_from_slice(&z1);
z2_prev.copy_from_slice(&z2);
z3_prev.copy_from_slice(&z3);
h_prev = h;
have_prev = true;
h_abs_old = Some(h.abs());
err_norm_old = Some(err_norm);
let t_new = t + h;
dy_old_buf.copy_from_slice(&f0);
problem.rhs(t_new, &y_new, &mut f0);
stats.n_eval += 1;
if let Some(ref mut emitter) = grid_emitter {
emitter.emit_step(
t,
&y,
&dy_old_buf,
t_new,
&y_new,
&f0,
&mut t_out,
&mut y_out,
);
} else {
t_out.push(t_new);
y_out.extend_from_slice(&y_new);
}
t = t_new;
y.copy_from_slice(&y_new);
first = false;
reject = false;
if factor < S::from_f64(1.2) {
} else {
let h_proposed = h * factor;
let h_capped = if h_proposed.abs() > h_max {
if h_proposed > S::ZERO {
h_max
} else {
-h_max
}
} else {
h_proposed
};
h = h_capped;
lu_real = None;
lu_complex = None;
}
} else {
stats.n_reject += 1;
reject = true;
h = h * factor;
lu_real = None;
lu_complex = None;
if h.abs() < h_min {
return Err(SolverError::StepSizeTooSmall {
t: t.to_f64(),
h: h.to_f64(),
h_min: h_min.to_f64(),
});
}
}
step_count += 1;
}
Ok(SolverResult::new(t_out, y_out, dim, stats))
}
}
impl Radau5 {
fn initial_step_size<S: Scalar>(y: &[S], f: &[S], options: &SolverOptions<S>, dim: usize) -> S {
let mut d0 = S::ZERO;
let mut d1 = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y[i].abs();
d0 = d0 + (y[i] / sc) * (y[i] / sc);
d1 = d1 + (f[i] / sc) * (f[i] / sc);
}
let d0 = (d0 / S::from_usize(dim)).sqrt();
let d1 = (d1 / S::from_usize(dim)).sqrt();
let h0 = if d0 < S::from_f64(1e-5) || d1 < S::from_f64(1e-5) {
S::from_f64(1e-6)
} else {
S::from_f64(0.01) * d0 / d1
};
h0.min(options.h_max).max(options.h_min)
}
fn form_transformed_matrices<S>(
jac: &[S],
h: S,
dim: usize,
mass: Option<&[S]>,
) -> (DenseMatrix<S>, DenseMatrix<S>)
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
{
let fac1 = S::from_f64(coefficients::U1) / h;
let mut e1 = DenseMatrix::zeros(dim, dim);
for i in 0..dim {
for j in 0..dim {
let jij = jac[i * dim + j];
let mij = match mass {
Some(m) => m[i * dim + j],
None => {
if i == j {
S::ONE
} else {
S::ZERO
}
}
};
e1.set(i, j, fac1 * mij - jij);
}
}
let alphn = S::from_f64(coefficients::ALPH) / h;
let betan = S::from_f64(coefficients::BETA) / h;
let mut e2 = DenseMatrix::zeros(2 * dim, 2 * dim);
for i in 0..dim {
for j in 0..dim {
let jij = jac[i * dim + j];
let mij = match mass {
Some(m) => m[i * dim + j],
None => {
if i == j {
S::ONE
} else {
S::ZERO
}
}
};
e2.set(i, j, alphn * mij - jij);
e2.set(i, dim + j, -betan * mij);
e2.set(dim + i, j, betan * mij);
e2.set(dim + i, dim + j, alphn * mij - jij);
}
}
(e1, e2)
}
fn safety_factor<S: Scalar>(n_iter: usize, max_iter: usize) -> S {
let num = 0.9 * (2.0 * max_iter as f64 + 1.0);
let den = 2.0 * max_iter as f64 + n_iter as f64;
S::from_f64(num / den)
}
fn predict_factor<S: Scalar>(
h_abs: S,
h_abs_old: Option<S>,
err_norm: S,
err_norm_old: Option<S>,
) -> S {
let multiplier = match (h_abs_old, err_norm_old) {
(Some(h_old), Some(err_old)) if err_norm > S::ZERO && h_old > S::ZERO => {
(h_abs / h_old) * (err_old / err_norm).powf(S::from_f64(0.25))
}
_ => S::ONE,
};
multiplier.min(S::ONE) * err_norm.powf(S::from_f64(-0.25))
}
#[allow(clippy::too_many_arguments)]
fn newton_iteration<S, Sys>(
problem: &Sys,
t: S,
h: S,
y: &[S],
scal: &[S],
z1: &mut [S],
z2: &mut [S],
z3: &mut [S],
w1: &mut [S],
w2: &mut [S],
w3: &mut [S],
cont: &mut [S],
lu_real: &LUFactorization<S>,
lu_complex: &LUFactorization<S>,
mass: Option<&[S]>,
stats: &mut SolverStats,
dim: usize,
options: &SolverOptions<S>,
) -> Result<(bool, usize), SolverError>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
let c1 = S::from_f64(coefficients::C1);
let c2 = S::from_f64(coefficients::C2);
let uround = S::from_f64(1e-16);
let fnewt = (S::from_f64(10.0) * uround / options.rtol)
.max(S::from_f64(0.03).min(options.rtol.sqrt()));
let fac1 = S::from_f64(coefficients::U1) / h;
let alphn = S::from_f64(coefficients::ALPH) / h;
let betan = S::from_f64(coefficients::BETA) / h;
let ti11 = S::from_f64(coefficients::TI11);
let ti12 = S::from_f64(coefficients::TI12);
let ti13 = S::from_f64(coefficients::TI13);
let ti21 = S::from_f64(coefficients::TI21);
let ti22 = S::from_f64(coefficients::TI22);
let ti23 = S::from_f64(coefficients::TI23);
let ti31 = S::from_f64(coefficients::TI31);
let ti32 = S::from_f64(coefficients::TI32);
let ti33 = S::from_f64(coefficients::TI33);
let t11 = S::from_f64(coefficients::T11);
let t12 = S::from_f64(coefficients::T12);
let t13 = S::from_f64(coefficients::T13);
let t21 = S::from_f64(coefficients::T21);
let t22 = S::from_f64(coefficients::T22);
let t23 = S::from_f64(coefficients::T23);
let t31 = S::from_f64(coefficients::T31);
let t32 = S::from_f64(coefficients::T32);
let mut dynold: S = uround;
let mut thqold: S = S::ONE;
let mut faccon: S = S::ONE;
let n3 = S::from_usize(3 * dim);
let mut f2_temp = vec![S::ZERO; dim];
let mut f3_temp = vec![S::ZERO; dim];
let mut z1_orig = vec![S::ZERO; dim];
let mut z2_orig = vec![S::ZERO; dim];
let mut z3_orig = vec![S::ZERO; dim];
let mut mz1_buf = vec![S::ZERO; dim];
let mut mz2_buf = vec![S::ZERO; dim];
let mut mz3_buf = vec![S::ZERO; dim];
let mut rhs1 = vec![S::ZERO; dim];
let mut rhs2 = vec![S::ZERO; dim];
let mut rhs3 = vec![S::ZERO; dim];
let mut rhs_complex = vec![S::ZERO; 2 * dim];
for newt in 0..MAX_NEWTON_ITER {
for i in 0..dim {
cont[i] = y[i] + z1[i];
}
problem.rhs(t + c1 * h, cont, z1);
for i in 0..dim {
cont[i] = y[i] + z2[i];
}
problem.rhs(t + c2 * h, cont, &mut f2_temp);
for i in 0..dim {
cont[i] = y[i] + z3[i];
}
problem.rhs(t + h, cont, &mut f3_temp);
stats.n_eval += 3;
for i in 0..dim {
z1_orig[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
z2_orig[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
z3_orig[i] = t31 * w1[i] + t32 * w2[i]; }
if let Some(m) = mass {
for i in 0..dim {
mz1_buf[i] = S::ZERO;
mz2_buf[i] = S::ZERO;
mz3_buf[i] = S::ZERO;
}
for i in 0..dim {
for j in 0..dim {
let mij = m[i * dim + j];
mz1_buf[i] = mz1_buf[i] + mij * z1_orig[j];
mz2_buf[i] = mz2_buf[i] + mij * z2_orig[j];
mz3_buf[i] = mz3_buf[i] + mij * z3_orig[j];
}
}
} else {
mz1_buf.copy_from_slice(&z1_orig);
mz2_buf.copy_from_slice(&z2_orig);
mz3_buf.copy_from_slice(&z3_orig);
}
for i in 0..dim {
let a1 = z1[i]; let a2 = f2_temp[i];
let a3 = f3_temp[i];
let tf1 = ti11 * a1 + ti12 * a2 + ti13 * a3;
let tf2 = ti21 * a1 + ti22 * a2 + ti23 * a3;
let tf3 = ti31 * a1 + ti32 * a2 + ti33 * a3;
let tmz1 = ti11 * mz1_buf[i] + ti12 * mz2_buf[i] + ti13 * mz3_buf[i];
let tmz2 = ti21 * mz1_buf[i] + ti22 * mz2_buf[i] + ti23 * mz3_buf[i];
let tmz3 = ti31 * mz1_buf[i] + ti32 * mz2_buf[i] + ti33 * mz3_buf[i];
rhs1[i] = tf1 - fac1 * tmz1;
rhs2[i] = tf2 - alphn * tmz2 + betan * tmz3;
rhs3[i] = tf3 - alphn * tmz3 - betan * tmz2;
}
let dw1 = lu_real.solve(&rhs1)?;
for i in 0..dim {
rhs_complex[i] = rhs2[i];
rhs_complex[dim + i] = rhs3[i];
}
let dw_complex = lu_complex.solve(&rhs_complex)?;
let mut dyno = S::ZERO;
for i in 0..dim {
let denom = scal[i];
dyno = dyno
+ (dw1[i] / denom) * (dw1[i] / denom)
+ (dw_complex[i] / denom) * (dw_complex[i] / denom)
+ (dw_complex[dim + i] / denom) * (dw_complex[dim + i] / denom);
}
dyno = (dyno / n3).sqrt();
if (1..MAX_NEWTON_ITER - 1).contains(&newt) {
let thq = dyno / dynold;
let theta = if newt == 1 {
thq
} else {
(thq * thqold).sqrt()
};
thqold = thq;
if theta < S::from_f64(0.99) {
faccon = theta / (S::ONE - theta);
let dyth =
faccon * dyno * theta.powf(S::from_usize(MAX_NEWTON_ITER - 1 - newt))
/ fnewt;
if dyth >= S::ONE {
return Ok((false, newt + 1));
}
} else {
return Ok((false, newt + 1));
}
}
dynold = dyno.max(uround);
for i in 0..dim {
w1[i] = w1[i] + dw1[i];
w2[i] = w2[i] + dw_complex[i];
w3[i] = w3[i] + dw_complex[dim + i];
}
for i in 0..dim {
z1[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
z2[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
z3[i] = t31 * w1[i] + t32 * w2[i]; }
if faccon * dyno <= fnewt {
return Ok((true, newt + 1));
}
}
Ok((false, MAX_NEWTON_ITER))
}
#[allow(clippy::too_many_arguments)]
fn error_estimate<S, Sys>(
problem: &Sys,
t: S,
f0: &[S],
z1: &[S],
z2: &[S],
z3: &[S],
y: &[S],
y_new: &[S],
h: S,
options: &SolverOptions<S>,
lu_real: &LUFactorization<S>,
err: &mut [S],
dim: usize,
first: bool,
reject: bool,
stats: &mut SolverStats,
mass: Option<&[S]>,
) -> S
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
let dd1 = S::from_f64(coefficients::DD1);
let dd2 = S::from_f64(coefficients::DD2);
let dd3 = S::from_f64(coefficients::DD3);
let mut f2 = vec![S::ZERO; dim];
for i in 0..dim {
f2[i] = dd1 * z1[i] + dd2 * z2[i] + dd3 * z3[i];
}
let mut cont = vec![S::ZERO; dim];
if let Some(m) = mass {
let mut mf2 = vec![S::ZERO; dim];
for i in 0..dim {
for j in 0..dim {
mf2[i] = mf2[i] + m[i * dim + j] * f2[j];
}
}
for i in 0..dim {
cont[i] = mf2[i] / h + f0[i]; }
for i in 0..dim {
f2[i] = mf2[i] / h;
}
} else {
for i in 0..dim {
f2[i] = f2[i] / h;
cont[i] = f2[i] + f0[i]; }
}
let solved = match lu_real.solve(&cont) {
Ok(s) => s,
Err(_) => return S::from_f64(1e6),
};
let mut err_norm = S::ZERO;
for i in 0..dim {
err[i] = solved[i];
let y_max = y[i].abs().max(y_new[i].abs());
let scale = options.atol + options.rtol * y_max;
let r = solved[i] / scale;
err_norm = err_norm + r * r;
}
let err_norm = (err_norm / S::from_usize(dim)).sqrt();
let err_norm = err_norm.max(S::from_f64(1e-10));
if err_norm >= S::ONE && (first || reject) {
for i in 0..dim {
cont[i] = y[i] + solved[i];
}
let mut f1 = vec![S::ZERO; dim];
problem.rhs(t, &cont, &mut f1);
stats.n_eval += 1;
for i in 0..dim {
cont[i] = f1[i] + f2[i];
}
let solved2 = match lu_real.solve(&cont) {
Ok(s) => s,
Err(_) => return S::from_f64(1e6),
};
let mut err_norm2 = S::ZERO;
for i in 0..dim {
err[i] = solved2[i];
let y_max = y[i].abs().max(y_new[i].abs());
let scale = options.atol + options.rtol * y_max;
let r = solved2[i] / scale;
err_norm2 = err_norm2 + r * r;
}
let err_norm2 = (err_norm2 / S::from_usize(dim)).sqrt();
return err_norm2.max(S::from_f64(1e-10));
}
err_norm
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::problem::{DaeProblem, OdeProblem};
#[test]
fn test_radau5_stiff_decay() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -100.0 * y[0];
},
0.0,
0.1,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-2).atol(1e-4);
let result = Radau5::solve(&problem, 0.0, 0.1, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let exact = (-10.0_f64).exp();
assert!(
(y_final[0] - exact).abs() < 1e-4,
"Error: {}",
(y_final[0] - exact).abs()
);
}
#[test]
fn test_radau5_exponential() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[0];
},
0.0,
1.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let exact = 1.0_f64.exp();
assert!((y_final[0] - exact).abs() < 1e-5);
}
#[test]
fn test_radau5_linear_2d() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0] + y[1];
dydt[1] = -y[0] - y[1];
},
0.0,
1.0,
vec![1.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0, 0.0], &options).unwrap();
assert!(result.success);
}
#[test]
fn test_radau5_van_der_pol_mild() {
let mu = 10.0;
let problem = OdeProblem::new(
move |_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
},
0.0,
2.0,
vec![2.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Radau5::solve(&problem, 0.0, 2.0, &[2.0, 0.0], &options);
assert!(result.is_ok());
}
#[test]
fn test_radau5_van_der_pol_stiff() {
let mu = 100.0;
let problem = OdeProblem::new(
move |_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
},
0.0,
20.0,
vec![2.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options);
assert!(
result.is_ok(),
"Van der Pol μ=100 failed: {:?}",
result.err()
);
}
#[test]
fn test_radau5_step_efficiency() {
let mu = 100.0;
let problem = OdeProblem::new(
move |_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
},
0.0,
20.0,
vec![2.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options).unwrap();
assert!(
result.stats.n_accept < 200,
"Too many accepted steps: {} (expected < 200, ~15 typical)",
result.stats.n_accept
);
assert!(result.success);
}
#[test]
fn test_radau5_simple_dae() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0] + y[1];
dydt[1] = y[0] - y[1];
},
|mass: &mut [f64]| {
for i in 0..4 {
mass[i] = 0.0;
}
mass[0] = 1.0;
},
0.0,
1.0,
vec![1.0, 1.0],
vec![1],
);
let options = SolverOptions::default()
.rtol(1e-4)
.atol(1e-6)
.max_steps(500_000);
let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0, 1.0], &options);
assert!(result.is_ok(), "DAE solve failed: {:?}", result.err());
let sol = result.unwrap();
let yf = sol.y_final().unwrap();
assert!(
(yf[0] - 1.0).abs() < 1e-4,
"y1 deviated: {} (expected 1.0)",
yf[0]
);
assert!(
(yf[1] - 1.0).abs() < 1e-4,
"y2 deviated: {} (expected 1.0)",
yf[1]
);
let constraint = yf[0] - yf[1];
assert!(
constraint.abs() < 1e-4,
"Constraint violated: {} (y1={}, y2={})",
constraint,
yf[0],
yf[1]
);
}
#[test]
fn test_radau5_dae_with_mass_identity() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
},
0.0,
1.0,
vec![1.0],
vec![],
);
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
assert!(
result.is_ok(),
"DAE with identity mass failed: {:?}",
result.err()
);
let sol = result.unwrap();
let yf = sol.y_final().unwrap();
let exact = (-1.0_f64).exp();
assert!(
(yf[0] - exact).abs() < 1e-5,
"Error: {} (expected {}, got {})",
(yf[0] - exact).abs(),
exact,
yf[0]
);
}
#[test]
fn test_radau5_dae_scaled_mass() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
|mass: &mut [f64]| {
mass[0] = 2.0;
},
0.0,
1.0,
vec![1.0],
vec![],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
assert!(
result.is_ok(),
"DAE with scaled mass failed: {:?}",
result.err()
);
let sol = result.unwrap();
let yf = sol.y_final().unwrap();
let exact = (-0.5_f64).exp();
assert!(
(yf[0] - exact).abs() < 1e-3,
"Error: {} (expected {}, got {})",
(yf[0] - exact).abs(),
exact,
yf[0]
);
}
}