use scivex_core::Float;
use super::{OdeOptions, OdeResult};
use crate::error::{OptimError, Result};
const FIXED_POINT_MAX_ITER: usize = 50;
const FIXED_POINT_TOL: f64 = 1e-10;
#[allow(clippy::too_many_lines)]
pub fn bdf2<T, F>(f: F, t_span: [T; 2], y0: &[T], options: &OdeOptions<T>) -> Result<OdeResult<T>>
where
T: Float,
F: Fn(T, &[T]) -> Vec<T>,
{
let t0 = t_span[0];
let tf = t_span[1];
let n = y0.len();
let h = options
.first_step
.unwrap_or_else(|| (tf - t0) / T::from_f64(200.0));
let max_steps = options.max_steps;
let mut t = t0;
let mut t_values = vec![t];
let mut y_values = vec![y0.to_vec()];
let mut n_evals: usize = 0;
let mut n_steps: usize = 0;
let t1 = t + h.min(tf - t);
let h_actual = t1 - t;
let mut y1 = y0.to_vec();
let dy0 = f(t, &y1);
n_evals += 1;
for i in 0..n {
y1[i] += h_actual * dy0[i];
}
for _ in 0..FIXED_POINT_MAX_ITER {
let dy = f(t1, &y1);
n_evals += 1;
let mut max_diff = T::zero();
for i in 0..n {
let y_new = y0[i] + h_actual * dy[i];
let diff = (y_new - y1[i]).abs();
if diff > max_diff {
max_diff = diff;
}
y1[i] = y_new;
}
if max_diff < T::from_f64(FIXED_POINT_TOL) {
break;
}
}
t = t1;
t_values.push(t);
y_values.push(y1.clone());
n_steps += 1;
if t >= tf {
return Ok(OdeResult {
t: t_values,
y: y_values,
n_evals,
n_steps,
success: true,
});
}
let four_thirds = T::from_f64(4.0 / 3.0);
let one_third = T::from_f64(1.0 / 3.0);
let two_thirds = T::from_f64(2.0 / 3.0);
while t < tf {
if n_steps >= max_steps {
return Err(OptimError::ConvergenceFailure {
iterations: n_steps,
});
}
let step = h.min(tf - t);
let t_next = t + step;
let y_nm1 = &y_values[y_values.len() - 2];
let y_n = &y_values[y_values.len() - 1];
let mut y_next = vec![T::zero(); n];
for i in 0..n {
y_next[i] = four_thirds * y_n[i] - one_third * y_nm1[i];
}
let mut converged = false;
for _ in 0..FIXED_POINT_MAX_ITER {
let dy = f(t_next, &y_next);
n_evals += 1;
let mut max_diff = T::zero();
for i in 0..n {
let y_new = four_thirds * y_n[i] - one_third * y_nm1[i] + two_thirds * step * dy[i];
let diff = (y_new - y_next[i]).abs();
if diff > max_diff {
max_diff = diff;
}
y_next[i] = y_new;
}
if max_diff < T::from_f64(FIXED_POINT_TOL) {
converged = true;
break;
}
}
if !converged {
return Err(OptimError::ConvergenceFailure {
iterations: n_steps,
});
}
t = t_next;
t_values.push(t);
y_values.push(y_next);
n_steps += 1;
if let Some(ref event_fn) = options.event_fn {
let y_cur = &y_values[y_values.len() - 1];
let val = event_fn(t, y_cur);
if val.abs() < T::from_f64(1e-12)
|| (t_values.len() > 1 && {
let prev_y = &y_values[y_values.len() - 2];
let prev_t = t_values[t_values.len() - 2];
let prev_val = event_fn(prev_t, prev_y);
(prev_val > T::zero()) != (val > T::zero())
})
{
break;
}
}
}
Ok(OdeResult {
t: t_values,
y: y_values,
n_evals,
n_steps,
success: true,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bdf2_exponential_decay() {
let result = bdf2(
|_t: f64, y: &[f64]| vec![-y[0]],
[0.0, 1.0],
&[1.0],
&OdeOptions::default(),
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
let expected = (-1.0_f64).exp();
assert!(
(y_final - expected).abs() < 1e-4,
"y_final={y_final}, expected={expected}"
);
}
#[test]
fn test_bdf2_stiff_system() {
let result = bdf2(
|_t: f64, y: &[f64]| vec![-50.0 * y[0]],
[0.0, 0.5],
&[1.0],
&OdeOptions {
first_step: Some(0.002),
max_steps: 5000,
..OdeOptions::default()
},
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
let expected = (-25.0_f64).exp();
assert!(
(y_final - expected).abs() < 1e-3,
"y_final={y_final}, expected={expected}, err={}",
(y_final - expected).abs()
);
assert!(result.success);
}
#[test]
fn test_bdf2_linear() {
let result = bdf2(
|_t: f64, _y: &[f64]| vec![1.0],
[0.0, 2.0],
&[0.0],
&OdeOptions::default(),
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
assert!(
(y_final - 2.0).abs() < 1e-6,
"y_final={y_final}, expected=2.0"
);
}
#[test]
fn test_bdf2_system() {
let result = bdf2(
|_t: f64, y: &[f64]| vec![-20.0 * y[0] + y[1], y[0] - 20.0 * y[1]],
[0.0, 1.0],
&[1.0, 0.0],
&OdeOptions {
first_step: Some(0.002),
max_steps: 5000,
..OdeOptions::default()
},
)
.unwrap();
let y_final = &result.y.last().unwrap();
assert!(
y_final[0].abs() < 1e-3,
"y0={} should be near zero",
y_final[0]
);
assert!(
y_final[1].abs() < 1e-3,
"y1={} should be near zero",
y_final[1]
);
}
}