use crate::mass::MassProperties;
use crate::rotational::*;
use crate::state::TranslationalState;
use astrodyn_math::JeodQuat;
use glam::DVec3;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum IntegratorType {
#[default]
Rk4,
Rkf45,
GaussJackson(crate::gauss_jackson::config::GaussJacksonConfig),
Abm4,
}
pub fn rk4_translational_step(
state: &TranslationalState,
accel_fn: impl Fn(&TranslationalState, f64) -> DVec3,
dt: f64,
) -> TranslationalState {
assert!(
dt.is_finite(),
"rk4_translational_step requires a finite dt, got {dt}"
);
let k1_a = accel_fn(state, 0.0);
let k1_v = state.velocity;
let s2 = TranslationalState {
position: state.position + k1_v * (dt * 0.5),
velocity: state.velocity + k1_a * (dt * 0.5),
};
let k2_a = accel_fn(&s2, 0.5);
let k2_v = s2.velocity;
let s3 = TranslationalState {
position: state.position + k2_v * (dt * 0.5),
velocity: state.velocity + k2_a * (dt * 0.5),
};
let k3_a = accel_fn(&s3, 0.5);
let k3_v = s3.velocity;
let s4 = TranslationalState {
position: state.position + k3_v * dt,
velocity: state.velocity + k3_a * dt,
};
let k4_a = accel_fn(&s4, 1.0);
let k4_v = s4.velocity;
let sixth_dt = dt / 6.0;
TranslationalState {
position: state.position + (k1_v + k2_v * 2.0 + k3_v * 2.0 + k4_v) * sixth_dt,
velocity: state.velocity + (k1_a + k2_a * 2.0 + k3_a * 2.0 + k4_a) * sixth_dt,
}
}
pub fn rk4_sixdof_step(
state: &SixDofState,
accel_fn: impl Fn(&SixDofState, f64) -> DVec3,
torque_fn: impl Fn(&SixDofState) -> DVec3,
mass_props: &MassProperties,
dt: f64,
) -> SixDofState {
assert!(
dt.is_finite(),
"rk4_sixdof_step requires a finite dt, got {dt}"
);
let pos0 = state.trans.position;
let vel0 = state.trans.velocity;
let q0 = state.rot.quaternion.data;
let omega0 = state.rot.ang_vel_body;
let make_state = |pos: DVec3, vel: DVec3, q: [f64; 4], omega: DVec3| -> SixDofState {
SixDofState {
trans: TranslationalState {
position: pos,
velocity: vel,
},
rot: RotationalState {
quaternion: JeodQuat::new(q[0], q[1], q[2], q[3]),
ang_vel_body: omega,
},
}
};
let eval_derivs = |s: &SixDofState, time_frac: f64| -> (DVec3, DVec3, [f64; 4], DVec3) {
let k_v = s.trans.velocity;
let k_a = accel_fn(s, time_frac);
let k_qdot = compute_left_quat_deriv(&s.rot.quaternion, s.rot.ang_vel_body);
let k_alpha = compute_rotational_acceleration(
&mass_props.inertia,
&mass_props.inverse_inertia,
s.rot.ang_vel_body,
torque_fn(s),
);
(k_v, k_a, k_qdot, k_alpha)
};
let step_q = |q_base: [f64; 4], k_qdot: [f64; 4], h: f64| -> [f64; 4] {
[
q_base[0] + k_qdot[0] * h,
q_base[1] + k_qdot[1] * h,
q_base[2] + k_qdot[2] * h,
q_base[3] + k_qdot[3] * h,
]
};
let (k1_v, k1_a, k1_qdot, k1_alpha) = eval_derivs(state, 0.0);
let half_dt = dt * 0.5;
let s2 = make_state(
pos0 + k1_v * half_dt,
vel0 + k1_a * half_dt,
step_q(q0, k1_qdot, half_dt),
omega0 + k1_alpha * half_dt,
);
let (k2_v, k2_a, k2_qdot, k2_alpha) = eval_derivs(&s2, 0.5);
let s3 = make_state(
pos0 + k2_v * half_dt,
vel0 + k2_a * half_dt,
step_q(q0, k2_qdot, half_dt),
omega0 + k2_alpha * half_dt,
);
let (k3_v, k3_a, k3_qdot, k3_alpha) = eval_derivs(&s3, 0.5);
let s4 = make_state(
pos0 + k3_v * dt,
vel0 + k3_a * dt,
step_q(q0, k3_qdot, dt),
omega0 + k3_alpha * dt,
);
let (k4_v, k4_a, k4_qdot, k4_alpha) = eval_derivs(&s4, 1.0);
let sixth_dt = dt / 6.0;
let final_pos = pos0 + (k1_v + k2_v * 2.0 + k3_v * 2.0 + k4_v) * sixth_dt;
let final_vel = vel0 + (k1_a + k2_a * 2.0 + k3_a * 2.0 + k4_a) * sixth_dt;
let final_omega = omega0 + (k1_alpha + k2_alpha * 2.0 + k3_alpha * 2.0 + k4_alpha) * sixth_dt;
let final_q = [
q0[0] + (k1_qdot[0] + 2.0 * k2_qdot[0] + 2.0 * k3_qdot[0] + k4_qdot[0]) * sixth_dt,
q0[1] + (k1_qdot[1] + 2.0 * k2_qdot[1] + 2.0 * k3_qdot[1] + k4_qdot[1]) * sixth_dt,
q0[2] + (k1_qdot[2] + 2.0 * k2_qdot[2] + 2.0 * k3_qdot[2] + k4_qdot[2]) * sixth_dt,
q0[3] + (k1_qdot[3] + 2.0 * k2_qdot[3] + 2.0 * k3_qdot[3] + k4_qdot[3]) * sixth_dt,
];
let mut final_quat = JeodQuat::new(final_q[0], final_q[1], final_q[2], final_q[3]);
normalize_integ(&mut final_quat);
SixDofState {
trans: TranslationalState {
position: final_pos,
velocity: final_vel,
},
rot: RotationalState {
quaternion: final_quat,
ang_vel_body: final_omega,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mass::MassProperties;
use crate::rotational::{RotationalState, SixDofState};
use astrodyn_math::JeodQuat;
use glam::DMat3;
#[test]
fn harmonic_oscillator() {
let dt = 0.01;
let steps = 628; let t_final = dt * steps as f64;
let mut state = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::new(0.0, 0.0, 0.0),
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
for _ in 0..steps {
state = rk4_translational_step(&state, accel_fn, dt);
}
let exact_pos = t_final.cos();
let exact_vel = -t_final.sin();
let pos_error = (state.position.x - exact_pos).abs();
let vel_error = (state.velocity.x - exact_vel).abs();
assert!(pos_error < 1e-8, "Position error {pos_error} exceeds 1e-8");
assert!(vel_error < 1e-8, "Velocity error {vel_error} exceeds 1e-8");
}
#[test]
fn convergence_order() {
let dt_coarse = 0.1;
let dt_fine = dt_coarse / 2.0;
let total_time: f64 = 1.0;
let initial = TranslationalState {
position: DVec3::new(1.0, 0.0, 0.0),
velocity: DVec3::new(0.0, 0.0, 0.0),
};
let accel_fn = |s: &TranslationalState, _t: f64| -> DVec3 { -s.position };
let exact_pos = total_time.cos();
let exact_vel = -total_time.sin();
let steps_coarse = (total_time / dt_coarse).round() as usize;
let mut state_coarse = initial;
for _ in 0..steps_coarse {
state_coarse = rk4_translational_step(&state_coarse, accel_fn, dt_coarse);
}
let error_coarse = (state_coarse.position.x - exact_pos).abs();
let steps_fine = (total_time / dt_fine).round() as usize;
let mut state_fine = initial;
for _ in 0..steps_fine {
state_fine = rk4_translational_step(&state_fine, accel_fn, dt_fine);
}
let error_fine = (state_fine.position.x - exact_pos).abs();
let ratio = error_coarse / error_fine;
assert!(
(ratio - 16.0).abs() < 2.0,
"Convergence ratio {ratio} is not close to 16 (4th order)"
);
let vel_error_coarse = (state_coarse.velocity.x - exact_vel).abs();
let vel_error_fine = (state_fine.velocity.x - exact_vel).abs();
let vel_ratio = vel_error_coarse / vel_error_fine;
assert!(
(vel_ratio - 16.0).abs() < 2.0,
"Velocity convergence ratio {vel_ratio} is not close to 16 (4th order)"
);
}
#[test]
fn free_particle() {
let dt = 0.5;
let initial_pos = DVec3::new(1.0, 2.0, 3.0);
let initial_vel = DVec3::new(4.0, 5.0, 6.0);
let mut state = TranslationalState {
position: initial_pos,
velocity: initial_vel,
};
let zero_accel = |_: &TranslationalState, _t: f64| -> DVec3 { DVec3::ZERO };
let num_steps = 10;
for _ in 0..num_steps {
state = rk4_translational_step(&state, zero_accel, dt);
}
let total_time = dt * num_steps as f64;
let expected_pos = initial_pos + initial_vel * total_time;
let pos_error = (state.position - expected_pos).length();
assert!(
pos_error < 1e-12,
"Free particle position error {pos_error} exceeds 1e-12"
);
let vel_error = (state.velocity - initial_vel).length();
assert!(
vel_error < 1e-12,
"Free particle velocity error {vel_error} exceeds 1e-12"
);
}
fn mass_with_inertia(mass: f64, ix: f64, iy: f64, iz: f64) -> MassProperties {
let inertia = DMat3::from_diagonal(DVec3::new(ix, iy, iz));
let inverse_inertia = DMat3::from_diagonal(DVec3::new(1.0 / ix, 1.0 / iy, 1.0 / iz));
MassProperties {
mass,
inverse_mass: 1.0 / mass,
inertia,
inverse_inertia,
position: DVec3::ZERO,
t_parent_this: DMat3::IDENTITY,
dirty: false,
}
}
#[test]
fn torque_free_symmetric_body() {
let mass_props = mass_with_inertia(100.0, 10.0, 10.0, 20.0);
let omega_z = 0.1;
let omega_x0 = 0.01;
let omega_p = (20.0 - 10.0) / 10.0 * omega_z; let period = std::f64::consts::TAU / omega_p;
let initial = SixDofState {
trans: TranslationalState {
position: DVec3::ZERO,
velocity: DVec3::ZERO,
},
rot: RotationalState {
quaternion: JeodQuat::identity(),
ang_vel_body: DVec3::new(omega_x0, 0.0, omega_z),
},
};
let dt = 0.001;
let steps = (period / dt).round() as usize;
let zero_accel = |_: &SixDofState, _t: f64| -> DVec3 { DVec3::ZERO };
let zero_torque = |_: &SixDofState| -> DVec3 { DVec3::ZERO };
let mut state = initial;
for _ in 0..steps {
state = rk4_sixdof_step(&state, zero_accel, zero_torque, &mass_props, dt);
}
let omega_z_err = (state.rot.ang_vel_body.z - omega_z).abs();
assert!(
omega_z_err < 1e-10,
"omega_z should be constant, error = {}",
omega_z_err,
);
let omega_x_err = (state.rot.ang_vel_body.x - omega_x0).abs();
let omega_y_err = state.rot.ang_vel_body.y.abs();
let rel_err_x = omega_x_err / omega_x0;
let rel_err_y = omega_y_err / omega_x0;
assert!(
rel_err_x < 1e-3,
"omega_x relative error {} exceeds 0.1% after one precession period",
rel_err_x,
);
assert!(
rel_err_y < 1e-3,
"omega_y relative error {} exceeds 0.1% after one precession period",
rel_err_y,
);
}
#[test]
fn quaternion_norm_preservation() {
let mass_props = mass_with_inertia(100.0, 10.0, 20.0, 30.0);
let initial = SixDofState {
trans: TranslationalState {
position: DVec3::ZERO,
velocity: DVec3::ZERO,
},
rot: RotationalState {
quaternion: JeodQuat::identity(),
ang_vel_body: DVec3::new(0.01, 0.02, 0.05),
},
};
let dt = 1.0;
let total_seconds = 86400;
let zero_accel = |_: &SixDofState, _t: f64| -> DVec3 { DVec3::ZERO };
let zero_torque = |_: &SixDofState| -> DVec3 { DVec3::ZERO };
let mut state = initial;
let mut max_norm_err = 0.0_f64;
for _ in 0..total_seconds {
state = rk4_sixdof_step(&state, zero_accel, zero_torque, &mass_props, dt);
let norm_err = (state.rot.quaternion.norm_sq() - 1.0).abs();
max_norm_err = max_norm_err.max(norm_err);
}
assert!(
max_norm_err < 1e-14,
"Max quaternion norm error over 86400s: {} (exceeds 1e-14)",
max_norm_err,
);
}
#[test]
fn sixdof_pure_translation_matches_translational() {
let mass_props = MassProperties::new(100.0);
let initial_pos = DVec3::new(7_000_000.0, 0.0, 0.0);
let initial_vel = DVec3::new(0.0, 7_500.0, 0.0);
let mu = 3.986_004_415e14;
let mut state_3dof = TranslationalState {
position: initial_pos,
velocity: initial_vel,
};
let mut state_6dof = SixDofState {
trans: TranslationalState {
position: initial_pos,
velocity: initial_vel,
},
rot: RotationalState {
quaternion: JeodQuat::identity(),
ang_vel_body: DVec3::ZERO,
},
};
let dt = 10.0;
let steps = 100;
for _ in 0..steps {
state_3dof = rk4_translational_step(
&state_3dof,
|s, _t| {
let r = s.position.length();
-mu / (r * r * r) * s.position
},
dt,
);
state_6dof = rk4_sixdof_step(
&state_6dof,
|s, _t| {
let r = s.trans.position.length();
-mu / (r * r * r) * s.trans.position
},
|_| DVec3::ZERO,
&mass_props,
dt,
);
}
let pos_diff = (state_6dof.trans.position - state_3dof.position).length();
let vel_diff = (state_6dof.trans.velocity - state_3dof.velocity).length();
assert!(
pos_diff < 1e-6,
"Position difference between 3DOF and 6DOF: {} m",
pos_diff,
);
assert!(
vel_diff < 1e-9,
"Velocity difference between 3DOF and 6DOF: {} m/s",
vel_diff,
);
let q = state_6dof.rot.quaternion;
assert!(
(q.scalar() - 1.0).abs() < 1e-14,
"Quaternion scalar should be 1.0, got {}",
q.scalar(),
);
assert!(
q.vector().length() < 1e-14,
"Quaternion vector should be zero, got {:?}",
q.vector(),
);
}
}