use crate::assist_data::AssistData;
use crate::coordinates::{
bary_to_helio, ecliptic_to_equatorial, equatorial_to_ecliptic, helio_to_bary,
rotate_matrix_eq_to_ecl,
};
use crate::orbit::{NonGravParams, Orbit};
use crate::{Error, Result};
use libassist_sys::ffi;
use libassist_sys::{AssistSim, Ephemeris, IntegratorConfig, Simulation};
#[derive(Debug, Clone)]
pub struct PropagatedState {
pub state: [f64; 6],
pub epoch: f64,
pub stm: Option<[[f64; 6]; 6]>,
pub nongrav_partials: Option<[[f64; 3]; 6]>,
}
impl PropagatedState {
pub fn propagate_covariance(&self, p0: &[[f64; 6]; 6]) -> Option<[[f64; 6]; 6]> {
self.stm.as_ref().map(|stm| covariance_6x6(stm, p0))
}
pub fn propagate_covariance_with_nongrav(&self, p0: &[[f64; 9]; 9]) -> Option<[[f64; 6]; 6]> {
let stm = self.stm.as_ref()?;
let ng = self.nongrav_partials.as_ref()?;
Some(covariance_9x9(stm, ng, p0))
}
}
fn covariance_6x6(stm: &[[f64; 6]; 6], p0: &[[f64; 6]; 6]) -> [[f64; 6]; 6] {
let mut tmp = [[0.0f64; 6]; 6];
for i in 0..6 {
for j in 0..6 {
let mut s = 0.0;
for k in 0..6 {
s += stm[i][k] * p0[k][j];
}
tmp[i][j] = s;
}
}
let mut out = [[0.0f64; 6]; 6];
for i in 0..6 {
for j in 0..6 {
let mut s = 0.0;
for k in 0..6 {
s += tmp[i][k] * stm[j][k];
}
out[i][j] = s;
}
}
out
}
fn covariance_9x9(stm: &[[f64; 6]; 6], ng: &[[f64; 3]; 6], p0: &[[f64; 9]; 9]) -> [[f64; 6]; 6] {
let j = |i: usize, k: usize| -> f64 { if k < 6 { stm[i][k] } else { ng[i][k - 6] } };
let mut tmp = [[0.0f64; 9]; 6];
for i in 0..6 {
for col in 0..9 {
let mut s = 0.0;
for k in 0..9 {
s += j(i, k) * p0[k][col];
}
tmp[i][col] = s;
}
}
let mut out = [[0.0f64; 6]; 6];
for i in 0..6 {
for col in 0..6 {
let mut s = 0.0;
for k in 0..9 {
s += tmp[i][k] * j(col, k);
}
out[i][col] = s;
}
}
out
}
pub fn assist_propagate_single(
data: &AssistData,
orbit: &Orbit,
target_epochs: &[f64],
compute_stm: bool,
integrator: &IntegratorConfig,
) -> Result<Vec<PropagatedState>> {
if target_epochs.is_empty() {
return Ok(vec![]);
}
let ephem = &data.ephem;
let t0 = ephem.mjd_to_assist_time(orbit.epoch);
let has_nongrav = orbit.non_grav.is_some();
let want_nongrav_partials = compute_stm && has_nongrav;
let mut sim = Simulation::new()?;
sim.set_t(t0);
integrator.apply(&mut sim);
let mut asim = AssistSim::new(sim, ephem)?;
configure_forces(&mut asim, has_nongrav);
let bary_state = ecl_orbit_to_bary_eq(&orbit.state, ephem, t0)?;
let n_var = variational_count(compute_stm, has_nongrav);
add_particles_and_variationals(&mut asim, &bary_state, n_var);
if let Some(ng) = orbit.non_grav.as_ref() {
apply_nongrav_scalars(&mut asim, ng);
install_particle_params(&mut asim, ng, want_nongrav_partials);
}
run_integration(
&mut asim,
ephem,
target_epochs,
compute_stm,
want_nongrav_partials,
)
}
pub fn assist_propagate(
data: &AssistData,
orbits: &[Orbit],
target_epochs: &[f64],
compute_stm: bool,
num_threads: Option<usize>,
integrator: &IntegratorConfig,
) -> Result<Vec<Vec<PropagatedState>>> {
let op = |orbit: &Orbit| {
assist_propagate_single(data, orbit, target_epochs, compute_stm, integrator)
};
map_with_threads(orbits, num_threads, op)
}
pub(crate) fn map_with_threads<T, U, F>(
items: &[T],
num_threads: Option<usize>,
f: F,
) -> Result<Vec<U>>
where
T: Sync,
U: Send,
F: Fn(&T) -> Result<U> + Send + Sync,
{
if num_threads == Some(0) {
return Err(Error::Other(
"num_threads = Some(0) is not valid; use None for the default pool".into(),
));
}
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
match num_threads {
None => items.par_iter().map(&f).collect(),
Some(1) => items.iter().map(&f).collect(),
Some(n) => {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(n)
.build()
.map_err(|e| Error::Other(format!("rayon pool build failed: {e}")))?;
pool.install(|| items.par_iter().map(&f).collect())
}
}
}
#[cfg(not(feature = "parallel"))]
{
let _ = num_threads;
items.iter().map(&f).collect()
}
}
#[derive(Debug, Clone, Copy)]
pub struct PropagatorConfig {
pub compute_stm: bool,
pub has_nongrav: bool,
}
impl PropagatorConfig {
pub fn gravity_only() -> Self {
Self {
compute_stm: false,
has_nongrav: false,
}
}
pub fn gravity_with_stm() -> Self {
Self {
compute_stm: true,
has_nongrav: false,
}
}
pub fn nongrav_with_stm() -> Self {
Self {
compute_stm: true,
has_nongrav: true,
}
}
}
pub struct PropagatorPool<'a> {
asim: AssistSim,
ephem: &'a Ephemeris,
config: PropagatorConfig,
n_var: usize,
}
impl<'a> PropagatorPool<'a> {
pub fn new(
data: &'a AssistData,
config: PropagatorConfig,
integrator: &IntegratorConfig,
) -> Result<Self> {
let ephem = &data.ephem;
let mut sim = Simulation::new()?;
sim.set_t(0.0);
integrator.apply(&mut sim);
let mut asim = AssistSim::new(sim, ephem)?;
configure_forces(&mut asim, config.has_nongrav);
let placeholder = [0.0f64; 6];
let n_var = variational_count(config.compute_stm, config.has_nongrav);
add_particles_and_variationals(&mut asim, &placeholder, n_var);
if config.has_nongrav {
let placeholder_ng = NonGravParams::new(0.0, 0.0, 0.0);
install_particle_params(
&mut asim,
&placeholder_ng,
config.compute_stm && config.has_nongrav,
);
}
Ok(Self {
asim,
ephem,
config,
n_var,
})
}
pub fn config(&self) -> PropagatorConfig {
self.config
}
pub fn steps_done(&self) -> u64 {
self.asim.sim().steps_done()
}
pub fn propagate(
&mut self,
orbit: &Orbit,
target_epochs: &[f64],
) -> Result<Vec<PropagatedState>> {
if target_epochs.is_empty() {
return Ok(vec![]);
}
if orbit.non_grav.is_some() != self.config.has_nongrav {
return Err(Error::Other(format!(
"orbit non-grav flag ({}) does not match pool config \
has_nongrav ({}) — rebuild the pool with a matching config",
orbit.non_grav.is_some(),
self.config.has_nongrav,
)));
}
let want_nongrav_partials = self.config.compute_stm && self.config.has_nongrav;
let t0 = self.ephem.mjd_to_assist_time(orbit.epoch);
self.asim.reset_integrator();
self.asim.sim_mut().set_t(t0);
self.asim.sim_mut().set_dt(0.001);
let bary_state = ecl_orbit_to_bary_eq(&orbit.state, self.ephem, t0)?;
overwrite_particles(&mut self.asim, &bary_state, self.n_var);
if let Some(ng) = orbit.non_grav.as_ref() {
apply_nongrav_scalars(&mut self.asim, ng);
self.asim.update_nongrav_coeffs(ng.a1, ng.a2, ng.a3)?;
}
run_integration(
&mut self.asim,
self.ephem,
target_epochs,
self.config.compute_stm,
want_nongrav_partials,
)
}
}
fn ecl_orbit_to_bary_eq(ecl_state: &[f64; 6], ephem: &Ephemeris, t: f64) -> Result<[f64; 6]> {
let sun = ephem.get_body_state_array(ffi::ASSIST_BODY_SUN, t)?;
Ok(helio_to_bary(&ecliptic_to_equatorial(ecl_state), &sun))
}
pub(crate) fn configure_forces(asim: &mut AssistSim, has_nongrav: bool) {
let mut forces = ffi::ASSIST_FORCES_DEFAULT;
if has_nongrav {
forces |= ffi::ASSIST_FORCE_NON_GRAVITATIONAL;
}
asim.set_forces(forces);
}
fn variational_count(compute_stm: bool, has_nongrav: bool) -> usize {
match (compute_stm, has_nongrav) {
(false, _) => 0,
(true, false) => 6,
(true, true) => 9,
}
}
pub(crate) fn apply_nongrav_scalars(asim: &mut AssistSim, ng: &NonGravParams) {
if let Some(v) = ng.alpha {
asim.set_alpha(v);
}
if let Some(v) = ng.nk {
asim.set_nk(v);
}
if let Some(v) = ng.nm {
asim.set_nm(v);
}
if let Some(v) = ng.nn {
asim.set_nn(v);
}
if let Some(v) = ng.r0 {
asim.set_r0(v);
}
}
pub(crate) fn add_particles_and_variationals(
asim: &mut AssistSim,
bary_state: &[f64; 6],
n_var: usize,
) {
asim.sim_mut().add_test_particle(
bary_state[0],
bary_state[1],
bary_state[2],
bary_state[3],
bary_state[4],
bary_state[5],
);
for _ in 0..n_var {
asim.sim_mut().add_variation_1st_order(0);
}
init_variational_state_perturbations(asim, n_var);
}
fn overwrite_particles(asim: &mut AssistSim, bary_state: &[f64; 6], n_var: usize) {
unsafe {
let ptr = librebound_sys::ffi::assist_rs_sim_get_particles(asim.sim().as_ptr());
let p = &mut *ptr;
p.x = bary_state[0];
p.y = bary_state[1];
p.z = bary_state[2];
p.vx = bary_state[3];
p.vy = bary_state[4];
p.vz = bary_state[5];
}
init_variational_state_perturbations(asim, n_var);
}
fn init_variational_state_perturbations(asim: &mut AssistSim, n_var: usize) {
debug_assert!(matches!(n_var, 0 | 6 | 9));
if n_var == 0 {
return;
}
unsafe {
let ptr = librebound_sys::ffi::assist_rs_sim_get_particles(asim.sim().as_ptr());
for i in 0..n_var {
*ptr.add(1 + i) = ffi::reb_particle::default();
}
for d in 0..6 {
let p = &mut *ptr.add(1 + d);
match d {
0 => p.x = 1.0,
1 => p.y = 1.0,
2 => p.z = 1.0,
3 => p.vx = 1.0,
4 => p.vy = 1.0,
5 => p.vz = 1.0,
_ => unreachable!(),
}
}
}
}
pub(crate) fn install_particle_params(
asim: &mut AssistSim,
ng: &NonGravParams,
want_nongrav_partials: bool,
) {
let n_total = asim.sim().n_particles();
let mut params = vec![0.0f64; 3 * n_total];
params[0] = ng.a1;
params[1] = ng.a2;
params[2] = ng.a3;
if want_nongrav_partials {
let n_real = 1usize;
for k in 0..3 {
params[3 * (n_real + 6 + k) + k] = 1.0;
}
}
asim.set_particle_params(params);
}
fn run_integration(
asim: &mut AssistSim,
ephem: &Ephemeris,
target_epochs: &[f64],
compute_stm: bool,
want_nongrav_partials: bool,
) -> Result<Vec<PropagatedState>> {
let mut results = Vec::with_capacity(target_epochs.len());
for &target_mjd in target_epochs {
let t_target = ephem.mjd_to_assist_time(target_mjd);
asim.integrate(t_target)?;
let particles = asim.sim().particles();
if particles.is_empty() {
return Err(Error::Other("No particles after integration".into()));
}
let p = &particles[0];
let bary_eq = [p.x, p.y, p.z, p.vx, p.vy, p.vz];
let sun_t = ephem.get_body_state_array(ffi::ASSIST_BODY_SUN, t_target)?;
let helio_ecl = equatorial_to_ecliptic(&bary_to_helio(&bary_eq, &sun_t));
let (stm, nongrav_partials) = if compute_stm {
extract_stm_and_partials(particles, want_nongrav_partials)
} else {
(None, None)
};
results.push(PropagatedState {
state: helio_ecl,
epoch: target_mjd,
stm,
nongrav_partials,
});
}
Ok(results)
}
fn extract_stm_and_partials(
particles: &[ffi::reb_particle],
want_nongrav_partials: bool,
) -> (Option<[[f64; 6]; 6]>, Option<[[f64; 3]; 6]>) {
let n_real = 1usize;
let mut stm_eq = [[0.0f64; 6]; 6];
for (d, vp) in particles[n_real..n_real + 6].iter().enumerate() {
stm_eq[0][d] = vp.x;
stm_eq[1][d] = vp.y;
stm_eq[2][d] = vp.z;
stm_eq[3][d] = vp.vx;
stm_eq[4][d] = vp.vy;
stm_eq[5][d] = vp.vz;
}
let stm = Some(rotate_matrix_eq_to_ecl(&stm_eq));
let nongrav = if want_nongrav_partials {
let mut ng_eq = [[0.0f64; 3]; 6];
for (k, vp) in particles[n_real + 6..n_real + 9].iter().enumerate() {
ng_eq[0][k] = vp.x;
ng_eq[1][k] = vp.y;
ng_eq[2][k] = vp.z;
ng_eq[3][k] = vp.vx;
ng_eq[4][k] = vp.vy;
ng_eq[5][k] = vp.vz;
}
let mut ng_ecl = [[0.0f64; 3]; 6];
for k in 0..3 {
let col = [
ng_eq[0][k],
ng_eq[1][k],
ng_eq[2][k],
ng_eq[3][k],
ng_eq[4][k],
ng_eq[5][k],
];
let rotated = equatorial_to_ecliptic(&col);
for r in 0..6 {
ng_ecl[r][k] = rotated[r];
}
}
Some(ng_ecl)
} else {
None
};
(stm, nongrav)
}
#[cfg(test)]
mod tests {
use super::*;
fn identity6() -> [[f64; 6]; 6] {
let mut m = [[0.0; 6]; 6];
for i in 0..6 {
m[i][i] = 1.0;
}
m
}
fn identity9() -> [[f64; 9]; 9] {
let mut m = [[0.0; 9]; 9];
for i in 0..9 {
m[i][i] = 1.0;
}
m
}
fn sample_stm() -> [[f64; 6]; 6] {
[
[1.0, 0.0, 0.0, 30.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 30.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 30.0],
[0.0001, 0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0001, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0001, 0.0, 0.0, 1.0],
]
}
fn max_abs_diff_6x6(a: &[[f64; 6]; 6], b: &[[f64; 6]; 6]) -> f64 {
let mut m: f64 = 0.0;
for i in 0..6 {
for j in 0..6 {
m = m.max((a[i][j] - b[i][j]).abs());
}
}
m
}
#[test]
fn covariance_6x6_identity_p0() {
let stm = sample_stm();
let got = covariance_6x6(&stm, &identity6());
let mut want = [[0.0f64; 6]; 6];
for i in 0..6 {
for j in 0..6 {
let mut s = 0.0;
for k in 0..6 {
s += stm[i][k] * stm[j][k];
}
want[i][j] = s;
}
}
assert!(max_abs_diff_6x6(&got, &want) < 1e-14);
for i in 0..6 {
for j in 0..6 {
assert!((got[i][j] - got[j][i]).abs() < 1e-14);
}
}
}
#[test]
fn covariance_6x6_zero_stm_is_zero() {
let stm = [[0.0f64; 6]; 6];
let p0 = identity6();
let got = covariance_6x6(&stm, &p0);
for i in 0..6 {
for j in 0..6 {
assert_eq!(got[i][j], 0.0);
}
}
}
#[test]
fn propagate_covariance_method_wraps_helper() {
let state = PropagatedState {
state: [0.0; 6],
epoch: 0.0,
stm: Some(sample_stm()),
nongrav_partials: None,
};
let p0 = identity6();
let via_method = state.propagate_covariance(&p0).unwrap();
let via_helper = covariance_6x6(state.stm.as_ref().unwrap(), &p0);
assert!(max_abs_diff_6x6(&via_method, &via_helper) < 1e-14);
}
#[test]
fn propagate_covariance_none_when_no_stm() {
let state = PropagatedState {
state: [0.0; 6],
epoch: 0.0,
stm: None,
nongrav_partials: None,
};
assert!(state.propagate_covariance(&identity6()).is_none());
}
#[test]
fn covariance_9x9_reduces_to_6x6_when_nongrav_block_is_zero() {
let stm = sample_stm();
let ng: [[f64; 3]; 6] = [
[1e-3, 2e-3, -1e-3],
[0.0, 1e-3, 0.0],
[0.0, 0.0, 1e-3],
[1e-5, 0.0, 0.0],
[0.0, 1e-5, 0.0],
[0.0, 0.0, 1e-5],
];
let p_xx = {
let mut m = [[0.0; 6]; 6];
for (i, v) in [0.1, 0.1, 0.1, 1e-4, 1e-4, 1e-4].iter().enumerate() {
m[i][i] = *v;
}
m
};
let mut p0_9 = [[0.0; 9]; 9];
for i in 0..6 {
for j in 0..6 {
p0_9[i][j] = p_xx[i][j];
}
}
let got = covariance_9x9(&stm, &ng, &p0_9);
let want = covariance_6x6(&stm, &p_xx);
assert!(
max_abs_diff_6x6(&got, &want) < 1e-14,
"9×9 with zero A-block should match 6×6 path; diff={:.3e}",
max_abs_diff_6x6(&got, &want)
);
}
#[test]
fn covariance_9x9_picks_up_pure_nongrav_covariance() {
let stm = sample_stm();
let ng: [[f64; 3]; 6] = [
[10.0, 5.0, 2.0],
[3.0, 8.0, 1.0],
[1.0, 2.0, 6.0],
[0.1, 0.05, 0.02],
[0.03, 0.08, 0.01],
[0.01, 0.02, 0.06],
];
let p_aa = [[1.0, 0.2, 0.1], [0.2, 1.0, 0.3], [0.1, 0.3, 1.0]];
let mut p0_9 = [[0.0; 9]; 9];
for i in 0..3 {
for j in 0..3 {
p0_9[6 + i][6 + j] = p_aa[i][j];
}
}
let got = covariance_9x9(&stm, &ng, &p0_9);
let mut tmp = [[0.0; 3]; 6]; for i in 0..6 {
for j in 0..3 {
let mut s = 0.0;
for k in 0..3 {
s += ng[i][k] * p_aa[k][j];
}
tmp[i][j] = s;
}
}
let mut want = [[0.0; 6]; 6];
for i in 0..6 {
for j in 0..6 {
let mut s = 0.0;
for k in 0..3 {
s += tmp[i][k] * ng[j][k];
}
want[i][j] = s;
}
}
assert!(max_abs_diff_6x6(&got, &want) < 1e-12);
}
#[test]
fn covariance_9x9_identity_p0_includes_both_blocks() {
let stm = sample_stm();
let ng: [[f64; 3]; 6] = [
[1.0, 0.5, 0.2],
[0.3, 0.8, 0.1],
[0.1, 0.2, 0.6],
[0.01, 0.005, 0.002],
[0.003, 0.008, 0.001],
[0.001, 0.002, 0.006],
];
let got = covariance_9x9(&stm, &ng, &identity9());
let state_part = covariance_6x6(&stm, &identity6());
let mut ng_part = [[0.0; 6]; 6];
for i in 0..6 {
for j in 0..6 {
let mut s = 0.0;
for k in 0..3 {
s += ng[i][k] * ng[j][k];
}
ng_part[i][j] = s;
}
}
let mut want = [[0.0; 6]; 6];
for i in 0..6 {
for j in 0..6 {
want[i][j] = state_part[i][j] + ng_part[i][j];
}
}
assert!(max_abs_diff_6x6(&got, &want) < 1e-12);
}
}