use crate::integration::rk4_translational_step;
use crate::state::TranslationalState;
use glam::DVec3;
const WSCALE: f64 = 1.0 / 24.0;
const PREDICTOR_WEIGHTS: [f64; 4] = [55.0, -59.0, 37.0, -9.0];
const CORRECTOR_WEIGHTS: [f64; 4] = [9.0, 19.0, -5.0, 1.0];
const HIST_LEN: usize = 4;
#[derive(Debug, Clone)]
pub struct Abm4State {
posdot_hist: [DVec3; HIST_LEN],
veldot_hist: [DVec3; HIST_LEN],
primed_steps: usize,
topology_dirty: bool,
}
impl Default for Abm4State {
fn default() -> Self {
Self::new()
}
}
impl Abm4State {
pub fn new() -> Self {
Self {
posdot_hist: [DVec3::ZERO; HIST_LEN],
veldot_hist: [DVec3::ZERO; HIST_LEN],
primed_steps: 0,
topology_dirty: false,
}
}
pub fn reset(&mut self) {
self.posdot_hist = [DVec3::ZERO; HIST_LEN];
self.veldot_hist = [DVec3::ZERO; HIST_LEN];
self.primed_steps = 0;
self.topology_dirty = false;
}
pub fn reset_for_topology_change(&mut self) {
self.reset();
}
pub fn mark_topology_dirty(&mut self) {
self.topology_dirty = true;
}
pub fn is_topology_dirty(&self) -> bool {
self.topology_dirty
}
pub fn is_priming(&self) -> bool {
self.primed_steps < HIST_LEN - 1
}
fn rotate_history(&mut self) {
for j in (1..HIST_LEN).rev() {
self.posdot_hist[j] = self.posdot_hist[j - 1];
self.veldot_hist[j] = self.veldot_hist[j - 1];
}
}
fn save_priming_derivatives(&mut self, velocity: DVec3, accel: DVec3) {
self.rotate_history();
self.posdot_hist[0] = velocity;
self.veldot_hist[0] = accel;
}
}
pub fn abm4_translational_step(
state: &TranslationalState,
accel_fn: impl Fn(&TranslationalState, f64) -> DVec3,
dt: f64,
abm_state: &mut Abm4State,
) -> TranslationalState {
assert!(
dt.is_finite() && dt > 0.0,
"abm4_translational_step requires a finite positive dt, got {dt}"
);
assert!(
!abm_state.topology_dirty,
"abm4_translational_step called with stale predictor history: the body's \
mass / attachment topology changed but Abm4State::reset_for_topology_change() \
was not called. Wire the attach / detach handler to reset the integrator state \
(astrodyn::reset_integrators) — see JEOD's dyn_body_attach.cc::reset_integrators() \
and JEOD_invariants.md row IG.37."
);
if abm_state.is_priming() {
let accel_start = accel_fn(state, 0.0);
abm_state.save_priming_derivatives(state.velocity, accel_start);
abm_state.primed_steps += 1;
return rk4_translational_step(state, &accel_fn, dt);
}
let accel_n = accel_fn(state, 0.0);
let velocity_n = state.velocity;
abm_state.rotate_history();
abm_state.posdot_hist[0] = velocity_n;
abm_state.veldot_hist[0] = accel_n;
let mut weighted_posdot = DVec3::ZERO;
let mut weighted_veldot = DVec3::ZERO;
for (j, &w) in PREDICTOR_WEIGHTS.iter().enumerate() {
weighted_posdot += abm_state.posdot_hist[j] * w;
weighted_veldot += abm_state.veldot_hist[j] * w;
}
let wscaled_dt = WSCALE * dt;
let init_pos = state.position;
let init_vel = state.velocity;
let pred_pos = init_pos + weighted_posdot * wscaled_dt;
let pred_vel = init_vel + weighted_veldot * wscaled_dt;
let pred_state = TranslationalState {
position: pred_pos,
velocity: pred_vel,
};
let accel_pred = accel_fn(&pred_state, 1.0);
let der_w = CORRECTOR_WEIGHTS[0]; let hist_w = &CORRECTOR_WEIGHTS[1..];
let mut corr_posdot = der_w * pred_vel;
let mut corr_veldot = der_w * accel_pred;
for (j, &w) in hist_w.iter().enumerate() {
corr_posdot += abm_state.posdot_hist[j] * w;
corr_veldot += abm_state.veldot_hist[j] * w;
}
let final_pos = init_pos + corr_posdot * wscaled_dt;
let final_vel = init_vel + corr_veldot * wscaled_dt;
abm_state.primed_steps = abm_state.primed_steps.saturating_add(1);
TranslationalState {
position: final_pos,
velocity: final_vel,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn abm4_harmonic_oscillator() {
let dt = 0.01;
let steps = 628; let t_final = dt * steps as f64;
let mut state = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let mut abm = Abm4State::new();
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
for _ in 0..steps {
state = abm4_translational_step(&state, accel_fn, dt, &mut abm);
}
let exact_pos = t_final.cos();
let exact_vel = -t_final.sin();
let pos_err = (state.position.x - exact_pos).abs();
let vel_err = (state.velocity.x - exact_vel).abs();
assert!(pos_err < 1e-6, "ABM4 harmonic osc pos err: {pos_err:.2e}");
assert!(vel_err < 1e-6, "ABM4 harmonic osc vel err: {vel_err:.2e}");
}
#[test]
fn abm4_convergence_order() {
let dt_coarse = 0.01_f64;
let dt_fine = dt_coarse / 2.0;
let total_time = 10.0_f64;
let initial = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
let exact_pos = total_time.cos();
let steps_coarse = (total_time / dt_coarse).round() as usize;
let mut state_c = initial;
let mut abm_c = Abm4State::new();
for _ in 0..steps_coarse {
state_c = abm4_translational_step(&state_c, accel_fn, dt_coarse, &mut abm_c);
}
let err_coarse = (state_c.position.x - exact_pos).abs();
let steps_fine = (total_time / dt_fine).round() as usize;
let mut state_f = initial;
let mut abm_f = Abm4State::new();
for _ in 0..steps_fine {
state_f = abm4_translational_step(&state_f, accel_fn, dt_fine, &mut abm_f);
}
let err_fine = (state_f.position.x - exact_pos).abs();
let ratio = err_coarse / err_fine;
assert!(
(10.0..=25.0).contains(&ratio),
"ABM4 convergence ratio {ratio:.1} not consistent with 4th order \
(err_coarse={err_coarse:.3e}, err_fine={err_fine:.3e})"
);
}
#[test]
fn abm4_free_particle() {
let dt = 0.5;
let initial_pos = DVec3::new(1.0, 2.0, 3.0);
let initial_vel = DVec3::new(4.0, 5.0, 6.0);
let mut state = TranslationalState {
position: initial_pos,
velocity: initial_vel,
};
let mut abm = Abm4State::new();
let zero_accel = |_: &TranslationalState, _t: f64| DVec3::ZERO;
let n = 50;
for _ in 0..n {
state = abm4_translational_step(&state, zero_accel, dt, &mut abm);
}
let expected = initial_pos + initial_vel * (dt * n as f64);
let pos_err = (state.position - expected).length();
let vel_err = (state.velocity - initial_vel).length();
assert!(pos_err < 1e-11, "Free particle pos err: {pos_err}");
assert!(vel_err < 1e-14, "Free particle vel err: {vel_err}");
}
#[test]
fn abm4_comparable_to_rk4() {
let dt = 0.01_f64;
let steps = 1000;
let t_final = dt * steps as f64;
let initial = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
let mut s_rk4 = initial;
for _ in 0..steps {
s_rk4 = rk4_translational_step(&s_rk4, accel_fn, dt);
}
let mut s_abm = initial;
let mut abm = Abm4State::new();
for _ in 0..steps {
s_abm = abm4_translational_step(&s_abm, accel_fn, dt, &mut abm);
}
let exact = t_final.cos();
let err_rk4 = (s_rk4.position.x - exact).abs();
let err_abm = (s_abm.position.x - exact).abs();
let ratio = err_abm.max(err_rk4) / err_abm.min(err_rk4).max(1e-20);
assert!(
ratio < 100.0,
"ABM4 ({err_abm:.2e}) vs RK4 ({err_rk4:.2e}) differ by {ratio:.1}x"
);
}
#[test]
fn abm4_kepler_orbit() {
let mu: f64 = 3.986_004_415e14;
let r0: f64 = 7_000_000.0;
let v0 = (mu / r0).sqrt();
let period = 2.0 * std::f64::consts::PI * (r0.powi(3) / mu).sqrt();
let target_dt = 10.0_f64;
let steps = (period / target_dt).round() as usize;
let dt = period / steps as f64;
let mut state = TranslationalState {
position: DVec3::new(r0, 0.0, 0.0),
velocity: DVec3::new(0.0, v0, 0.0),
};
let mut abm = Abm4State::new();
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 {
let r = s.position.length();
-mu / (r * r * r) * s.position
};
for _ in 0..steps {
state = abm4_translational_step(&state, accel_fn, dt, &mut abm);
}
let pos_err = (state.position - DVec3::new(r0, 0.0, 0.0)).length();
assert!(
pos_err < 5_000.0,
"ABM4 Kepler orbit closure err: {pos_err:.2e} m"
);
}
#[test]
fn abm4_reset() {
let mut abm = Abm4State::new();
assert!(abm.is_priming());
let mut state = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
for _ in 0..5 {
state = abm4_translational_step(&state, accel_fn, 0.01, &mut abm);
}
assert!(!abm.is_priming());
abm.reset();
assert!(abm.is_priming());
assert_eq!(abm.primed_steps, 0);
assert_eq!(abm.posdot_hist[0], DVec3::ZERO);
}
#[test]
fn abm4_reset_for_topology_change_clears_history_and_dirty_flag() {
let mut abm = Abm4State::new();
let mut state = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
for _ in 0..5 {
state = abm4_translational_step(&state, accel_fn, 0.01, &mut abm);
}
assert!(!abm.is_priming());
assert_ne!(abm.posdot_hist[0], DVec3::ZERO);
assert_ne!(abm.veldot_hist[0], DVec3::ZERO);
abm.mark_topology_dirty();
assert!(abm.is_topology_dirty());
abm.reset_for_topology_change();
assert!(abm.is_priming(), "history must be cleared back to priming");
assert!(
!abm.is_topology_dirty(),
"topology-dirty flag must be cleared by reset"
);
assert_eq!(abm.primed_steps, 0);
for i in 0..HIST_LEN {
assert_eq!(abm.posdot_hist[i], DVec3::ZERO);
assert_eq!(abm.veldot_hist[i], DVec3::ZERO);
}
}
#[test]
#[should_panic(expected = "stale predictor history")]
fn abm4_step_with_topology_dirty_panics() {
let mut abm = Abm4State::new();
let state = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::ZERO,
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
abm.mark_topology_dirty();
let _ = abm4_translational_step(&state, accel_fn, 0.01, &mut abm);
}
}