use itertools::Itertools;
use nalgebra::Vector3;
use std::{collections::VecDeque, ops::ControlFlow};
use crate::{
constants::{Observations, Radian},
error_models::ErrorModel,
initial_orbit_determination::{gauss::GaussObs, gauss_result::GaussResult, IODParams},
observations::{triplets_iod::generate_triplets, Observation},
orbit_type::equinoctial_element::EquinoctialElements,
outfit::Outfit,
outfit_errors::OutfitError,
};
pub trait ObservationsExt {
fn compute_triplets(
&mut self,
dt_min: f64,
dt_max: f64,
optimal_interval_time: f64,
max_obs_for_triplets: usize,
max_triplet: u32,
) -> Vec<GaussObs>;
fn select_rms_interval(
&self,
triplets: &GaussObs,
extf: f64,
dtmax: f64,
) -> Result<(usize, usize), OutfitError>;
fn rms_orbit_error(
&self,
state: &Outfit,
triplets: &GaussObs,
orbit_element: &EquinoctialElements,
extf: f64,
dtmax: f64,
prune_if_rms_ge: Option<f64>,
) -> Result<f64, OutfitError>;
fn apply_batch_rms_correction(&mut self, error_model: &ErrorModel, gap_max: f64);
fn extract_errors(&self, idx_obs: Vector3<usize>) -> (Vector3<Radian>, Vector3<Radian>);
}
pub trait ObservationIOD {
fn estimate_best_orbit(
&mut self,
state: &Outfit,
error_model: &ErrorModel,
rng: &mut impl rand::Rng,
params: &IODParams,
) -> Result<(GaussResult, f64), OutfitError>;
}
impl ObservationsExt for Observations {
fn compute_triplets(
&mut self,
dt_min: f64,
dt_max: f64,
optimal_interval_time: f64,
max_obs_for_triplets: usize,
max_triplet: u32,
) -> Vec<GaussObs> {
generate_triplets(
self,
dt_min,
dt_max,
optimal_interval_time,
max_obs_for_triplets,
max_triplet,
)
}
fn select_rms_interval(
&self,
triplets: &GaussObs,
extf: f64,
dtmax: f64,
) -> Result<(usize, usize), OutfitError> {
let nobs = self.len();
let idx_obs1 = triplets.idx_obs[0];
let obs1 = self
.get(idx_obs1)
.ok_or(OutfitError::ObservationNotFound(idx_obs1))?;
let idx_obs3 = triplets.idx_obs[2];
let obs3 = self
.get(idx_obs3)
.ok_or(OutfitError::ObservationNotFound(idx_obs3))?;
let first_obs = self.first().ok_or(OutfitError::ObservationNotFound(0))?;
let last_obs = self
.last()
.ok_or(OutfitError::ObservationNotFound(nobs - 1))?;
let mut dt = if extf >= 0.0 {
(obs3.time - obs1.time) * extf
} else {
10.0 * (last_obs.time - first_obs.time)
};
if dtmax >= 0.0 {
dt = dt.max(dtmax);
}
let mut i_start = 0;
for i in (0..=idx_obs1).rev() {
if let Some(obs_i) = self.get(i) {
if obs1.time - obs_i.time > dt {
break;
}
i_start = i;
}
}
let mut i_end = nobs - 1;
for i in idx_obs3..nobs {
if let Some(obs_i) = self.get(i) {
if obs_i.time - obs3.time > dt {
break;
}
i_end = i;
}
}
Ok((i_start, i_end))
}
fn rms_orbit_error(
&self,
state: &Outfit,
triplets: &GaussObs,
orbit_element: &EquinoctialElements,
extf: f64,
dtmax: f64,
prune_if_rms_ge: Option<f64>,
) -> Result<f64, OutfitError> {
let (start_obs_rms, end_obs_rms) = self.select_rms_interval(triplets, extf, dtmax)?;
let n_obs = (end_obs_rms - start_obs_rms + 1) as f64;
let denom = 2.0 * n_obs;
if prune_if_rms_ge.is_none() {
let sum = self[start_obs_rms..=end_obs_rms]
.iter()
.map(|obs| obs.ephemeris_error(state, orbit_element))
.try_fold(0.0, |acc, term| term.map(|v| acc + v))?;
return Ok((sum / denom).sqrt());
}
let prune = prune_if_rms_ge.unwrap();
let sum_cutoff = if prune.is_finite() {
prune * prune * denom
} else {
f64::INFINITY };
let folded: ControlFlow<f64, f64> = self[start_obs_rms..=end_obs_rms]
.iter()
.map(|obs| obs.ephemeris_error(state, orbit_element))
.try_fold(0.0, |acc, term| match term {
Ok(v) => {
let new_sum = acc + v;
if new_sum >= sum_cutoff {
ControlFlow::Break(prune)
} else {
ControlFlow::Continue(new_sum)
}
}
Err(_) => ControlFlow::Break(prune),
});
match folded {
ControlFlow::Continue(sum) => Ok((sum / denom).sqrt()),
ControlFlow::Break(rms) => Ok(rms),
}
}
fn apply_batch_rms_correction(&mut self, error_model: &ErrorModel, gap_max: f64) {
self.sort_by(|a, b| a.time.partial_cmp(&b.time).unwrap());
for (_observer_id, group) in &self.into_iter().chunk_by(|obs| obs.observer) {
let mut batch: VecDeque<&mut Observation> = VecDeque::new();
let mut iter = group.peekable();
while let Some(obs) = iter.next() {
batch.push_back(obs);
while let Some(next) = iter.peek() {
let dt = next.time
- batch
.back()
.expect("in apply_batch_rms_correction: batch should not be empty")
.time;
if dt <= gap_max {
batch.push_back(iter.next().expect(
"in apply_batch_rms_correction: next in batch should not be None",
));
} else {
break;
}
}
let n = batch.len();
if n > 0 {
let factor = match error_model {
ErrorModel::VFCC17 if n >= 5 => (n as f64 * 0.25).sqrt(),
_ => (n as f64).sqrt(),
};
for obs in batch.drain(..) {
obs.error_ra *= factor;
obs.error_dec *= factor;
}
}
}
}
}
fn extract_errors(&self, idx_obs: Vector3<usize>) -> (Vector3<Radian>, Vector3<Radian>) {
let (errors_ra, errors_dec): (Vec<_>, Vec<_>) = idx_obs
.into_iter()
.map(|i| {
let obs = &self[*i];
(obs.error_ra, obs.error_dec)
})
.unzip();
(
Vector3::from_column_slice(&errors_ra),
Vector3::from_column_slice(&errors_dec),
)
}
}
impl ObservationIOD for Observations {
fn estimate_best_orbit(
&mut self,
state: &Outfit,
error_model: &ErrorModel,
rng: &mut impl rand::Rng,
params: &IODParams,
) -> Result<(GaussResult, f64), OutfitError> {
self.apply_batch_rms_correction(error_model, params.gap_max);
let triplets = self.compute_triplets(
params.dt_min,
params.dt_max_triplet,
params.optimal_interval_time,
params.max_obs_for_triplets,
params.max_triplets,
);
if triplets.is_empty() {
let span = if self.is_empty() {
0.0
} else {
self.last().unwrap().time - self.first().unwrap().time
};
return Err(OutfitError::NoFeasibleTriplets {
span,
n_obs: self.len(),
dt_min: params.dt_min,
dt_max: params.dt_max_triplet,
});
}
let mut best_rms = f64::INFINITY;
let mut best_orbit: Option<GaussResult> = None;
let mut last_error: Option<OutfitError> = None;
let mut n_attempts: usize = 0;
for triplet in triplets {
let (error_ra, error_dec) = self.extract_errors(triplet.idx_obs);
for realization in triplet.realizations_iter(
&error_ra,
&error_dec,
params.n_noise_realizations,
params.noise_scale,
rng,
) {
n_attempts += 1;
let gauss_res = match realization.prelim_orbit(state, params) {
Ok(res) => res,
Err(e) => {
last_error = Some(e);
continue;
}
};
let equinoctial_elements = gauss_res.get_orbit().to_equinoctial()?;
let rms = match self.rms_orbit_error(
state,
&realization,
&equinoctial_elements,
params.extf,
params.dtmax,
Some(best_rms),
) {
Ok(v) => {
if !v.is_finite() {
last_error = Some(OutfitError::NonFiniteScore(v));
continue;
} else {
v
}
}
Err(e) => {
last_error = Some(e);
continue;
}
};
if rms < best_rms {
best_rms = rms;
best_orbit = Some(gauss_res);
}
}
}
if let Some(orbit) = best_orbit {
Ok((orbit, best_rms))
} else {
let root_cause = match last_error {
Some(e) => e,
None => panic!("In estimate_best_orbit: no error captured but best_orbit is None, this should not happen"),
};
Err(OutfitError::NoViableOrbit {
cause: Box::new(root_cause),
attempts: n_attempts,
})
}
}
}
#[cfg(test)]
mod test_obs_ext {
use crate::error_models::ErrorModel;
use super::*;
#[test]
#[cfg(feature = "jpl-download")]
fn test_select_rms_interval() {
use crate::unit_test_global::OUTFIT_HORIZON_TEST;
let mut traj_set = OUTFIT_HORIZON_TEST.1.clone();
let traj_number = crate::constants::ObjectNumber::String("K09R05F".into());
let traj_len = traj_set
.get(&traj_number)
.expect("Failed to get trajectory")
.len();
let traj = traj_set
.get_mut(&traj_number)
.expect("Failed to get trajectory");
let triplets = traj.compute_triplets(0.03, 150.0, 20.0, traj_len, 10);
let (u1, u2) = traj
.select_rms_interval(triplets.first().unwrap(), -1., 30.)
.unwrap();
assert_eq!(u1, 0);
assert_eq!(u2, 36);
let (u1, u2) = traj
.select_rms_interval(triplets.first().unwrap(), 10., 30.)
.unwrap();
assert_eq!(u1, 14);
assert_eq!(u2, 36);
let (u1, u2) = traj
.select_rms_interval(triplets.first().unwrap(), 0.001, 3.)
.unwrap();
assert_eq!(u1, 17);
assert_eq!(u2, 33);
}
#[test]
#[cfg(feature = "jpl-download")]
fn test_rms_trajectory() {
use nalgebra::Matrix3;
use crate::{
orbit_type::keplerian_element::KeplerianElements, unit_test_global::OUTFIT_HORIZON_TEST,
};
let mut traj_set = OUTFIT_HORIZON_TEST.1.clone();
let traj = traj_set
.get_mut(&crate::constants::ObjectNumber::String("K09R05F".into()))
.expect("Failed to get trajectory");
traj.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
let triplets = GaussObs {
idx_obs: Vector3::new(34, 35, 36),
ra: [[
1.789_797_623_341_267,
1.789_865_909_348_251,
1.7899347771316527,
]]
.into(),
dec: [[
0.779_178_052_350_181,
0.779_086_664_971_291_9,
0.778_996_538_107_973_6,
]]
.into(),
time: [[
57070.238017592594,
57_070.250_007_592_59,
57070.262067592594,
]]
.into(),
observer_helio_position: Matrix3::zeros(),
};
let kepler = KeplerianElements {
reference_epoch: 57_049.242_334_573_75,
semi_major_axis: 1.8017360713154256,
eccentricity: 0.283_559_145_668_705_7,
inclination: 0.20267383288689386,
ascending_node_longitude: 7.955_979_023_693_781E-3,
periapsis_argument: 1.2451951387589135,
mean_anomaly: 0.44054589015887125,
};
let rms = traj
.rms_orbit_error(
&OUTFIT_HORIZON_TEST.0,
&triplets,
&kepler.into(),
-1.0,
30.,
None,
)
.unwrap();
assert_eq!(rms, 68.88650730830162);
}
mod test_batch_rms_correction {
use crate::constants::MJD;
use approx::assert_ulps_eq;
use smallvec::smallvec;
use super::*;
fn obs(observer: u16, time: MJD) -> Observation {
Observation {
observer,
ra: 1.0,
error_ra: 1e-6,
dec: 0.5,
error_dec: 2e-6,
time,
observer_earth_position: Vector3::zeros(),
observer_helio_position: Vector3::zeros(),
}
}
#[test]
fn test_single_batch_vfcc17_large() {
let base_time = 59000.0;
let mut obs: Observations = smallvec![
obs(1, base_time),
obs(1, base_time + 0.01),
obs(1, base_time + 0.02),
obs(1, base_time + 0.03),
obs(1, base_time + 0.04), ];
obs.apply_batch_rms_correction(&ErrorModel::VFCC17, 8.0 / 24.0);
let factor = (5.0_f64 * 0.25_f64).sqrt();
for ob in &obs {
assert_ulps_eq!(ob.error_ra, 1e-6 * factor, max_ulps = 2);
assert_ulps_eq!(ob.error_dec, 2e-6 * factor, max_ulps = 2);
}
}
#[test]
fn test_single_batch_small_n() {
let base_time = 59000.0;
let mut obs: Observations = smallvec![
obs(2, base_time),
obs(2, base_time + 0.01), ];
obs.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
let factor = (2.0f64).sqrt();
for ob in &obs {
assert_ulps_eq!(ob.error_ra, 1e-6 * factor, max_ulps = 2);
assert_ulps_eq!(ob.error_dec, 2e-6 * factor, max_ulps = 2);
}
}
#[test]
fn test_multiple_batches_same_observer() {
let base_time = 59000.0;
let mut obs: Observations = smallvec![
obs(3, base_time),
obs(3, base_time + 0.01), obs(3, base_time + 1.0), ];
obs.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
let factor1 = (2.0f64).sqrt();
let factor2 = 1.0;
assert_ulps_eq!(obs[0].error_ra, 1e-6 * factor1, max_ulps = 2);
assert_ulps_eq!(obs[1].error_ra, 1e-6 * factor1, max_ulps = 2);
assert_ulps_eq!(obs[2].error_ra, 1e-6 * factor2, max_ulps = 2);
}
#[test]
fn test_different_observers_are_not_grouped() {
let base_time = 59000.0;
let mut obs: Observations = smallvec![
obs(10, base_time),
obs(11, base_time + 0.01),
obs(12, base_time + 0.02),
];
obs.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
for ob in &obs {
assert_ulps_eq!(ob.error_ra, 1e-6, max_ulps = 2);
assert_ulps_eq!(ob.error_dec, 2e-6, max_ulps = 2);
}
}
#[test]
fn test_batch_gaps_exceed_gapmax() {
let mut obs: Observations = smallvec![
obs(5, 59000.0),
obs(5, 59001.0), ];
obs.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
for ob in &obs {
assert_ulps_eq!(ob.error_ra, 1e-6, max_ulps = 2);
assert_ulps_eq!(ob.error_dec, 2e-6, max_ulps = 2);
}
}
#[test]
#[cfg(feature = "jpl-download")]
fn test_batch_real_data() {
use crate::unit_test_global::OUTFIT_HORIZON_TEST;
let mut traj_set = OUTFIT_HORIZON_TEST.1.clone();
let traj = traj_set
.get_mut(&crate::constants::ObjectNumber::String("K09R05F".into()))
.expect("Failed to get trajectory");
traj.apply_batch_rms_correction(&ErrorModel::FCCT14, 8.0 / 24.0);
assert_ulps_eq!(traj[0].error_ra, 2.507075226057322e-6, max_ulps = 2);
assert_ulps_eq!(traj[0].error_dec, 2.036217397086327e-6, max_ulps = 2);
assert_ulps_eq!(traj[1].error_ra, 2.5070681687218917e-6, max_ulps = 2);
assert_ulps_eq!(traj[1].error_dec, 2.036217397086327e-6, max_ulps = 2);
assert_ulps_eq!(traj[2].error_ra, 2.507_059_507_890_695_2E-6, max_ulps = 2);
assert_ulps_eq!(traj[2].error_dec, 2.036217397086327e-6, max_ulps = 2);
}
}
mod test_extract_errors {
use super::*;
use approx::assert_ulps_eq;
use smallvec::smallvec;
fn make_observations() -> Observations {
smallvec![
Observation {
observer: 0,
ra: 1.0,
dec: 0.5,
error_ra: 1e-6,
error_dec: 2e-6,
time: 59000.0,
observer_earth_position: Vector3::zeros(),
observer_helio_position: Vector3::zeros(),
},
Observation {
observer: 0,
ra: 1.1,
dec: 0.6,
error_ra: 3e-6,
error_dec: 4e-6,
time: 59000.1,
observer_earth_position: Vector3::zeros(),
observer_helio_position: Vector3::zeros(),
},
Observation {
observer: 0,
ra: 1.2,
dec: 0.7,
error_ra: 5e-6,
error_dec: 6e-6,
time: 59000.2,
observer_earth_position: Vector3::zeros(),
observer_helio_position: Vector3::zeros(),
},
]
}
#[test]
fn test_extract_errors_basic() {
let obs = make_observations();
let idx_obs = Vector3::new(0, 1, 2);
let (ra_errors, dec_errors) = obs.extract_errors(idx_obs);
assert_ulps_eq!(ra_errors[0], 1e-6, max_ulps = 2);
assert_ulps_eq!(ra_errors[1], 3e-6, max_ulps = 2);
assert_ulps_eq!(ra_errors[2], 5e-6, max_ulps = 2);
assert_ulps_eq!(dec_errors[0], 2e-6, max_ulps = 2);
assert_ulps_eq!(dec_errors[1], 4e-6, max_ulps = 2);
assert_ulps_eq!(dec_errors[2], 6e-6, max_ulps = 2);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_extract_errors_out_of_bounds() {
let obs = make_observations();
let idx_obs = Vector3::new(0, 1, 10); let _ = obs.extract_errors(idx_obs);
}
}
#[test]
#[cfg(feature = "jpl-download")]
fn test_estimate_best_orbit() {
use approx::assert_relative_eq;
use rand::{rngs::StdRng, SeedableRng};
use crate::{
orbit_type::{
keplerian_element::KeplerianElements, orbit_type_test::approx_equal,
OrbitalElements,
},
unit_test_global::OUTFIT_HORIZON_TEST,
};
let mut traj_set = OUTFIT_HORIZON_TEST.1.clone();
let traj_number = crate::constants::ObjectNumber::String("K09R05F".into());
let traj_len = traj_set
.get(&traj_number)
.expect("Failed to get trajectory")
.len();
let traj = traj_set
.get_mut(&traj_number)
.expect("Failed to get trajectory");
let mut rng = StdRng::seed_from_u64(42_u64);
let gap_max = 8.0 / 24.0;
let params = IODParams {
n_noise_realizations: 5,
max_obs_for_triplets: traj_len,
gap_max,
..Default::default()
};
let (best_orbit, best_rms) = traj
.estimate_best_orbit(
&OUTFIT_HORIZON_TEST.0,
&ErrorModel::FCCT14,
&mut rng,
¶ms,
)
.unwrap();
let binding = best_orbit;
let orbit = binding.get_orbit();
let expected_orbit = OrbitalElements::Keplerian(KeplerianElements {
reference_epoch: 57049.22904488294,
semi_major_axis: 1.801748431600605,
eccentricity: 0.283572284127787,
inclination: 0.20266779609836036,
ascending_node_longitude: 0.008022659889281067,
periapsis_argument: 1.245060173584828,
mean_anomaly: 0.44047943792316746,
});
assert!(approx_equal(orbit, &expected_orbit, 1e-14));
assert_relative_eq!(best_rms, 55.14810894219461, epsilon = 1e-14);
}
}