pub mod index_generator;
use nalgebra::{Matrix3, Vector3};
use photom::observation_dataset::observation::Observation;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::cache::OutfitCache;
use crate::initial_orbit_determination::gauss::GaussObs;
use crate::initial_orbit_determination::triplet_generation::index_generator::TripletIndexGenerator;
use crate::IODParams;
#[derive(Clone, Copy, Debug)]
struct WeightedTriplet {
weight: f64,
first_idx: usize,
middle_idx: usize,
last_idx: usize,
}
impl PartialEq for WeightedTriplet {
fn eq(&self, other: &Self) -> bool {
self.weight == other.weight
&& self.first_idx == other.first_idx
&& self.middle_idx == other.middle_idx
&& self.last_idx == other.last_idx
}
}
impl Eq for WeightedTriplet {}
impl PartialOrd for WeightedTriplet {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WeightedTriplet {
fn cmp(&self, other: &Self) -> Ordering {
match self.weight.partial_cmp(&other.weight) {
Some(ord) => ord,
None => (self.first_idx, self.middle_idx, self.last_idx).cmp(&(
other.first_idx,
other.middle_idx,
other.last_idx,
)),
}
}
}
#[inline(always)]
pub fn triplet_weight(time1: f64, time2: f64, time3: f64, dtw: f64) -> f64 {
let inv_dtw = dtw.recip();
let dt12 = time2 - time1;
let dt23 = time3 - time2;
s_gap(dt12, inv_dtw) + s_gap(dt23, inv_dtw)
}
#[inline(always)]
pub fn triplet_weight_with_inv(time1: f64, time2: f64, time3: f64, inv_dtw: f64) -> f64 {
let dt12 = time2 - time1;
let dt23 = time3 - time2;
s_gap(dt12, inv_dtw) + s_gap(dt23, inv_dtw)
}
#[inline(always)]
fn s_gap(dt: f64, inv_dtw: f64) -> f64 {
let r = dt * inv_dtw;
if r <= 1.0 {
r.recip()
} else {
1.0 + r
}
}
pub fn generate_triplets(
observations: &[&Observation],
cache: &OutfitCache,
params: &IODParams,
) -> Vec<GaussObs> {
if params.max_triplets == 0 || observations.len() < 3 {
return Vec::new();
}
let mut index_gen = TripletIndexGenerator::from_observations(
observations,
params.dt_min,
params.dt_max_triplet,
params.max_obs_for_triplets,
usize::MAX,
);
let k_cap = params.max_triplets as usize;
let inv_dtw = params.optimal_interval_time.recip();
let best_k = collect_best_k_triplets(&mut index_gen, k_cap, inv_dtw);
best_k
.into_iter()
.map(|wt| build_gauss_obs(cache, observations, wt))
.collect()
}
fn collect_best_k_triplets(
gen: &mut TripletIndexGenerator,
max_triplets: usize,
inv_optimal_interval: f64,
) -> Vec<WeightedTriplet> {
let mut heap: BinaryHeap<WeightedTriplet> =
BinaryHeap::with_capacity(max_triplets.saturating_add(1));
while let Some((first, middle, last)) = gen.next() {
let times = gen.reduced_times();
let weight = triplet_weight_with_inv(
times[first],
times[middle],
times[last],
inv_optimal_interval,
);
if !weight.is_finite() {
continue;
}
if heap.len() < max_triplets {
heap.push(WeightedTriplet {
weight,
first_idx: first,
middle_idx: middle,
last_idx: last,
});
} else if heap.peek().is_some_and(|worst| weight < worst.weight) {
heap.pop();
heap.push(WeightedTriplet {
weight,
first_idx: first,
middle_idx: middle,
last_idx: last,
});
}
}
let mut result = heap.into_vec();
result.sort_unstable_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap_or(Ordering::Equal));
result
}
fn build_gauss_obs(
cache: &OutfitCache,
observations: &[&Observation],
wt: WeightedTriplet,
) -> GaussObs {
let o1 = &observations[wt.first_idx];
let o2 = &observations[wt.middle_idx];
let o3 = &observations[wt.last_idx];
let observer_matrix = Matrix3::from_columns(&[
*cache.get_helio_position(o1.index()),
*cache.get_helio_position(o2.index()),
*cache.get_helio_position(o3.index()),
]);
let (o1_ra, o1_dec) = (o1.equ_coord().ra, o1.equ_coord().dec);
let (o2_ra, o2_dec) = (o2.equ_coord().ra, o2.equ_coord().dec);
let (o3_ra, o3_dec) = (o3.equ_coord().ra, o3.equ_coord().dec);
GaussObs::with_observer_position(
Vector3::new(wt.first_idx, wt.middle_idx, wt.last_idx),
Vector3::new(o1_ra, o2_ra, o3_ra),
Vector3::new(o1_dec, o2_dec, o3_dec),
Vector3::new(o1.mjd_tt(), o2.mjd_tt(), o3.mjd_tt()),
observer_matrix.map(|x| x.into_inner()),
)
}
#[cfg(test)]
mod triplets_iod_tests {
use super::*;
#[test]
fn test_compute_triplets() {
use crate::cache::OutfitCache;
use crate::test_fixture::{DATASET_2015AB, JPL_EPHEM_HORIZON, UT1_PROVIDER};
use crate::IODParams;
use photom::observer::error_model::{ModelCorrection, ObsErrorModel};
let dataset = DATASET_2015AB
.clone()
.with_error_model(ObsErrorModel::FCCT14)
.apply_batch_rms_correction(30.0);
let cache = OutfitCache::build(&dataset, &JPL_EPHEM_HORIZON, &UT1_PROVIDER, false).unwrap();
let traj = dataset
.materialize_trajectory("K09R05F")
.unwrap()
.collect_into_vec();
assert!(
traj.len() >= 3,
"trajectory must have at least 3 observations"
);
let params = IODParams {
dt_min: 0.03,
dt_max_triplet: 150.0,
optimal_interval_time: 20.0,
max_obs_for_triplets: traj.len(),
max_triplets: 10,
..Default::default()
};
let triplets = generate_triplets(&traj, &cache, ¶ms);
assert!(!triplets.is_empty(), "expected at least one triplet");
assert!(
triplets.len() <= params.max_triplets as usize,
"got {} triplets, expected ≤ {}",
triplets.len(),
params.max_triplets
);
for window in triplets.windows(2) {
let t1 = &window[0].time;
let t2 = &window[1].time;
let w1 = triplet_weight(t1[0], t1[1], t1[2], params.optimal_interval_time);
let w2 = triplet_weight(t2[0], t2[1], t2[2], params.optimal_interval_time);
assert!(
w1 <= w2 + 1e-12,
"triplets not sorted by ascending weight: w1={w1} > w2={w2}"
);
}
for t in &triplets {
assert!(
t.idx_obs[0] < t.idx_obs[1] && t.idx_obs[1] < t.idx_obs[2],
"triplet indices not strictly increasing: {:?}",
t.idx_obs
);
}
}
mod downsampling_observations_tests {
use crate::initial_orbit_determination::triplet_generation::index_generator::downsample_uniform_with_edges;
#[test]
fn returns_all_when_max_keep_ge_n() {
let n = 5;
let indices = downsample_uniform_with_edges(n, 5);
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
let indices = downsample_uniform_with_edges(n, 10);
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn empty_input_returns_empty() {
assert!(downsample_uniform_with_edges(0, 0).is_empty());
assert!(downsample_uniform_with_edges(0, 10).is_empty());
}
#[test]
fn max_keep_less_than_three_returns_first_middle_last() {
let n = 10;
let mid = n / 2;
for max_keep in [1, 2, 3] {
let indices = downsample_uniform_with_edges(n, max_keep);
assert_eq!(indices, vec![0, mid, n - 1]);
}
}
#[test]
fn max_keep_three_exactly_returns_first_middle_last() {
let n = 10;
let indices = downsample_uniform_with_edges(n, 3);
assert_eq!(indices, vec![0, n / 2, n - 1]);
let n = 3;
let indices = downsample_uniform_with_edges(n, 3);
assert_eq!(indices, vec![0, 1, 2]);
}
#[test]
fn downsampling_uniformity_for_general_case() {
let n = 10;
let max_keep = 5;
let indices = downsample_uniform_with_edges(n, max_keep);
assert_eq!(indices.len(), max_keep);
assert_eq!(indices.first().unwrap(), &0);
assert_eq!(indices.last().unwrap(), &(n - 1));
assert!(indices.windows(2).all(|w| w[1] > w[0]));
}
#[test]
fn works_with_large_data() {
let n = 1000;
let max_keep = 100;
let indices = downsample_uniform_with_edges(n, max_keep);
assert_eq!(indices.len(), max_keep);
assert_eq!(indices.first().unwrap(), &0);
assert_eq!(indices.last().unwrap(), &(n - 1));
}
}
}