use hifitime::ut1::Ut1Provider;
use photom::{
observation_dataset::{observation::Observation, ObsDataset},
observer::error_model::{ModelCorrection, ObsErrorModel},
TrajId,
};
use rand::{rngs::SmallRng, SeedableRng};
use std::collections::HashMap;
use crate::{
cache::OutfitCache, constants::FitOrbitResult, trajectory::TrajectoryFit, FullOrbitResult,
IODParams, JPLEphem, OutfitError,
};
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
pub trait FitIOD {
fn fit_full_iod(
self,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FullOrbitResult, OutfitError>;
#[cfg(feature = "parallel")]
fn fit_full_iod_parallel(
self,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FullOrbitResult, OutfitError>;
fn fit_iod(
self,
traj: impl Into<TrajId>,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FitOrbitResult, OutfitError>;
}
impl FitIOD for ObsDataset {
fn fit_iod(
self,
traj: impl Into<TrajId>,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FitOrbitResult, OutfitError> {
let (corrected_dataset, cache, _) =
prepare_iod(self, jpl, ut1_provider, params, error_model, rng)?;
fit_single_traj(&traj.into(), &corrected_dataset, &cache, jpl, params, rng)
}
fn fit_full_iod(
self,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FullOrbitResult, OutfitError> {
let (corrected_dataset, cache, base_seed) =
prepare_iod(self, jpl, ut1_provider, params, error_model, rng)?;
corrected_dataset
.iter_traj_id()
.ok_or(OutfitError::NoTrajectoryIndex)
.map(|iter| {
let mut map = HashMap::with_hasher(ahash::RandomState::new());
iter.map(|traj_id| {
process_traj(traj_id, &corrected_dataset, &cache, jpl, params, base_seed)
})
.for_each(|(k, v)| {
map.insert(k, v);
});
map
})
}
#[cfg(feature = "parallel")]
fn fit_full_iod_parallel(
self,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<FullOrbitResult, OutfitError> {
let (corrected_dataset, cache, base_seed) =
prepare_iod(self, jpl, ut1_provider, params, error_model, rng)?;
let new_map = || HashMap::with_hasher(ahash::RandomState::new());
corrected_dataset
.par_iter_traj_id()
.ok_or(OutfitError::NoTrajectoryIndex)
.map(|iter| {
iter.map(|traj_id| {
process_traj(traj_id, &corrected_dataset, &cache, jpl, params, base_seed)
})
.fold(new_map, |mut map, (k, v)| {
map.insert(k, v);
map
})
.reduce(new_map, |mut a, b| {
a.extend(b);
a
})
})
}
}
fn fit_single_traj(
traj: &TrajId,
corrected_dataset: &ObsDataset,
cache: &OutfitCache,
jpl: &JPLEphem,
params: &IODParams,
rng: &mut impl rand::Rng,
) -> Result<FitOrbitResult, OutfitError> {
let materialized_traj = corrected_dataset
.materialize_trajectory(traj)
.ok_or_else(|| OutfitError::TrajectoryIdNotFound(traj.clone()))?;
let mut obs_vec_refs: Vec<&Observation> = materialized_traj.collect_into_vec();
obs_vec_refs.sort_by(|a, b| a.mjd_tt().total_cmp(&b.mjd_tt()));
obs_vec_refs.estimate_best_orbit(cache, jpl, params, rng)
}
pub(crate) fn run_iod_on_observations(
observations: &[Observation],
cache: &OutfitCache,
jpl: &JPLEphem,
params: &IODParams,
rng: &mut impl rand::Rng,
) -> Result<FitOrbitResult, OutfitError> {
let mut refs: Vec<&Observation> = observations.iter().collect();
refs.sort_by(|a, b| a.mjd_tt().total_cmp(&b.mjd_tt()));
refs.estimate_best_orbit(cache, jpl, params, rng)
}
fn prepare_iod(
dataset: ObsDataset,
jpl: &JPLEphem,
ut1_provider: &Ut1Provider,
params: &IODParams,
error_model: ObsErrorModel,
rng: &mut impl rand::Rng,
) -> Result<(ObsDataset, OutfitCache, u64), OutfitError> {
let corrected_dataset = dataset
.with_error_model(error_model)
.apply_model_errors()
.apply_batch_rms_correction(params.gap_max);
let cache = OutfitCache::build(&corrected_dataset, jpl, ut1_provider, true)?;
let base_seed: u64 = rng.random();
Ok((corrected_dataset, cache, base_seed))
}
fn process_traj(
traj_id: &TrajId,
corrected_dataset: &ObsDataset,
cache: &OutfitCache,
jpl: &JPLEphem,
params: &IODParams,
base_seed: u64,
) -> (TrajId, Result<FitOrbitResult, OutfitError>) {
let traj_seed = base_seed ^ traj_id.stable_hash();
let mut local_rng = SmallRng::seed_from_u64(traj_seed);
let result = fit_single_traj(
traj_id,
corrected_dataset,
cache,
jpl,
params,
&mut local_rng,
);
(traj_id.clone(), result)
}