use numra_ode::{
compute_consistent_initial, compute_consistent_initial_tol, Bdf, DaeProblem, Radau5, Solver,
SolverOptions,
};
#[allow(clippy::type_complexity)]
fn rc_circuit_dae(
r: f64,
c: f64,
) -> DaeProblem<f64, impl Fn(f64, &[f64], &mut [f64]), impl Fn(&mut [f64])> {
DaeProblem::new(
move |_t, y: &[f64], dydt: &mut [f64]| {
let v = y[0];
let i = y[1];
dydt[0] = -v / (r * c); dydt[1] = i - v / r; },
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0; },
0.0,
1.0,
vec![1.0, 0.0], vec![1], )
}
#[test]
fn test_dae_init_coupled_constraints() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] + y[2] - 1.0;
dydt[2] = y[1] - y[2];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
mass[4] = 0.0;
mass[5] = 0.0;
mass[6] = 0.0;
mass[7] = 0.0;
mass[8] = 0.0;
},
0.0,
1.0,
vec![1.0, 0.7, 0.2],
vec![1, 2],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[1.0, 0.7, 0.2]).unwrap();
assert!(
(y0[1] - 0.5).abs() < 1e-8,
"y2 should be 0.5, got {}",
y0[1]
);
assert!(
(y0[2] - 0.5).abs() < 1e-8,
"y3 should be 0.5, got {}",
y0[2]
);
}
#[test]
fn test_dae_init_inconsistent_fails_gracefully() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] * y[1] + 1.0; },
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![1.0, 0.0],
vec![1],
);
let result = compute_consistent_initial_tol(&dae, 0.0, &[1.0, 0.0], 1e-10, 50);
assert!(result.is_err(), "Should fail for unsatisfiable constraint");
}
#[test]
fn test_dae_init_tight_tolerance() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] - y[0] * y[0];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![5.0, 0.0],
vec![1],
);
let y0 = compute_consistent_initial_tol(&dae, 0.0, &[5.0, 0.0], 1e-14, 100).unwrap();
let residual = y0[1] - y0[0] * y0[0];
assert!(
residual.abs() < 1e-13,
"Should achieve tight tolerance: residual = {}",
residual
);
}
#[test]
fn test_dae_rc_circuit_radau5() {
let r = 1.0_f64;
let c = 1.0_f64;
let dae = rc_circuit_dae(r, c);
let v0 = 1.0;
let y0 = compute_consistent_initial(&dae, 0.0, &[v0, 0.0]).unwrap();
assert!(
(y0[1] - v0 / r).abs() < 1e-10,
"i0 should be v0/R = {}",
v0 / r
);
let tf = 2.0;
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Radau5::solve(&dae, 0.0, tf, &y0, &options).unwrap();
assert!(result.success, "Radau5 DAE solve should succeed");
let y_final = result.y_final().unwrap();
let v_exact = v0 * (-tf / (r * c)).exp();
let i_exact = v_exact / r;
assert!(
(y_final[0] - v_exact).abs() < 1e-4,
"v(tf) error: got {}, expected {}",
y_final[0],
v_exact
);
assert!(
(y_final[1] - i_exact).abs() < 1e-4,
"i(tf) error: got {}, expected {}",
y_final[1],
i_exact
);
let dim = 2;
for step in 0..result.t.len() {
let v = result.y[step * dim];
let i = result.y[step * dim + 1];
let constraint = i - v / r;
assert!(
constraint.abs() < 1e-4,
"Constraint violated at t={}: i - v/R = {}",
result.t[step],
constraint
);
}
}
#[test]
fn test_dae_rc_circuit_radau5_different_params() {
let r = 10.0_f64;
let c = 0.1_f64;
let dae = rc_circuit_dae(r, c);
let v0 = 5.0;
let y0 = compute_consistent_initial(&dae, 0.0, &[v0, 0.0]).unwrap();
let tf = 3.0;
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Radau5::solve(&dae, 0.0, tf, &y0, &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let v_exact = v0 * (-tf / (r * c)).exp();
assert!(
(y_final[0] - v_exact).abs() < 1e-3,
"v(tf) error: got {}, expected {}",
y_final[0],
v_exact
);
}
#[test]
fn test_dae_robertson_radau5() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -0.04 * y[0] + 1e4 * y[1] * y[2];
dydt[1] = 0.04 * y[0] - 1e4 * y[1] * y[2] - 3e7 * y[1] * y[1];
dydt[2] = y[0] + y[1] + y[2] - 1.0; },
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
mass[4] = 1.0;
mass[5] = 0.0;
mass[6] = 0.0;
mass[7] = 0.0;
mass[8] = 0.0; },
0.0,
1.0,
vec![1.0, 0.0, 0.0], vec![2], );
let y0 = vec![1.0, 0.0, 0.0];
let options = SolverOptions::default()
.rtol(1e-3)
.atol(1e-6)
.max_steps(50000);
let result = Radau5::solve(&dae, 0.0, 0.01, &y0, &options);
match result {
Ok(res) => {
let y_final = res.y_final().unwrap();
let conservation = y_final[0] + y_final[1] + y_final[2];
assert!(
(conservation - 1.0).abs() < 1e-2,
"Conservation violated: y1+y2+y3 = {}",
conservation
);
for (i, &yi) in y_final.iter().enumerate() {
assert!(yi >= -1e-4, "y{} should be non-negative, got {}", i + 1, yi);
}
}
Err(e) => {
eprintln!(
"Robertson DAE challenge (expected for extreme stiffness): {}",
e
);
}
}
}
#[test]
fn test_dae_nonsingular_mass_radau5() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = -y[1];
},
|mass: &mut [f64]| {
mass[0] = 2.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 3.0;
},
0.0,
1.0,
vec![1.0, 1.0],
vec![], );
let y0 = vec![1.0, 1.0];
let tf = 2.0;
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Radau5::solve(&dae, 0.0, tf, &y0, &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let y1_exact = (-tf / 2.0).exp();
let y2_exact = (-tf / 3.0).exp();
assert!(
(y_final[0] - y1_exact).abs() < 1e-4,
"y1 error: got {}, expected {}",
y_final[0],
y1_exact
);
assert!(
(y_final[1] - y2_exact).abs() < 1e-4,
"y2 error: got {}, expected {}",
y_final[1],
y2_exact
);
}
#[test]
fn test_dae_rc_circuit_bdf() {
let r = 1.0_f64;
let c = 1.0_f64;
let dae = rc_circuit_dae(r, c);
let v0 = 1.0;
let y0 = compute_consistent_initial(&dae, 0.0, &[v0, 0.0]).unwrap();
let tf = 2.0;
let options = SolverOptions::default()
.rtol(1e-4)
.atol(1e-6)
.max_steps(5000);
let result = Bdf::solve(&dae, 0.0, tf, &y0, &options);
match result {
Ok(res) => {
assert!(res.success, "BDF DAE solve should succeed");
let y_final = res.y_final().unwrap();
let v_exact = v0 * (-tf / (r * c)).exp();
assert!(
(y_final[0] - v_exact).abs() < 0.1,
"BDF v(tf) error: got {}, expected {}",
y_final[0],
v_exact
);
}
Err(e) => {
eprintln!("BDF DAE limitation: {}", e);
}
}
}