pub type OdeRhsFn = Box<dyn Fn(f64, &[f64], &[f64]) -> Vec<f64> + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnsembleDispatch {
Sequential,
Simulated,
}
#[derive(Debug, Clone)]
pub struct OdeEnsembleConfig {
pub t_span: [f64; 2],
pub rtol: f64,
pub atol: f64,
pub max_steps: usize,
pub dispatch: EnsembleDispatch,
}
#[derive(Debug, Clone)]
pub struct EnsembleMember {
pub params: Vec<f64>,
pub y0: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct EnsembleResult {
pub solutions: Vec<Vec<f64>>,
pub n_steps: Vec<usize>,
pub success: Vec<bool>,
pub t_final: Vec<f64>,
}
pub struct OdeEnsemble {
config: OdeEnsembleConfig,
}
const A21: f64 = 1.0 / 5.0;
const A31: f64 = 3.0 / 40.0;
const A32: f64 = 9.0 / 40.0;
const A41: f64 = 44.0 / 45.0;
const A42: f64 = -56.0 / 15.0;
const A43: f64 = 32.0 / 9.0;
const A51: f64 = 19372.0 / 6561.0;
const A52: f64 = -25360.0 / 2187.0;
const A53: f64 = 64448.0 / 6561.0;
const A54: f64 = -212.0 / 729.0;
const A61: f64 = 9017.0 / 3168.0;
const A62: f64 = -355.0 / 33.0;
const A63: f64 = 46732.0 / 5247.0;
const A64: f64 = 49.0 / 176.0;
const A65: f64 = -5103.0 / 18656.0;
const B1: f64 = 35.0 / 384.0;
const B3: f64 = 500.0 / 1113.0;
const B4: f64 = 125.0 / 192.0;
const B5: f64 = -2187.0 / 6784.0;
const B6: f64 = 11.0 / 84.0;
const E1: f64 = 71.0 / 57600.0;
const E3: f64 = -71.0 / 16695.0;
const E4: f64 = 71.0 / 1920.0;
const E5: f64 = -17253.0 / 339200.0;
const E6: f64 = 22.0 / 525.0;
const E7: f64 = -1.0 / 40.0;
const C2: f64 = 1.0 / 5.0;
const C3: f64 = 3.0 / 10.0;
const C4: f64 = 4.0 / 5.0;
const C5: f64 = 8.0 / 9.0;
pub fn rk45_step(
t: f64,
y: &[f64],
params: &[f64],
h: f64,
rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
rtol: f64,
atol: f64,
) -> (Vec<f64>, Vec<f64>, f64) {
let n = y.len();
let k1 = rhs(t, y, params);
let y2: Vec<f64> = (0..n).map(|i| y[i] + h * A21 * k1[i]).collect();
let k2 = rhs(t + C2 * h, &y2, params);
let y3: Vec<f64> = (0..n)
.map(|i| y[i] + h * (A31 * k1[i] + A32 * k2[i]))
.collect();
let k3 = rhs(t + C3 * h, &y3, params);
let y4: Vec<f64> = (0..n)
.map(|i| y[i] + h * (A41 * k1[i] + A42 * k2[i] + A43 * k3[i]))
.collect();
let k4 = rhs(t + C4 * h, &y4, params);
let y5_tmp: Vec<f64> = (0..n)
.map(|i| y[i] + h * (A51 * k1[i] + A52 * k2[i] + A53 * k3[i] + A54 * k4[i]))
.collect();
let k5 = rhs(t + C5 * h, &y5_tmp, params);
let y6_tmp: Vec<f64> = (0..n)
.map(|i| y[i] + h * (A61 * k1[i] + A62 * k2[i] + A63 * k3[i] + A64 * k4[i] + A65 * k5[i]))
.collect();
let k6 = rhs(t + h, &y6_tmp, params);
let y_new: Vec<f64> = (0..n)
.map(|i| y[i] + h * (B1 * k1[i] + B3 * k3[i] + B4 * k4[i] + B5 * k5[i] + B6 * k6[i]))
.collect();
let k7 = rhs(t + h, &y_new, params);
let err_vec: Vec<f64> = (0..n)
.map(|i| h * (E1 * k1[i] + E3 * k3[i] + E4 * k4[i] + E5 * k5[i] + E6 * k6[i] + E7 * k7[i]))
.collect();
let err_norm = {
let sum_sq: f64 = (0..n)
.map(|i| {
let sc = atol + rtol * y[i].abs().max(y_new[i].abs());
let e = err_vec[i] / sc;
e * e
})
.sum::<f64>();
(sum_sq / n as f64).sqrt()
};
(y_new, err_vec, err_norm)
}
impl OdeEnsemble {
pub fn new(config: OdeEnsembleConfig) -> Self {
Self { config }
}
pub fn integrate(
&self,
members: &[EnsembleMember],
rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
) -> EnsembleResult {
let n = members.len();
let mut solutions = Vec::with_capacity(n);
let mut n_steps_vec = Vec::with_capacity(n);
let mut success_vec = Vec::with_capacity(n);
let mut t_final_vec = Vec::with_capacity(n);
for member in members {
let (y_final, n_steps, ok) = self.integrate_single(member, rhs);
let t_reached = if ok {
self.config.t_span[1]
} else {
self.config.t_span[0]
};
solutions.push(y_final);
n_steps_vec.push(n_steps);
success_vec.push(ok);
t_final_vec.push(t_reached);
}
EnsembleResult {
solutions,
n_steps: n_steps_vec,
success: success_vec,
t_final: t_final_vec,
}
}
fn integrate_single(
&self,
member: &EnsembleMember,
rhs: &dyn Fn(f64, &[f64], &[f64]) -> Vec<f64>,
) -> (Vec<f64>, usize, bool) {
let t_start = self.config.t_span[0];
let t_end = self.config.t_span[1];
let rtol = self.config.rtol;
let atol = self.config.atol;
let max_steps = self.config.max_steps;
let mut t = t_start;
let mut y = member.y0.clone();
let n = y.len();
if n == 0 {
return (y, 0, true);
}
let span = (t_end - t_start).abs();
let mut h = span * 1e-3;
h = h.min(span);
let direction = if t_end >= t_start { 1.0_f64 } else { -1.0 };
h *= direction;
let fac = 0.9_f64;
let fac_max = 10.0_f64;
let fac_min = 0.2_f64;
let mut steps = 0_usize;
let mut converged = false;
while (direction * (t_end - t)).abs() > 1e-12 * span.max(f64::EPSILON) {
if steps >= max_steps {
break;
}
if direction * (t + h - t_end) > 0.0 {
h = t_end - t;
}
if h.abs() < f64::EPSILON * span {
break;
}
let (y_new, _err_vec, err_norm) = rk45_step(t, &y, &member.params, h, rhs, rtol, atol);
if err_norm <= 1.0 || err_norm.is_nan() {
t += h;
y = y_new;
steps += 1;
if (direction * (t_end - t)).abs() < 1e-12 * span.max(f64::EPSILON) {
converged = true;
break;
}
}
let err_safe = err_norm.max(f64::EPSILON);
let factor = fac * err_safe.powf(-0.2);
let factor = factor.clamp(fac_min, fac_max);
h *= factor;
}
if (t - t_end).abs() < 1e-8 * span.max(f64::EPSILON) {
converged = true;
}
(y, steps, converged)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> OdeEnsembleConfig {
OdeEnsembleConfig {
t_span: [0.0, 1.0],
rtol: 1e-7,
atol: 1e-9,
max_steps: 100_000,
dispatch: EnsembleDispatch::Sequential,
}
}
#[test]
fn test_identical_params_same_solution() {
let config = default_config();
let ensemble = OdeEnsemble::new(config);
let members: Vec<EnsembleMember> = (0..5)
.map(|_| EnsembleMember {
params: vec![2.0],
y0: vec![1.0],
})
.collect();
let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
let y0 = &result.solutions[0];
for (i, sol) in result.solutions.iter().enumerate().skip(1) {
assert!(
(sol[0] - y0[0]).abs() < 1e-14,
"member {i} diverges from member 0: {:.6e} vs {:.6e}",
sol[0],
y0[0]
);
}
}
#[test]
fn test_different_params_different_solutions() {
let config = default_config();
let ensemble = OdeEnsemble::new(config);
let ks: Vec<f64> = vec![0.5, 1.0, 2.0, 4.0, 8.0];
let members: Vec<EnsembleMember> = ks
.iter()
.map(|&k| EnsembleMember {
params: vec![k],
y0: vec![1.0],
})
.collect();
let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
for i in 1..ks.len() {
let y_prev = result.solutions[i - 1][0];
let y_curr = result.solutions[i][0];
assert!(
y_curr < y_prev,
"k={} solution ({:.6e}) should be < k={} solution ({:.6e})",
ks[i],
y_curr,
ks[i - 1],
y_prev
);
}
}
#[test]
fn test_exponential_decay_analytical() {
let config = OdeEnsembleConfig {
t_span: [0.0, 2.0],
rtol: 1e-8,
atol: 1e-10,
max_steps: 100_000,
dispatch: EnsembleDispatch::Sequential,
};
let ensemble = OdeEnsemble::new(config);
let k = 3.0_f64;
let y0 = 2.5_f64;
let members = vec![EnsembleMember {
params: vec![k],
y0: vec![y0],
}];
let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
let y_numerical = result.solutions[0][0];
let y_analytical = y0 * (-k * 2.0_f64).exp();
assert!(
(y_numerical - y_analytical).abs() < 1e-6,
"y_numerical = {y_numerical:.8e}, y_analytical = {y_analytical:.8e}"
);
}
#[test]
fn test_all_converge() {
let config = default_config();
let ensemble = OdeEnsemble::new(config);
let members: Vec<EnsembleMember> = (1..=5)
.map(|k| EnsembleMember {
params: vec![k as f64],
y0: vec![1.0],
})
.collect();
let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
for (i, &ok) in result.success.iter().enumerate() {
assert!(ok, "member {i} did not converge");
}
}
#[test]
fn test_n_steps_positive() {
let config = default_config();
let ensemble = OdeEnsemble::new(config);
let members: Vec<EnsembleMember> = (1..=5)
.map(|k| EnsembleMember {
params: vec![k as f64],
y0: vec![1.0],
})
.collect();
let result = ensemble.integrate(&members, &|_t, y, p| vec![-p[0] * y[0]]);
for (i, &ns) in result.n_steps.iter().enumerate() {
assert!(ns > 0, "member {i} took 0 steps");
}
}
#[test]
fn test_2d_system_vanderpol() {
let config = OdeEnsembleConfig {
t_span: [0.0, 5.0],
rtol: 1e-6,
atol: 1e-8,
max_steps: 500_000,
dispatch: EnsembleDispatch::Sequential,
};
let ensemble = OdeEnsemble::new(config);
let member = EnsembleMember {
params: vec![0.1],
y0: vec![2.0, 0.0],
};
let result = ensemble.integrate(&[member], &|_t, y, p| {
let mu = p[0];
vec![y[1], mu * (1.0 - y[0] * y[0]) * y[1] - y[0]]
});
assert!(result.success[0], "van-der-Pol did not converge");
for &v in &result.solutions[0] {
assert!(v.is_finite(), "van-der-Pol solution is non-finite");
}
}
#[test]
fn test_simulated_dispatch_matches_sequential() {
let config_seq = OdeEnsembleConfig {
t_span: [0.0, 1.0],
rtol: 1e-7,
atol: 1e-9,
max_steps: 50_000,
dispatch: EnsembleDispatch::Sequential,
};
let config_sim = OdeEnsembleConfig {
dispatch: EnsembleDispatch::Simulated,
..config_seq.clone()
};
let members: Vec<EnsembleMember> = vec![
EnsembleMember {
params: vec![1.0],
y0: vec![1.0],
},
EnsembleMember {
params: vec![2.0],
y0: vec![3.0],
},
];
let ens_seq = OdeEnsemble::new(config_seq);
let ens_sim = OdeEnsemble::new(config_sim);
let rhs = &|_t: f64, y: &[f64], p: &[f64]| vec![-p[0] * y[0]];
let res_seq = ens_seq.integrate(&members, rhs);
let res_sim = ens_sim.integrate(&members, rhs);
for i in 0..members.len() {
assert!(
(res_seq.solutions[i][0] - res_sim.solutions[i][0]).abs() < 1e-14,
"member {i}: sequential={:.6e}, simulated={:.6e}",
res_seq.solutions[i][0],
res_sim.solutions[i][0]
);
}
}
}