use std::fmt;
pub trait OdeFunc: Send + Sync {
fn call(&self, t: f64, y: &[f64], params: &[f64]) -> Vec<f64>;
fn vjp(
&self,
t: f64,
y: &[f64],
params: &[f64],
grad_output: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
let eps = 1e-6_f64;
let n = y.len();
let p = params.len();
let mut grad_y = vec![0.0_f64; n];
for i in 0..n {
let mut y_plus = y.to_vec();
let mut y_minus = y.to_vec();
y_plus[i] += eps;
y_minus[i] -= eps;
let f_plus = self.call(t, &y_plus, params);
let f_minus = self.call(t, &y_minus, params);
for (k, go) in grad_output.iter().enumerate() {
grad_y[i] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
}
}
let f_tplus = self.call(t + eps, y, params);
let f_tminus = self.call(t - eps, y, params);
let grad_t: f64 = grad_output
.iter()
.enumerate()
.map(|(k, go)| go * (f_tplus[k] - f_tminus[k]) / (2.0 * eps))
.sum();
let mut grad_params = vec![0.0_f64; p];
for j in 0..p {
let mut p_plus = params.to_vec();
let mut p_minus = params.to_vec();
p_plus[j] += eps;
p_minus[j] -= eps;
let f_plus = self.call(t, y, &p_plus);
let f_minus = self.call(t, y, &p_minus);
for (k, go) in grad_output.iter().enumerate() {
grad_params[j] += go * (f_plus[k] - f_minus[k]) / (2.0 * eps);
}
}
(grad_y, grad_t, grad_params)
}
}
#[derive(Debug, Clone)]
pub struct OdeSolution {
pub times: Vec<f64>,
pub states: Vec<Vec<f64>>,
pub nfev: usize,
}
#[derive(Debug, Clone)]
pub struct AdaptiveSolution {
pub solution: OdeSolution,
pub rejected_steps: usize,
pub final_step_size: f64,
}
#[derive(Debug, Clone)]
pub struct AdjointResult {
pub final_state: Vec<f64>,
pub grad_y0: Vec<f64>,
pub grad_params: Vec<f64>,
pub total_nfev: usize,
}
#[derive(Debug, Clone)]
pub struct OdeSolverConfig {
pub rtol: f64,
pub atol: f64,
pub max_steps: usize,
pub min_step: f64,
pub max_step: f64,
pub dense_output: bool,
}
impl Default for OdeSolverConfig {
fn default() -> Self {
Self {
rtol: 1e-4,
atol: 1e-6,
max_steps: 1000,
min_step: 1e-12,
max_step: f64::INFINITY,
dense_output: true,
}
}
}
impl OdeSolverConfig {
pub fn new() -> Self {
Self::default()
}
pub fn rtol(mut self, v: f64) -> Self {
self.rtol = v;
self
}
pub fn atol(mut self, v: f64) -> Self {
self.atol = v;
self
}
pub fn max_steps(mut self, n: usize) -> Self {
self.max_steps = n;
self
}
pub fn no_dense_output(mut self) -> Self {
self.dense_output = false;
self
}
}
#[derive(Debug)]
pub enum OdeError {
MaxStepsExceeded,
StepTooSmall,
DivergentSolution,
InvalidInput(String),
}
impl fmt::Display for OdeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OdeError::MaxStepsExceeded => write!(
f,
"ODE solver exceeded the maximum number of steps; \
consider relaxing tolerances or increasing max_steps"
),
OdeError::StepTooSmall => write!(
f,
"ODE solver step size fell below the minimum threshold; \
the problem may be too stiff for this explicit solver"
),
OdeError::DivergentSolution => write!(
f,
"ODE solution diverged (NaN or Inf encountered in state vector)"
),
OdeError::InvalidInput(msg) => {
write!(f, "ODE solver received invalid input: {msg}")
}
}
}
}
impl std::error::Error for OdeError {}
#[inline]
#[allow(dead_code)]
fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
#[inline]
#[allow(dead_code)]
fn vec_scale(v: &[f64], s: f64) -> Vec<f64> {
v.iter().map(|x| x * s).collect()
}
#[inline]
fn vec_axpy(y: &[f64], alpha: f64, x: &[f64]) -> Vec<f64> {
y.iter()
.zip(x.iter())
.map(|(yi, xi)| yi + alpha * xi)
.collect()
}
fn error_norm(err: &[f64], y: &[f64], y_new: &[f64], rtol: f64, atol: f64) -> f64 {
let n = err.len();
if n == 0 {
return 0.0;
}
let sum: f64 = err
.iter()
.zip(y.iter())
.zip(y_new.iter())
.map(|((e, yi), yn)| {
let sc = atol + rtol * yi.abs().max(yn.abs());
(e / sc).powi(2)
})
.sum();
(sum / n as f64).sqrt()
}
fn has_diverged(v: &[f64]) -> bool {
v.iter().any(|x| x.is_nan() || x.is_infinite())
}
pub fn rk4_solve(
func: &dyn OdeFunc,
t0: f64,
t1: f64,
y0: &[f64],
params: &[f64],
num_steps: usize,
) -> OdeSolution {
let steps = num_steps.max(1);
let h = (t1 - t0) / steps as f64;
let mut times = Vec::with_capacity(steps + 1);
let mut states = Vec::with_capacity(steps + 1);
let mut nfev = 0usize;
times.push(t0);
states.push(y0.to_vec());
let mut t = t0;
let mut y = y0.to_vec();
for _ in 0..steps {
let k1 = func.call(t, &y, params);
nfev += 1;
let y2 = vec_axpy(&y, h * 0.5, &k1);
let k2 = func.call(t + h * 0.5, &y2, params);
nfev += 1;
let y3 = vec_axpy(&y, h * 0.5, &k2);
let k3 = func.call(t + h * 0.5, &y3, params);
nfev += 1;
let y4 = vec_axpy(&y, h, &k3);
let k4 = func.call(t + h, &y4, params);
nfev += 1;
y = y
.iter()
.zip(k1.iter())
.zip(k2.iter())
.zip(k3.iter())
.zip(k4.iter())
.map(|((((yi, k1i), k2i), k3i), k4i)| {
yi + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
})
.collect();
t += h;
times.push(t);
states.push(y.clone());
}
OdeSolution {
times,
states,
nfev,
}
}
const DOPRI5_A21: f64 = 1.0 / 5.0;
const DOPRI5_A31: f64 = 3.0 / 40.0;
const DOPRI5_A32: f64 = 9.0 / 40.0;
const DOPRI5_A41: f64 = 44.0 / 45.0;
const DOPRI5_A42: f64 = -56.0 / 15.0;
const DOPRI5_A43: f64 = 32.0 / 9.0;
const DOPRI5_A51: f64 = 19372.0 / 6561.0;
const DOPRI5_A52: f64 = -25360.0 / 2187.0;
const DOPRI5_A53: f64 = 64448.0 / 6561.0;
const DOPRI5_A54: f64 = -212.0 / 729.0;
const DOPRI5_A61: f64 = 9017.0 / 3168.0;
const DOPRI5_A62: f64 = -355.0 / 33.0;
const DOPRI5_A63: f64 = 46732.0 / 5247.0;
const DOPRI5_A64: f64 = 49.0 / 176.0;
const DOPRI5_A65: f64 = -5103.0 / 18656.0;
const DOPRI5_A71: f64 = 35.0 / 384.0;
const DOPRI5_A73: f64 = 500.0 / 1113.0;
const DOPRI5_A74: f64 = 125.0 / 192.0;
const DOPRI5_A75: f64 = -2187.0 / 6784.0;
const DOPRI5_A76: f64 = 11.0 / 84.0;
const DOPRI5_E1: f64 = 71.0 / 57600.0;
const DOPRI5_E3: f64 = -71.0 / 16695.0;
const DOPRI5_E4: f64 = 71.0 / 1920.0;
const DOPRI5_E5: f64 = -17253.0 / 339200.0;
const DOPRI5_E6: f64 = 22.0 / 525.0;
const DOPRI5_E7: f64 = -1.0 / 40.0;
const DOPRI5_SAFETY: f64 = 0.9;
const DOPRI5_MIN_FACTOR: f64 = 0.2;
const DOPRI5_MAX_FACTOR: f64 = 10.0;
const DOPRI5_ORDER: f64 = 5.0;
pub fn dopri5_solve(
func: &dyn OdeFunc,
t0: f64,
t1: f64,
y0: &[f64],
params: &[f64],
config: &OdeSolverConfig,
) -> Result<AdaptiveSolution, OdeError> {
if t0 == t1 {
return Ok(AdaptiveSolution {
solution: OdeSolution {
times: vec![t0],
states: vec![y0.to_vec()],
nfev: 0,
},
rejected_steps: 0,
final_step_size: 0.0,
});
}
if y0.is_empty() {
return Err(OdeError::InvalidInput("state vector is empty".into()));
}
let forward = t1 > t0;
let sign = if forward { 1.0_f64 } else { -1.0_f64 };
let span = (t1 - t0).abs();
let f0 = func.call(t0, y0, params);
let d0 = (y0.iter().map(|x| x * x).sum::<f64>() / y0.len() as f64).sqrt();
let d1 = (f0.iter().map(|x| x * x).sum::<f64>() / f0.len() as f64).sqrt();
let h0 = if d0 < 1e-5 || d1 < 1e-5 {
1e-6
} else {
0.01 * d0 / d1
};
let mut h = sign * h0.min(span).min(config.max_step);
let mut t = t0;
let mut y = y0.to_vec();
let mut k1 = f0;
let mut nfev = 1usize;
let mut times = vec![t0];
let mut states = vec![y0.to_vec()];
let mut rejected_steps = 0usize;
let mut steps = 0usize;
while (sign * (t1 - t)).abs() > f64::EPSILON * span.max(1.0) {
if steps >= config.max_steps {
return Err(OdeError::MaxStepsExceeded);
}
if (t + h - t1) * sign > 0.0 {
h = t1 - t;
}
let h_abs = h.abs();
if h_abs < config.min_step {
return Err(OdeError::StepTooSmall);
}
let y2 = vec_axpy(&y, DOPRI5_A21 * h, &k1);
let k2 = func.call(t + h / 5.0, &y2, params);
nfev += 1;
let y3: Vec<f64> = y
.iter()
.zip(k1.iter())
.zip(k2.iter())
.map(|((yi, k1i), k2i)| yi + h * (DOPRI5_A31 * k1i + DOPRI5_A32 * k2i))
.collect();
let k3 = func.call(t + h * 3.0 / 10.0, &y3, params);
nfev += 1;
let y4: Vec<f64> = y
.iter()
.zip(k1.iter())
.zip(k2.iter())
.zip(k3.iter())
.map(|(((yi, k1i), k2i), k3i)| {
yi + h * (DOPRI5_A41 * k1i + DOPRI5_A42 * k2i + DOPRI5_A43 * k3i)
})
.collect();
let k4 = func.call(t + h * 4.0 / 5.0, &y4, params);
nfev += 1;
let y5: Vec<f64> = y
.iter()
.zip(k1.iter())
.zip(k2.iter())
.zip(k3.iter())
.zip(k4.iter())
.map(|((((yi, k1i), k2i), k3i), k4i)| {
yi + h * (DOPRI5_A51 * k1i + DOPRI5_A52 * k2i + DOPRI5_A53 * k3i + DOPRI5_A54 * k4i)
})
.collect();
let k5 = func.call(t + h * 8.0 / 9.0, &y5, params);
nfev += 1;
let y6: Vec<f64> = y
.iter()
.zip(k1.iter())
.zip(k2.iter())
.zip(k3.iter())
.zip(k4.iter())
.zip(k5.iter())
.map(|(((((yi, k1i), k2i), k3i), k4i), k5i)| {
yi + h
* (DOPRI5_A61 * k1i
+ DOPRI5_A62 * k2i
+ DOPRI5_A63 * k3i
+ DOPRI5_A64 * k4i
+ DOPRI5_A65 * k5i)
})
.collect();
let k6 = func.call(t + h, &y6, params);
nfev += 1;
let y_new: Vec<f64> = y
.iter()
.zip(k1.iter())
.zip(k3.iter())
.zip(k4.iter())
.zip(k5.iter())
.zip(k6.iter())
.map(|(((((yi, k1i), k3i), k4i), k5i), k6i)| {
yi + h
* (DOPRI5_A71 * k1i
+ DOPRI5_A73 * k3i
+ DOPRI5_A74 * k4i
+ DOPRI5_A75 * k5i
+ DOPRI5_A76 * k6i)
})
.collect();
if has_diverged(&y_new) {
return Err(OdeError::DivergentSolution);
}
let k7 = func.call(t + h, &y_new, params);
nfev += 1;
let err: Vec<f64> = k1
.iter()
.zip(k3.iter())
.zip(k4.iter())
.zip(k5.iter())
.zip(k6.iter())
.zip(k7.iter())
.map(|(((((e1, e3), e4), e5), e6), e7)| {
h * (DOPRI5_E1 * e1
+ DOPRI5_E3 * e3
+ DOPRI5_E4 * e4
+ DOPRI5_E5 * e5
+ DOPRI5_E6 * e6
+ DOPRI5_E7 * e7)
})
.collect();
let error_norm_val = error_norm(&err, &y, &y_new, config.rtol, config.atol);
if error_norm_val <= 1.0 {
t += h;
y = y_new;
k1 = k7;
if config.dense_output {
times.push(t);
states.push(y.clone());
}
steps += 1;
let factor = if error_norm_val == 0.0 {
DOPRI5_MAX_FACTOR
} else {
(DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
.clamp(DOPRI5_MIN_FACTOR, DOPRI5_MAX_FACTOR)
};
h *= factor;
h = h.abs().min(config.max_step) * sign;
} else {
rejected_steps += 1;
let factor = (DOPRI5_SAFETY * error_norm_val.powf(-1.0 / DOPRI5_ORDER))
.clamp(DOPRI5_MIN_FACTOR, 1.0);
h *= factor;
}
}
if !config.dense_output || times.last().map(|&last| last != t).unwrap_or(true) {
times.push(t);
states.push(y.clone());
}
Ok(AdaptiveSolution {
solution: OdeSolution {
times,
states,
nfev,
},
rejected_steps,
final_step_size: h.abs(),
})
}
pub struct NeuralOde<F: OdeFunc> {
func: F,
t0: f64,
t1: f64,
config: OdeSolverConfig,
}
impl<F: OdeFunc> NeuralOde<F> {
pub fn new(func: F, t0: f64, t1: f64) -> Self {
Self {
func,
t0,
t1,
config: OdeSolverConfig::default(),
}
}
pub fn with_config(func: F, t0: f64, t1: f64, config: OdeSolverConfig) -> Self {
Self {
func,
t0,
t1,
config,
}
}
pub fn forward(&self, y0: &[f64], params: &[f64]) -> Result<OdeSolution, OdeError> {
if y0.is_empty() {
return Err(OdeError::InvalidInput("initial state is empty".into()));
}
let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &self.config)?;
Ok(adaptive.solution)
}
pub fn adjoint(
&self,
y0: &[f64],
params: &[f64],
grad_output: &[f64],
) -> Result<AdjointResult, OdeError> {
if y0.len() != grad_output.len() {
return Err(OdeError::InvalidInput(format!(
"grad_output length {} does not match state dimension {}",
grad_output.len(),
y0.len()
)));
}
let fwd_config = OdeSolverConfig {
dense_output: true,
..self.config.clone()
};
let adaptive = dopri5_solve(&self.func, self.t0, self.t1, y0, params, &fwd_config)?;
let fwd_nfev = adaptive.solution.nfev;
let adj_result = adjoint_backward(
&self.func,
&adaptive.solution,
params,
grad_output,
&self.config,
);
Ok(AdjointResult {
total_nfev: fwd_nfev + adj_result.total_nfev,
..adj_result
})
}
}
fn adjoint_backward(
func: &dyn OdeFunc,
solution: &OdeSolution,
params: &[f64],
grad_output: &[f64],
_config: &OdeSolverConfig,
) -> AdjointResult {
let n_state = grad_output.len();
let n_params = params.len();
let final_state = solution
.states
.last()
.cloned()
.unwrap_or_else(|| grad_output.to_vec());
let mut a = grad_output.to_vec();
let mut grad_params = vec![0.0_f64; n_params];
let mut total_nfev = 0usize;
let adj_steps_per_interval = 4usize;
let n_intervals = solution.times.len().saturating_sub(1);
for interval_idx in (0..n_intervals).rev() {
let t_start = solution.times[interval_idx + 1];
let t_end = solution.times[interval_idx];
let y_start = &solution.states[interval_idx + 1];
let y_end = &solution.states[interval_idx];
let h = (t_end - t_start) / adj_steps_per_interval as f64;
let mut t_cur = t_start;
for step_idx in 0..adj_steps_per_interval {
let alpha = step_idx as f64 / adj_steps_per_interval as f64;
let y_interp: Vec<f64> = y_start
.iter()
.zip(y_end.iter())
.map(|(ys, ye)| ys + alpha * (ye - ys))
.collect();
let aug_rhs =
|t_local: f64, a_local: &[f64], y_local: &[f64]| -> (Vec<f64>, Vec<f64>) {
let (da_dy, _da_dt, da_dp) = func.vjp(t_local, y_local, params, a_local);
let a_dot: Vec<f64> = da_dy.iter().map(|x| -x).collect();
let gp_dot: Vec<f64> = da_dp.iter().map(|x| -x).collect();
(a_dot, gp_dot)
};
let (k1_a, k1_gp) = aug_rhs(t_cur, &a, &y_interp);
total_nfev += 1;
let a2 = vec_axpy(&a, h * 0.5, &k1_a);
let alpha2 = (step_idx as f64 + 0.5) / adj_steps_per_interval as f64;
let y2: Vec<f64> = y_start
.iter()
.zip(y_end.iter())
.map(|(ys, ye)| ys + alpha2 * (ye - ys))
.collect();
let (k2_a, k2_gp) = aug_rhs(t_cur + h * 0.5, &a2, &y2);
total_nfev += 1;
let a3 = vec_axpy(&a, h * 0.5, &k2_a);
let (k3_a, k3_gp) = aug_rhs(t_cur + h * 0.5, &a3, &y2);
total_nfev += 1;
let a4 = vec_axpy(&a, h, &k3_a);
let alpha_end = (step_idx + 1) as f64 / adj_steps_per_interval as f64;
let y4: Vec<f64> = y_start
.iter()
.zip(y_end.iter())
.map(|(ys, ye)| ys + alpha_end * (ye - ys))
.collect();
let (k4_a, k4_gp) = aug_rhs(t_cur + h, &a4, &y4);
total_nfev += 1;
a = a
.iter()
.zip(k1_a.iter())
.zip(k2_a.iter())
.zip(k3_a.iter())
.zip(k4_a.iter())
.map(|((((ai, k1i), k2i), k3i), k4i)| {
ai + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
})
.collect();
grad_params = grad_params
.iter()
.zip(k1_gp.iter())
.zip(k2_gp.iter())
.zip(k3_gp.iter())
.zip(k4_gp.iter())
.map(|((((gp, k1i), k2i), k3i), k4i)| {
gp + h / 6.0 * (k1i + 2.0 * k2i + 2.0 * k3i + k4i)
})
.collect();
t_cur += h;
}
let _ = n_state; }
AdjointResult {
final_state,
grad_y0: a,
grad_params,
total_nfev,
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ConstantOde;
impl OdeFunc for ConstantOde {
fn call(&self, _t: f64, _y: &[f64], _params: &[f64]) -> Vec<f64> {
vec![0.0]
}
fn vjp(
&self,
_t: f64,
_y: &[f64],
_params: &[f64],
_grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
(vec![0.0], 0.0, vec![])
}
}
struct ExponentialGrowthOde;
impl OdeFunc for ExponentialGrowthOde {
fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
vec![y[0]]
}
fn vjp(
&self,
_t: f64,
_y: &[f64],
_params: &[f64],
grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
(grad.to_vec(), 0.0, vec![])
}
}
struct ExponentialDecayOde;
impl OdeFunc for ExponentialDecayOde {
fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
vec![-y[0]]
}
fn vjp(
&self,
_t: f64,
_y: &[f64],
_params: &[f64],
grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
(grad.iter().map(|g| -g).collect(), 0.0, vec![])
}
}
struct OscillatorOde;
impl OdeFunc for OscillatorOde {
fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
vec![y[1], -y[0]]
}
fn vjp(
&self,
_t: f64,
_y: &[f64],
_params: &[f64],
grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
let ga = grad[1]; let gb = grad[0]; (vec![-ga, gb], 0.0, vec![])
}
}
struct LinearParamOde;
impl OdeFunc for LinearParamOde {
fn call(&self, _t: f64, y: &[f64], params: &[f64]) -> Vec<f64> {
vec![params[0] * y[0]]
}
fn vjp(
&self,
_t: f64,
y: &[f64],
params: &[f64],
grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
let grad_y = vec![grad[0] * params[0]];
let grad_p = vec![grad[0] * y[0]];
(grad_y, 0.0, grad_p)
}
}
struct StiffOde;
impl OdeFunc for StiffOde {
fn call(&self, _t: f64, y: &[f64], _params: &[f64]) -> Vec<f64> {
vec![-1000.0 * y[0]]
}
fn vjp(
&self,
_t: f64,
_y: &[f64],
_params: &[f64],
grad: &[f64],
) -> (Vec<f64>, f64, Vec<f64>) {
(grad.iter().map(|g| -1000.0 * g).collect(), 0.0, vec![])
}
}
#[test]
fn test_rk4_constant_ode() {
let init_val = 42.0_f64;
let sol = rk4_solve(&ConstantOde, 0.0, 1.0, &[init_val], &[], 100);
let final_y = sol.states.last().unwrap()[0];
assert!(
(final_y - init_val).abs() < 1e-12,
"constant ODE should stay at {init_val}, got {final_y}"
);
}
#[test]
fn test_rk4_exponential_growth() {
let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10_000);
let final_y = sol.states.last().unwrap()[0];
let exact = std::f64::consts::E;
assert!(
(final_y - exact).abs() < 1e-6,
"RK4 exponential growth: got {final_y}, expected {exact}"
);
}
#[test]
fn test_rk4_exponential_decay() {
let sol = rk4_solve(&ExponentialDecayOde, 0.0, 1.0, &[1.0], &[], 10_000);
let final_y = sol.states.last().unwrap()[0];
let exact = (-1.0_f64).exp();
assert!(
(final_y - exact).abs() < 1e-6,
"RK4 exponential decay: got {final_y}, expected {exact}"
);
}
#[test]
fn test_rk4_oscillator_2d() {
use std::f64::consts::PI;
let sol = rk4_solve(&OscillatorOde, 0.0, 2.0 * PI, &[1.0, 0.0], &[], 100_000);
let last = sol.states.last().unwrap();
assert!(
(last[0] - 1.0).abs() < 1e-4,
"oscillator x: got {}",
last[0]
);
assert!(last[1].abs() < 1e-4, "oscillator y: got {}", last[1]);
}
#[test]
fn test_dopri5_more_accurate_than_rk4() {
let exact = std::f64::consts::E;
let rk4_sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
let rk4_err = (rk4_sol.states.last().unwrap()[0] - exact).abs();
let config = OdeSolverConfig::new().rtol(1e-8).atol(1e-10);
let dp5 = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
let dp5_err = (dp5.solution.states.last().unwrap()[0] - exact).abs();
assert!(
dp5_err < rk4_err,
"DOPRI5 (tight tol) error {dp5_err} should be less than coarse RK4 error {rk4_err}"
);
assert!(
dp5_err < 1e-6,
"DOPRI5 with rtol=1e-8/atol=1e-10 should achieve < 1e-6 error, got {dp5_err}"
);
}
#[test]
fn test_dopri5_step_rejection_on_stiff() {
let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8).max_steps(5000);
let result = dopri5_solve(&StiffOde, 0.0, 0.01, &[1.0], &[], &config);
match result {
Ok(adaptive) => {
let _ = adaptive.rejected_steps;
}
Err(OdeError::StepTooSmall) | Err(OdeError::MaxStepsExceeded) => {
}
Err(e) => panic!("unexpected error: {e}"),
}
}
#[test]
fn test_solver_config_builder() {
let cfg = OdeSolverConfig::new().rtol(1e-8).atol(1e-10).max_steps(500);
assert!((cfg.rtol - 1e-8).abs() < 1e-15);
assert!((cfg.atol - 1e-10).abs() < 1e-18);
assert_eq!(cfg.max_steps, 500);
}
#[test]
fn test_neural_ode_forward_correct_endpoint() {
let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
let sol = ode.forward(&[1.0], &[]).unwrap();
let final_y = sol.states.last().unwrap()[0];
let exact = std::f64::consts::E;
assert!(
(final_y - exact).abs() < 1e-4,
"NeuralOde forward: got {final_y}, expected ~{exact}"
);
}
#[test]
fn test_neural_ode_forward_t0_equals_t1() {
let init_val = 7.5_f64; let ode = NeuralOde::new(ExponentialGrowthOde, 1.5, 1.5);
let sol = ode.forward(&[init_val], &[]).unwrap();
assert!((sol.states[0][0] - init_val).abs() < 1e-12);
}
#[test]
fn test_max_steps_exceeded_on_stiff() {
let config = OdeSolverConfig::new().rtol(1e-12).atol(1e-14).max_steps(5); let result = dopri5_solve(&StiffOde, 0.0, 1.0, &[1.0], &[], &config);
assert!(
matches!(
result,
Err(OdeError::MaxStepsExceeded) | Err(OdeError::StepTooSmall)
),
"expected MaxStepsExceeded or StepTooSmall"
);
}
#[test]
fn test_nfev_is_positive() {
let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], 10);
assert!(sol.nfev > 0, "nfev should be > 0, got {}", sol.nfev);
assert_eq!(sol.nfev, 40, "RK4 should use 4 * num_steps evaluations");
}
#[test]
fn test_rejected_steps_field_exists() {
let config = OdeSolverConfig::new();
let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
let _ = adaptive.rejected_steps; assert!(adaptive.solution.nfev > 0);
}
#[test]
fn test_dense_output_stores_intermediate_steps() {
let config = OdeSolverConfig::new().rtol(1e-6).atol(1e-8);
let adaptive = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &config).unwrap();
assert!(
adaptive.solution.times.len() > 2,
"dense output should contain more than 2 time points, got {}",
adaptive.solution.times.len()
);
assert_eq!(
adaptive.solution.times.len(),
adaptive.solution.states.len(),
"times and states must have the same length"
);
}
#[test]
fn test_adjoint_grad_y0_dimension() {
let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
let y0 = vec![1.0_f64];
let params = vec![-1.0_f64];
let grad_out = vec![1.0_f64];
let adj = ode.adjoint(&y0, ¶ms, &grad_out).unwrap();
assert_eq!(
adj.grad_y0.len(),
y0.len(),
"grad_y0 must have same dim as y0"
);
}
#[test]
fn test_adjoint_grad_params_dimension() {
let ode = NeuralOde::new(LinearParamOde, 0.0, 0.5);
let y0 = vec![1.0_f64];
let params = vec![-1.0_f64];
let grad_out = vec![1.0_f64];
let adj = ode.adjoint(&y0, ¶ms, &grad_out).unwrap();
assert_eq!(
adj.grad_params.len(),
params.len(),
"grad_params must have same dim as params"
);
}
#[test]
fn test_ode_error_display() {
let msgs = [
(OdeError::MaxStepsExceeded, "max"),
(OdeError::StepTooSmall, "step"),
(OdeError::DivergentSolution, "diverged"),
(OdeError::InvalidInput("bad".into()), "bad"),
];
for (err, keyword) in msgs {
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains(keyword),
"Display for {err:?} should contain '{keyword}', got: '{msg}'"
);
}
}
#[test]
fn test_forward_is_deterministic() {
let ode = NeuralOde::new(ExponentialGrowthOde, 0.0, 1.0);
let sol1 = ode.forward(&[1.0], &[]).unwrap();
let sol2 = ode.forward(&[1.0], &[]).unwrap();
let y1 = sol1.states.last().unwrap()[0];
let y2 = sol2.states.last().unwrap()[0];
assert_eq!(y1, y2, "repeated forward passes must be deterministic");
}
#[test]
fn test_rk4_convergence_with_steps() {
let exact = std::f64::consts::E;
let steps_list = [10usize, 100, 1000, 10_000];
let mut prev_err = f64::INFINITY;
for &n in &steps_list {
let sol = rk4_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], n);
let err = (sol.states.last().unwrap()[0] - exact).abs();
assert!(
err < prev_err,
"error {err} at n={n} is not less than prev {prev_err}"
);
prev_err = err;
}
assert!(
prev_err < 1e-13,
"RK4 with 10_000 steps: error {prev_err} > 1e-13"
);
}
#[test]
fn test_dopri5_tolerance_affects_accuracy() {
let exact = std::f64::consts::E;
let coarse = OdeSolverConfig::new().rtol(1e-3).atol(1e-5);
let fine = OdeSolverConfig::new().rtol(1e-9).atol(1e-11);
let sol_coarse =
dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &coarse).unwrap();
let sol_fine = dopri5_solve(&ExponentialGrowthOde, 0.0, 1.0, &[1.0], &[], &fine).unwrap();
let err_coarse = (sol_coarse.solution.states.last().unwrap()[0] - exact).abs();
let err_fine = (sol_fine.solution.states.last().unwrap()[0] - exact).abs();
assert!(
err_fine < err_coarse,
"fine tol error {err_fine} should be less than coarse tol error {err_coarse}"
);
}
#[test]
fn test_neural_ode_params_affect_trajectory() {
let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
let sol_pos = ode.forward(&[1.0], &[1.0]).unwrap(); let sol_neg = ode.forward(&[1.0], &[-1.0]).unwrap();
let y_pos = sol_pos.states.last().unwrap()[0];
let y_neg = sol_neg.states.last().unwrap()[0];
assert!(
y_pos > y_neg,
"positive param should give larger y: y_pos={y_pos}, y_neg={y_neg}"
);
assert!(
(y_pos - std::f64::consts::E).abs() < 1e-3,
"y_pos ~ e, got {y_pos}"
);
assert!(
(y_neg - (-1.0_f64).exp()).abs() < 1e-3,
"y_neg ~ e^-1, got {y_neg}"
);
}
#[test]
fn test_adjoint_result_fields() {
let ode = NeuralOde::new(LinearParamOde, 0.0, 1.0);
let adj = ode.adjoint(&[1.0], &[-1.0], &[1.0]).unwrap();
assert!(adj.total_nfev > 0, "total_nfev should be > 0");
assert!(!adj.final_state.is_empty(), "final_state must not be empty");
assert!(!adj.grad_y0.is_empty(), "grad_y0 must not be empty");
assert_eq!(adj.grad_params.len(), 1);
}
#[test]
fn test_solution_first_state_is_y0() {
let y0 = vec![42.0_f64, -7.5];
let sol = rk4_solve(&OscillatorOde, 0.0, 1.0, &y0, &[], 100);
assert_eq!(&sol.states[0], &y0, "first stored state must equal y0");
assert_eq!(sol.times[0], 0.0);
}
}