use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::types::{ODEOptions, ODEResult};
use crate::IntegrateFloat;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
#[inline(always)]
fn to_f<F: IntegrateFloat>(v: f64) -> F {
F::from_f64(v).unwrap_or_else(|| F::zero())
}
#[derive(Debug, Clone)]
struct RosenbrockTableau {
stages: usize,
a: Vec<f64>,
c: Vec<f64>,
gamma: Vec<f64>,
gamma_diag: f64,
b: Vec<f64>,
b_hat: Vec<f64>,
order: usize,
embedded_order: usize,
}
impl RosenbrockTableau {
fn a_ij(&self, i: usize, j: usize) -> f64 {
debug_assert!(i > j);
let idx = i * (i - 1) / 2 + j;
self.a[idx]
}
fn gamma_ij(&self, i: usize, j: usize) -> f64 {
debug_assert!(i >= j);
let idx = i * (i + 1) / 2 + j;
self.gamma[idx]
}
}
fn ros3w_tableau() -> RosenbrockTableau {
let gamma_val = 0.435_866_521_508_459;
RosenbrockTableau {
stages: 3,
a: vec![
1.0, 1.0, 0.0, ],
c: vec![
0.0, 0.435_866_521_508_459, 0.435_866_521_508_459, ],
gamma: vec![
gamma_val, -0.192_946_556_960_290_95,
gamma_val, 0.0,
1.749_271_481_253_087,
gamma_val, ],
gamma_diag: gamma_val,
b: vec![
0.242_919_964_548_163_2,
0.070_388_567_562_680_46,
0.686_691_467_889_156_4,
],
b_hat: vec![
0.208_557_688_403_812_48,
0.064_139_660_247_965_14,
0.727_302_651_348_222_4,
],
order: 3,
embedded_order: 2,
}
}
fn ros34pw2_tableau() -> RosenbrockTableau {
let gamma_val = 0.435_866_521_508_459;
RosenbrockTableau {
stages: 4,
a: vec![
0.871_733_043_016_918,
0.844_570_600_153_694_4,
-0.112_990_642_363_971_6,
0.0,
0.0,
1.0,
],
c: vec![0.0, 0.871_733_043_016_918, 0.731_580_007_789_722_8, 1.0],
gamma: vec![
gamma_val,
-0.871_733_043_016_918,
gamma_val,
-0.903_380_570_130_440_8,
0.054_180_672_388_095_47,
gamma_val,
0.242_123_807_060_954_64,
-1.223_250_583_904_514_7,
0.545_260_255_335_102_3,
gamma_val,
],
gamma_diag: gamma_val,
b: vec![
0.242_123_807_060_954_64,
-1.223_250_583_904_514_7,
1.545_260_255_335_102_3,
0.435_866_521_508_459,
],
b_hat: vec![
0.378_109_031_458_193_7,
-0.096_042_292_212_423_18,
0.5,
0.217_933_260_754_229_5,
],
order: 4,
embedded_order: 3,
}
}
fn rodas3_tableau() -> RosenbrockTableau {
ros3w_tableau()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RosenbrockVariant {
ROS3w,
#[default]
ROS34PW2,
RODAS3,
}
fn numerical_jacobian<F, Func>(f: &Func, t: F, y: &Array1<F>, f_at_y: &Array1<F>) -> Array2<F>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let n = y.len();
let mut jac = Array2::zeros((n, n));
let eps_root = F::epsilon().sqrt();
for j in 0..n {
let mut y_pert = y.clone();
let delta = eps_root * (F::one() + y[j].abs());
y_pert[j] += delta;
let f_pert = f(t, y_pert.view());
for i in 0..n {
jac[[i, j]] = (f_pert[i] - f_at_y[i]) / delta;
}
}
jac
}
fn lu_solve<F: IntegrateFloat>(a: &Array2<F>, b: &Array1<F>) -> IntegrateResult<Array1<F>> {
let n = a.nrows();
if n != a.ncols() || n != b.len() {
return Err(IntegrateError::DimensionMismatch(
"lu_solve: incompatible dimensions".into(),
));
}
let mut lu = a.clone();
let mut piv: Vec<usize> = (0..n).collect();
let tiny = F::from_f64(1e-30).unwrap_or_else(|| F::epsilon());
for k in 0..n {
let mut max_val = lu[[piv[k], k]].abs();
let mut max_idx = k;
for i in (k + 1)..n {
let v = lu[[piv[i], k]].abs();
if v > max_val {
max_val = v;
max_idx = i;
}
}
if max_val < tiny {
return Err(IntegrateError::LinearSolveError(
"Singular or near-singular matrix in Rosenbrock solver".into(),
));
}
piv.swap(k, max_idx);
for i in (k + 1)..n {
let factor = lu[[piv[i], k]] / lu[[piv[k], k]];
lu[[piv[i], k]] = factor;
for j in (k + 1)..n {
let val = lu[[piv[k], j]];
lu[[piv[i], j]] -= factor * val;
}
}
}
let mut z = Array1::zeros(n);
for i in 0..n {
let mut s = b[piv[i]];
for j in 0..i {
s -= lu[[piv[i], j]] * z[j];
}
z[i] = s;
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut s = z[i];
for j in (i + 1)..n {
s -= lu[[piv[i], j]] * x[j];
}
if lu[[piv[i], i]].abs() < tiny {
return Err(IntegrateError::LinearSolveError(
"Zero diagonal in U factor".into(),
));
}
x[i] = s / lu[[piv[i], i]];
}
Ok(x)
}
pub fn rosenbrock_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
opts: ODEOptions<F>,
variant: RosenbrockVariant,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let tab = match variant {
RosenbrockVariant::ROS3w => ros3w_tableau(),
RosenbrockVariant::ROS34PW2 => ros34pw2_tableau(),
RosenbrockVariant::RODAS3 => rodas3_tableau(),
};
let [t0, tf] = t_span;
if tf <= t0 {
return Err(IntegrateError::ValueError(
"t_end must be greater than t_start".into(),
));
}
let n = y0.len();
let s = tab.stages;
let span = tf - t0;
let mut h = opts.h0.unwrap_or_else(|| span * to_f::<F>(0.001));
let h_min = opts.min_step.unwrap_or_else(|| span * to_f::<F>(1e-12));
let h_max = opts.max_step.unwrap_or(span);
let rtol = opts.rtol;
let atol = opts.atol;
let max_steps = opts.max_steps;
let mut ks: Vec<Array1<F>> = (0..s).map(|_| Array1::zeros(n)).collect();
let mut t_vals = vec![t0];
let mut y_vals = vec![y0.clone()];
let mut t = t0;
let mut y = y0;
let mut step_count = 0_usize;
let mut func_evals = 0_usize;
let mut jac_evals = 0_usize;
let mut accepted = 0_usize;
let mut rejected = 0_usize;
let mut f_current = f(t, y.view());
func_evals += 1;
let mut jac = numerical_jacobian(&f, t, &y, &f_current);
jac_evals += 1;
let safety: F = to_f(0.9);
let fac_min: F = to_f(0.2);
let fac_max: F = to_f(2.5);
while t < tf && step_count < max_steps {
if t + h > tf {
h = tf - t;
}
if h < h_min {
h = h_min;
}
let h_gamma = h * to_f::<F>(tab.gamma_diag);
let mut w_mat = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
w_mat[[i, j]] = -h_gamma * jac[[i, j]];
}
w_mat[[i, i]] += F::one();
}
let mut stage_ok = true;
for i in 0..s {
let mut y_stage = y.clone();
for j in 0..i {
let a_ij: F = to_f(tab.a_ij(i, j));
y_stage += &(&ks[j] * a_ij);
}
let t_stage = t + to_f::<F>(tab.c[i]) * h;
let f_stage = f(t_stage, y_stage.view());
func_evals += 1;
let mut rhs = &f_stage * h;
if i > 0 {
let mut gamma_sum = Array1::zeros(n);
for j in 0..i {
let g_ij: F = to_f(tab.gamma_ij(i, j));
gamma_sum += &(&ks[j] * g_ij);
}
let mut j_times_gs = Array1::zeros(n);
for row in 0..n {
let mut val = F::zero();
for col in 0..n {
val += jac[[row, col]] * gamma_sum[col];
}
j_times_gs[row] = val;
}
rhs += &(&j_times_gs * h);
}
match lu_solve(&w_mat, &rhs) {
Ok(k_i) => {
ks[i] = k_i;
}
Err(_) => {
stage_ok = false;
break;
}
}
}
if !stage_ok {
h *= to_f::<F>(0.5);
rejected += 1;
step_count += 1;
continue;
}
let mut y_new = y.clone();
let mut y_hat = y.clone();
for i in 0..s {
let b_i: F = to_f(tab.b[i]);
let bh_i: F = to_f(tab.b_hat[i]);
y_new += &(&ks[i] * b_i);
y_hat += &(&ks[i] * bh_i);
}
let mut err_norm = F::zero();
for i in 0..n {
let sc = atol + rtol * y_new[i].abs().max(y[i].abs());
let e = (y_new[i] - y_hat[i]) / sc;
err_norm += e * e;
}
err_norm = (err_norm / to_f::<F>(n as f64)).sqrt();
if err_norm <= F::one() {
t += h;
y = y_new;
f_current = f(t, y.view());
func_evals += 1;
jac = numerical_jacobian(&f, t, &y, &f_current);
jac_evals += 1;
t_vals.push(t);
y_vals.push(y.clone());
accepted += 1;
} else {
rejected += 1;
}
let q: F = to_f((tab.embedded_order + 1) as f64);
let err_safe = err_norm.max(to_f::<F>(1e-6));
let factor = safety * (F::one() / err_safe).powf(F::one() / q);
let factor = factor.max(fac_min).min(fac_max);
h *= factor;
h = h.min(h_max).max(h_min);
step_count += 1;
}
if t < tf {
let _last_t = t_vals
.last()
.copied()
.ok_or_else(|| IntegrateError::ComputationError("Empty solution".into()))?;
}
Ok(ODEResult {
t: t_vals,
y: y_vals,
n_steps: step_count,
n_accepted: accepted,
n_rejected: rejected,
n_eval: func_evals,
n_jac: jac_evals,
n_lu: accepted + rejected,
success: t >= tf - h_min,
message: if t >= tf - h_min {
Some("Integration completed successfully".to_string())
} else {
Some(format!("Integration stopped at t={t} (max steps reached)"))
},
method: crate::ode::types::ODEMethod::Radau, })
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_exponential_decay() {
let result = rosenbrock_method(
|_t: f64, y: ArrayView1<f64>| array![-y[0]],
[0.0, 1.0],
array![1.0],
ODEOptions {
rtol: 1e-8,
atol: 1e-10,
..Default::default()
},
RosenbrockVariant::ROS34PW2,
)
.expect("rosenbrock solve");
let y_final = result.y.last().expect("has solution")[0];
let exact = (-1.0_f64).exp();
assert!(
(y_final - exact).abs() < 1e-5,
"exp decay: got {y_final}, expected {exact}"
);
}
#[test]
fn test_linear_growth() {
let result = rosenbrock_method(
|_t: f64, _y: ArrayView1<f64>| array![1.0],
[0.0, 2.0],
array![0.0],
ODEOptions::default(),
RosenbrockVariant::ROS3w,
)
.expect("rosenbrock solve");
let y_final = result.y.last().expect("has solution")[0];
assert!(
(y_final - 2.0).abs() < 1e-6,
"linear: got {y_final}, expected 2.0"
);
}
#[test]
fn test_harmonic_oscillator() {
let t_end = 2.0;
let result = rosenbrock_method(
|_t: f64, y: ArrayView1<f64>| array![y[1], -y[0]],
[0.0, t_end],
array![1.0, 0.0],
ODEOptions {
rtol: 1e-8,
atol: 1e-10,
max_steps: 5000,
..Default::default()
},
RosenbrockVariant::ROS34PW2,
)
.expect("rosenbrock solve");
let y_final = result.y.last().expect("has solution");
let exact_y1 = t_end.cos();
let exact_y2 = -(t_end.sin());
assert!(
(y_final[0] - exact_y1).abs() < 1e-3,
"harmonic y1: got {}, expected {exact_y1}",
y_final[0]
);
assert!(
(y_final[1] - exact_y2).abs() < 1e-3,
"harmonic y2: got {}, expected {exact_y2}",
y_final[1]
);
}
#[test]
fn test_stiff_robertson() {
let result = rosenbrock_method(
|_t: f64, y: ArrayView1<f64>| {
array![
-0.04 * y[0] + 1e4 * y[1] * y[2],
0.04 * y[0] - 1e4 * y[1] * y[2] - 3e7 * y[1] * y[1],
3e7 * y[1] * y[1]
]
},
[0.0, 0.1],
array![1.0, 0.0, 0.0],
ODEOptions {
rtol: 1e-4,
atol: 1e-8,
max_steps: 5000,
..Default::default()
},
RosenbrockVariant::ROS34PW2,
)
.expect("rosenbrock Robertson");
let y_final = result.y.last().expect("has solution");
let sum = y_final[0] + y_final[1] + y_final[2];
assert!(
(sum - 1.0).abs() < 1e-3,
"Robertson conservation: sum = {sum}"
);
assert!(result.success, "Robertson should complete");
}
#[test]
fn test_rodas3_variant() {
let result = rosenbrock_method(
|_t: f64, y: ArrayView1<f64>| array![-y[0]],
[0.0, 1.0],
array![1.0],
ODEOptions {
rtol: 1e-8,
atol: 1e-10,
max_steps: 5000,
..Default::default()
},
RosenbrockVariant::RODAS3,
)
.expect("RODAS3 solve");
let y_final = result.y.last().expect("has solution")[0];
let exact = (-1.0_f64).exp();
assert!(
(y_final - exact).abs() < 1e-3,
"RODAS3 exp decay: got {y_final}, expected {exact}"
);
}
#[test]
fn test_van_der_pol_stiff() {
let mu = 5.0;
let result = rosenbrock_method(
move |_t: f64, y: ArrayView1<f64>| array![y[1], mu * (1.0 - y[0] * y[0]) * y[1] - y[0]],
[0.0, 10.0],
array![2.0, 0.0],
ODEOptions {
rtol: 1e-5,
atol: 1e-8,
max_steps: 10_000,
..Default::default()
},
RosenbrockVariant::ROS34PW2,
)
.expect("Van der Pol");
assert!(result.success, "Van der Pol should complete");
assert!(result.t.len() > 10, "Should have multiple solution points");
}
#[test]
fn test_invalid_span() {
let res = rosenbrock_method(
|_t: f64, _y: ArrayView1<f64>| array![0.0],
[1.0, 0.0],
array![0.0],
ODEOptions::default(),
RosenbrockVariant::ROS34PW2,
);
assert!(res.is_err(), "t_end < t_start should error");
}
}