use crate::error::IntegrateResult;
use crate::ode::types::{MassMatrix, MassMatrixType, ODEMethod, ODEOptions, ODEResult};
use crate::ode::utils::common::calculate_error_weights;
use crate::ode::utils::dense_output::DenseSolution;
use crate::ode::utils::interpolation::ContinuousOutputMethod;
use crate::ode::utils::mass_matrix;
use crate::IntegrateFloat;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
const C1_F64: f64 = 0.155_051_025_721_682_2; const C2_F64: f64 = 0.644_948_974_278_317_8;
const A11_F64: f64 = 0.196_815_477_223_660_4;
const A12_F64: f64 = -0.065_535_425_850_198_4; const A13_F64: f64 = 0.023_770_974_348_220_15; const A21_F64: f64 = 0.394_424_314_739_087_3;
const A22_F64: f64 = 0.292_073_411_665_228_4; const A23_F64: f64 = -0.041_548_752_125_997_9; const A31_F64: f64 = 0.376_403_062_700_467_3;
const A32_F64: f64 = 0.512_485_826_188_421_6;
const A33_F64: f64 = 1.0 / 9.0;
const MAX_NEWTON_ITER: usize = 7;
const NEWTON_KAPPA: f64 = 0.1;
const FD_EPS: f64 = 1e-7;
const EC_K1_F64: f64 = -1.558_078_204_724_922_6; const EC_K2_F64: f64 = 0.891_411_538_058_255_2; const EC_K3_F64: f64 = -1.0 / 3.0;
const MU_REAL_F64: f64 = 3.637_834_252_744_496;
#[allow(dead_code)]
pub fn radau_method_with_mass<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
mass_matrix: MassMatrix<F>,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + std::iter::Sum,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let n = y0.len();
let sn = 3 * n;
mass_matrix::check_mass_compatibility(&mass_matrix, t_start, y0.view())?;
let c1 = F::from_f64(C1_F64).expect("from_f64");
let c2 = F::from_f64(C2_F64).expect("from_f64");
let c3 = F::one();
let a11 = F::from_f64(A11_F64).expect("from_f64");
let a12 = F::from_f64(A12_F64).expect("from_f64");
let a13 = F::from_f64(A13_F64).expect("from_f64");
let a21 = F::from_f64(A21_F64).expect("from_f64");
let a22 = F::from_f64(A22_F64).expect("from_f64");
let a23 = F::from_f64(A23_F64).expect("from_f64");
let a31 = F::from_f64(A31_F64).expect("from_f64");
let a32 = F::from_f64(A32_F64).expect("from_f64");
let a33 = F::from_f64(A33_F64).expect("from_f64");
let b1 = a31;
let b2 = a32;
let b3 = a33;
let span = t_end - t_start;
let h0 = opts
.h0
.unwrap_or_else(|| span / F::from_f64(100.0).expect("from_f64"));
let min_step = opts
.min_step
.unwrap_or_else(|| span * F::from_f64(1e-10).expect("from_f64"));
let max_step = opts.max_step.unwrap_or(span);
let mut t = t_start;
let mut y = y0.clone();
let mut h = h0;
let mut t_values = vec![t];
let mut y_values = vec![y.clone()];
let mut dy_values: Vec<Array1<F>> = Vec::new();
if opts.dense_output {
let f_y0 = f(t, y.view());
let dy0 = mass_matrix::solve_mass_system(&mass_matrix, t, y.view(), f_y0.view())?;
dy_values.push(dy0);
}
let mut func_evals: usize = 1;
let mut step_count: usize = 0;
let mut accepted_steps: usize = 0;
let mut rejected_steps: usize = 0;
let mut n_lu: usize = 0;
let mut n_jac: usize = 0;
let rtol = opts.rtol;
let atol = opts.atol;
let mut cached_jac: Option<Array2<F>> = None;
let mut cached_m: Option<Option<Array2<F>>> = None;
while t < t_end && step_count < opts.max_steps {
if t + h > t_end {
h = t_end - t;
}
h = h.min(max_step).max(min_step);
let t1 = t + c1 * h;
let t2 = t + c2 * h;
let t3 = t + c3 * h;
step_count += 1;
if cached_jac.is_none() {
let f0 = f(t, y.view());
func_evals += 1;
let jac = finite_diff_jac(&f, t, &y, &f0, F::from_f64(FD_EPS).expect("from_f64"));
cached_jac = Some(jac);
n_jac += 1;
let m_opt = mass_matrix.evaluate(t, y.view());
cached_m = Some(m_opt);
}
let jac = cached_jac.as_ref().expect("jacobian must be set");
let m_eval = cached_m.as_ref().expect("mass matrix must be set");
let newton_mat = build_coupled_matrix(
n,
sn,
h,
&a11,
&a12,
&a13,
&a21,
&a22,
&a23,
&a31,
&a32,
&a33,
m_eval,
jac,
&mass_matrix,
t,
&y,
);
n_lu += 1;
let k_init = compute_initial_k(&mass_matrix, t, &y, m_eval, &f, &mut func_evals)?;
let mut k = Array1::<F>::zeros(sn);
for i in 0..n {
k[i] = k_init[i];
k[n + i] = k_init[i];
k[2 * n + i] = k_init[i];
}
let error_weights = calculate_error_weights(&y, atol, rtol);
let mut newton_converged = false;
let mut prev_res_norm = F::from_f64(f64::MAX).expect("from_f64");
for newton_iter in 0..MAX_NEWTON_ITER {
let k1 = k.slice(scirs2_core::ndarray::s![0..n]).to_owned();
let k2 = k.slice(scirs2_core::ndarray::s![n..2 * n]).to_owned();
let k3 = k.slice(scirs2_core::ndarray::s![2 * n..3 * n]).to_owned();
let y1 = &y + &((&k1 * a11 + &k2 * a12 + &k3 * a13) * h);
let y2 = &y + &((&k1 * a21 + &k2 * a22 + &k3 * a23) * h);
let y3 = &y + &((&k1 * a31 + &k2 * a32 + &k3 * a33) * h);
let f1 = f(t1, y1.view());
let f2 = f(t2, y2.view());
let f3 = f(t3, y3.view());
func_evals += 3;
let r = compute_residual(
n,
&k1,
&k2,
&k3,
&f1,
&f2,
&f3,
&mass_matrix,
t1,
t2,
t3,
&y1,
&y2,
&y3,
)?;
let res_norm = rms_norm_weighted(&r, &error_weights, 3);
if newton_iter > 0 {
let theta = res_norm / prev_res_norm;
if theta > F::one() {
break;
}
let kappa = F::from_f64(NEWTON_KAPPA).expect("from_f64");
let predicted_final = theta / (F::one() - theta) * res_norm;
if predicted_final < kappa || res_norm < F::from_f64(1e-10).expect("from_f64") {
newton_converged = true;
break;
}
}
prev_res_norm = res_norm;
if res_norm < F::from_f64(1e-10).expect("from_f64") {
newton_converged = true;
break;
}
let neg_r = r.mapv(|x| -x);
let dk = solve_coupled_system(&newton_mat, &neg_r, sn)?;
k = k + dk;
}
if !newton_converged {
let k1 = k.slice(scirs2_core::ndarray::s![0..n]).to_owned();
let k2 = k.slice(scirs2_core::ndarray::s![n..2 * n]).to_owned();
let k3 = k.slice(scirs2_core::ndarray::s![2 * n..3 * n]).to_owned();
let y1 = &y + &((&k1 * a11 + &k2 * a12 + &k3 * a13) * h);
let y2 = &y + &((&k1 * a21 + &k2 * a22 + &k3 * a23) * h);
let y3 = &y + &((&k1 * a31 + &k2 * a32 + &k3 * a33) * h);
let f1 = f(t1, y1.view());
let f2 = f(t2, y2.view());
let f3 = f(t3, y3.view());
func_evals += 3;
let r = compute_residual(
n,
&k1,
&k2,
&k3,
&f1,
&f2,
&f3,
&mass_matrix,
t1,
t2,
t3,
&y1,
&y2,
&y3,
)?;
let res_norm = rms_norm_weighted(&r, &error_weights, 3);
if res_norm < F::from_f64(1e-6).expect("from_f64") {
newton_converged = true;
}
}
if !newton_converged {
h *= F::from_f64(0.5).expect("from_f64");
rejected_steps += 1;
cached_jac = None;
cached_m = None;
if h < min_step {
return Err(crate::error::IntegrateError::ComputationError(
"Radau Newton iteration failed to converge even at minimum step size"
.to_string(),
));
}
continue;
}
let k1 = k.slice(scirs2_core::ndarray::s![0..n]).to_owned();
let k2 = k.slice(scirs2_core::ndarray::s![n..2 * n]).to_owned();
let k3 = k.slice(scirs2_core::ndarray::s![2 * n..3 * n]).to_owned();
let y_new = &y + &((&k1 * b1 + &k2 * b2 + &k3 * b3) * h);
let ec_k1 = F::from_f64(EC_K1_F64).expect("from_f64");
let ec_k2 = F::from_f64(EC_K2_F64).expect("from_f64");
let ec_k3 = F::from_f64(EC_K3_F64).expect("from_f64");
let ze = &k1 * ec_k1 + &k2 * ec_k2 + &k3 * ec_k3;
let rhs_err = &k_init + &ze;
let mu_h = F::from_f64(MU_REAL_F64).expect("from_f64") / h;
let err_mat = build_error_matrix(n, mu_h, m_eval, jac);
let error_vec = {
use crate::ode::utils::linear_solvers::solve_linear_system;
solve_linear_system(&err_mat.view(), &rhs_err.view())
.unwrap_or_else(|_| rhs_err.clone())
};
let scale: Array1<F> = y
.iter()
.zip(y_new.iter())
.map(|(&yi, &yi_new)| {
let mx = if yi.abs() > yi_new.abs() {
yi.abs()
} else {
yi_new.abs()
};
atol + mx * rtol
})
.collect();
let error_norm = error_vec
.iter()
.zip(scale.iter())
.map(|(&e, &w)| (e / w).powi(2))
.sum::<F>()
.sqrt()
/ F::from_f64((n as f64).sqrt()).expect("from_f64");
if error_norm <= F::one() {
t += h;
y = y_new;
t_values.push(t);
y_values.push(y.clone());
if opts.dense_output {
let f_y = f(t, y.view());
func_evals += 1;
let dy = mass_matrix::solve_mass_system(&mass_matrix, t, y.view(), f_y.view())?;
dy_values.push(dy);
}
accepted_steps += 1;
let fac_min = F::from_f64(0.2).expect("from_f64");
let fac_max = F::from_f64(5.0).expect("from_f64");
let safety = F::from_f64(0.9).expect("from_f64");
let factor = if error_norm < F::from_f64(1e-14).expect("from_f64") {
fac_max
} else {
let exponent = F::from_f64(0.25).expect("from_f64");
let raw = safety * (F::one() / error_norm).powf(exponent);
raw.max(fac_min).min(fac_max)
};
h *= factor;
} else {
let exponent = F::from_f64(0.25).expect("from_f64");
let safety = F::from_f64(0.9).expect("from_f64");
let factor = (safety * (F::one() / error_norm).powf(exponent))
.max(F::from_f64(0.2).expect("from_f64"))
.min(F::from_f64(1.0).expect("from_f64"));
h *= factor;
rejected_steps += 1;
cached_jac = None;
cached_m = None;
}
}
let success = t >= t_end;
let message = if success {
Some(format!("Integration successful, reached t = {t:?}"))
} else {
Some(format!("Integration incomplete, stopped at t = {t:?}"))
};
let _dense_output = if opts.dense_output {
Some(DenseSolution::new(
t_values.clone(),
y_values.clone(),
Some(dy_values),
Some(ContinuousOutputMethod::CubicHermite),
None,
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: accepted_steps,
n_rejected: rejected_steps,
n_lu,
n_jac,
method: ODEMethod::Radau,
})
}
fn finite_diff_jac<F, Func>(f: &Func, t: F, y: &Array1<F>, f0: &Array1<F>, eps: F) -> Array2<F>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let n = y.len();
let mut jac = Array2::<F>::zeros((n, n));
for j in 0..n {
let scale = F::one() + y[j].abs();
let h_j = eps * scale;
let mut yp = y.clone();
yp[j] += h_j;
let fp = f(t, yp.view());
for i in 0..n {
jac[[i, j]] = (fp[i] - f0[i]) / h_j;
}
}
jac
}
#[allow(clippy::too_many_arguments)]
fn build_coupled_matrix<F>(
n: usize,
sn: usize,
h: F,
a11: &F,
a12: &F,
a13: &F,
a21: &F,
a22: &F,
a23: &F,
a31: &F,
a32: &F,
a33: &F,
m_opt: &Option<Array2<F>>,
jac: &Array2<F>,
_mass: &MassMatrix<F>,
_t: F,
_y: &Array1<F>,
) -> Array2<F>
where
F: IntegrateFloat,
{
let a_mat = [[*a11, *a12, *a13], [*a21, *a22, *a23], [*a31, *a32, *a33]];
let mut mat = Array2::<F>::zeros((sn, sn));
for bi in 0..3 {
for bj in 0..3 {
let a_ij = a_mat[bi][bj];
let row_off = bi * n;
let col_off = bj * n;
for i in 0..n {
for j in 0..n {
let m_ij = if bi == bj {
match m_opt {
Some(ref m) => m[[i, j]],
None => {
if i == j {
F::one()
} else {
F::zero()
}
}
}
} else {
F::zero()
};
let j_term = h * a_ij * jac[[i, j]];
mat[[row_off + i, col_off + j]] = m_ij - j_term;
}
}
}
}
mat
}
fn build_error_matrix<F>(n: usize, mu_h: F, m_opt: &Option<Array2<F>>, jac: &Array2<F>) -> Array2<F>
where
F: IntegrateFloat,
{
let mut mat = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let m_ij = match m_opt {
Some(ref m) => m[[i, j]],
None => {
if i == j {
F::one()
} else {
F::zero()
}
}
};
mat[[i, j]] = mu_h * m_ij - jac[[i, j]];
}
}
mat
}
fn compute_initial_k<F, Func>(
mass: &MassMatrix<F>,
t: F,
y: &Array1<F>,
m_opt: &Option<Array2<F>>,
f: &Func,
func_evals: &mut usize,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let f0 = f(t, y.view());
*func_evals += 1;
let k0 = match m_opt {
None => f0, Some(_) => mass_matrix::solve_mass_system(mass, t, y.view(), f0.view())?,
};
Ok(k0)
}
#[allow(clippy::too_many_arguments)]
fn compute_residual<F>(
n: usize,
k1: &Array1<F>,
k2: &Array1<F>,
k3: &Array1<F>,
f1: &Array1<F>,
f2: &Array1<F>,
f3: &Array1<F>,
mass: &MassMatrix<F>,
t1: F,
t2: F,
t3: F,
y1: &Array1<F>,
y2: &Array1<F>,
y3: &Array1<F>,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
{
let mut r = Array1::<F>::zeros(3 * n);
let mk1 = apply_m_vec(mass, t1, y1, k1)?;
for i in 0..n {
r[i] = mk1[i] - f1[i];
}
let mk2 = apply_m_vec(mass, t2, y2, k2)?;
for i in 0..n {
r[n + i] = mk2[i] - f2[i];
}
let mk3 = apply_m_vec(mass, t3, y3, k3)?;
for i in 0..n {
r[2 * n + i] = mk3[i] - f3[i];
}
Ok(r)
}
fn apply_m_vec<F>(
mass: &MassMatrix<F>,
t: F,
y: &Array1<F>,
v: &Array1<F>,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
{
mass_matrix::apply_mass(mass, t, y.view(), v.view())
}
fn rms_norm_weighted<F: IntegrateFloat>(r: &Array1<F>, weights: &Array1<F>, stages: usize) -> F {
let n = weights.len();
let total = r.len();
let sum = r
.iter()
.enumerate()
.map(|(idx, &ri)| {
let w = weights[idx % n];
(ri / w).powi(2)
})
.sum::<F>();
let denom = F::from_usize(total.max(1)).expect("from_usize")
/ F::from_usize(stages).expect("from_usize");
(sum / denom).sqrt()
}
fn solve_coupled_system<F: IntegrateFloat>(
a: &Array2<F>,
b: &Array1<F>,
sn: usize,
) -> IntegrateResult<Array1<F>> {
use crate::ode::utils::linear_solvers::solve_linear_system;
let av = a.view();
let bv = b.view();
debug_assert_eq!(a.shape(), &[sn, sn]);
solve_linear_system(&av, &bv)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_butcher_tableau_consistency() {
let b1 = A31_F64;
let b2 = A32_F64;
let b3 = A33_F64;
let c1 = C1_F64;
let c2 = C2_F64;
let c3 = 1.0_f64;
assert_relative_eq!(b1 + b2 + b3, 1.0, epsilon = 1e-12);
assert_relative_eq!(b1 * c1 + b2 * c2 + b3 * c3, 0.5, epsilon = 1e-12);
assert_relative_eq!(
b1 * c1.powi(2) + b2 * c2.powi(2) + b3 * c3.powi(2),
1.0 / 3.0,
epsilon = 1e-12
);
let row_sum_1 = A11_F64 + A12_F64 + A13_F64;
let row_sum_2 = A21_F64 + A22_F64 + A23_F64;
let row_sum_3 = A31_F64 + A32_F64 + A33_F64;
assert_relative_eq!(row_sum_1, c1, epsilon = 1e-13);
assert_relative_eq!(row_sum_2, c2, epsilon = 1e-13);
assert_relative_eq!(row_sum_3, c3, epsilon = 1e-13);
}
}