use nalgebra::{Matrix3, Vector3};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::constants::Observations;
use crate::initial_orbit_determination::gauss::GaussObs;
use crate::observations::triplets_generator::TripletIndexGenerator;
use crate::observations::Observation;
#[derive(Clone, Copy, Debug)]
struct WeightedTriplet {
weight: f64,
first_idx: usize,
middle_idx: usize,
last_idx: usize,
}
impl PartialEq for WeightedTriplet {
fn eq(&self, other: &Self) -> bool {
self.weight == other.weight
&& self.first_idx == other.first_idx
&& self.middle_idx == other.middle_idx
&& self.last_idx == other.last_idx
}
}
impl Eq for WeightedTriplet {}
impl PartialOrd for WeightedTriplet {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WeightedTriplet {
fn cmp(&self, other: &Self) -> Ordering {
match self.weight.partial_cmp(&other.weight) {
Some(ord) => ord,
None => (self.first_idx, self.middle_idx, self.last_idx).cmp(&(
other.first_idx,
other.middle_idx,
other.last_idx,
)),
}
}
}
#[inline(always)]
pub fn triplet_weight(time1: f64, time2: f64, time3: f64, dtw: f64) -> f64 {
let inv_dtw = dtw.recip();
let dt12 = time2 - time1;
let dt23 = time3 - time2;
s_gap(dt12, inv_dtw) + s_gap(dt23, inv_dtw)
}
#[inline(always)]
pub fn triplet_weight_with_inv(time1: f64, time2: f64, time3: f64, inv_dtw: f64) -> f64 {
let dt12 = time2 - time1;
let dt23 = time3 - time2;
s_gap(dt12, inv_dtw) + s_gap(dt23, inv_dtw)
}
#[inline(always)]
fn s_gap(dt: f64, inv_dtw: f64) -> f64 {
let r = dt * inv_dtw;
if r <= 1.0 {
r.recip()
} else {
1.0 + r
}
}
pub(crate) fn downsample_uniform_with_edges_indices(n: usize, max_keep: usize) -> Vec<usize> {
match n {
0 => Vec::new(),
_ if max_keep <= 3 => {
let mid = n / 2;
vec![0, mid, n - 1]
}
_ if max_keep >= n => (0..n).collect(),
_ => {
let slots = max_keep - 2;
std::iter::once(0)
.chain((0..slots).map(move |i| {
let fraction = (i + 1) as f64 / (slots + 1) as f64;
1 + (fraction * (n - 2) as f64).floor() as usize
}))
.chain(std::iter::once(n - 1))
.collect()
}
}
}
pub fn generate_triplets(
observations: &mut Observations,
dt_min: f64,
dt_max: f64,
optimal_interval_time: f64,
max_obs_for_triplets: usize,
max_triplet: u32,
) -> Vec<GaussObs> {
if max_triplet == 0 {
return Vec::new();
}
let mut index_gen = TripletIndexGenerator::from_observations(
observations,
dt_min,
dt_max,
max_obs_for_triplets,
usize::MAX, );
let k_cap = max_triplet as usize;
let mut heap: BinaryHeap<WeightedTriplet> = BinaryHeap::with_capacity(k_cap.saturating_add(1));
let mut push_best_k = |cand: WeightedTriplet| {
if !cand.weight.is_finite() {
return; }
if heap.len() < k_cap {
heap.push(cand);
} else if let Some(worst) = heap.peek() {
if cand.weight < worst.weight {
heap.pop();
heap.push(cand);
}
}
};
let inv_dtw = optimal_interval_time.recip(); while let Some((i, j, k)) = index_gen.next() {
let times = index_gen.reduced_times();
let w = triplet_weight_with_inv(times[i], times[j], times[k], inv_dtw);
push_best_k(WeightedTriplet {
weight: w,
first_idx: i,
middle_idx: j,
last_idx: k,
});
}
let mut best_reduced = heap.into_sorted_vec();
best_reduced.sort_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap());
let mapping = index_gen.selected_original_indices();
best_reduced
.into_iter()
.map(|wt| {
let (i, j, k) = (wt.first_idx, wt.middle_idx, wt.last_idx);
let oi = mapping[i];
let oj = mapping[j];
let ok = mapping[k];
let o1: &Observation = &observations[oi];
let o2: &Observation = &observations[oj];
let o3: &Observation = &observations[ok];
let observer_matrix: Matrix3<f64> = Matrix3::from_columns(&[
o1.get_observer_helio_position(),
o2.get_observer_helio_position(),
o3.get_observer_helio_position(),
]);
GaussObs::with_observer_position(
Vector3::new(oi, oj, ok),
Vector3::new(o1.ra, o2.ra, o3.ra),
Vector3::new(o1.dec, o2.dec, o3.dec),
Vector3::new(o1.time, o2.time, o3.time),
observer_matrix,
)
})
.collect()
}
#[cfg(test)]
mod triplets_iod_tests {
#[cfg(feature = "jpl-download")]
use approx::assert_relative_eq;
use super::*;
#[cfg(feature = "jpl-download")]
pub(crate) fn assert_gauss_obs_approx_eq(a: &GaussObs, b: &GaussObs, tol: f64) {
assert_eq!(a.idx_obs, b.idx_obs);
assert_relative_eq!(a.ra, b.ra, max_relative = tol);
assert_relative_eq!(a.dec, b.dec, max_relative = tol);
assert_relative_eq!(a.time, b.time, max_relative = tol);
}
#[test]
#[cfg(feature = "jpl-download")]
fn test_compute_triplets() {
use camino::Utf8Path;
use crate::{
trajectories::trajectory_file::TrajectoryFile, unit_test_global::OUTFIT_HORIZON_TEST,
TrajectorySet,
};
let mut env_state = OUTFIT_HORIZON_TEST.0.clone();
let mut traj_set =
TrajectorySet::new_from_80col(&mut env_state, Utf8Path::new("tests/data/2015AB.obs"));
let traj_number = crate::constants::ObjectNumber::String("K09R05F".into());
let traj_len = traj_set
.get(&traj_number)
.expect("Failed to get trajectory")
.len();
let traj_mut = traj_set
.get_mut(&traj_number)
.expect("Failed to get trajectory");
let triplets = generate_triplets(traj_mut, 0.03, 150.0, 20.0, traj_len, 10);
assert_eq!(
triplets.len(),
10,
"Expected 10 triplets, got {}",
triplets.len()
);
let expected_triplets = GaussObs {
idx_obs: [[23, 24, 33]].into(),
ra: [[1.6893715963476699, 1.689861452091063, 1.7527345385664372]].into(),
dec: [[1.082468037385525, 0.9436790189346231, 0.8273762407899986]].into(),
time: [[57028.479297592596, 57049.2318575926, 57063.97711759259]].into(),
observer_helio_position: [
[-0.2645666171486676, 0.8689351643673471, 0.3766996211112465],
[-0.5889735526502539, 0.7240117187952059, 0.3138734206791042],
[-0.7743874438017259, 0.5612884709246775, 0.2433497107566823],
]
.into(),
};
assert_gauss_obs_approx_eq(&triplets[0], &expected_triplets, 1e-12);
let expected_triplet = GaussObs {
idx_obs: [[21, 25, 33]].into(),
ra: [[1.6894680985108947, 1.6898894500811472, 1.7527345385664372]].into(),
dec: [[1.0825984522657437, 0.9435805047946215, 0.8273762407899986]].into(),
time: [[57028.45404759259, 57049.245147592585, 57063.97711759259]].into(),
observer_helio_position: [
[-0.26413563361674103, 0.8690466209095019, 0.3767466856686271],
[-0.5891631852172257, 0.7238872516832191, 0.3138186516545291],
[-0.7743874438017259, 0.5612884709246775, 0.2433497107566823],
]
.into(),
};
assert_gauss_obs_approx_eq(&triplets[9], &expected_triplet, 1e-12);
}
mod downsampling_observations_tests {
use nalgebra::Vector3;
use super::*;
fn make_obs(n: usize) -> Observations {
(0..n)
.map(|i| Observation {
observer: 0,
ra: 0.0,
dec: 0.0,
error_ra: 0.0,
error_dec: 0.0,
time: i as f64,
observer_earth_position: Vector3::zeros(),
observer_helio_position: Vector3::zeros(),
})
.collect()
}
#[test]
fn returns_all_when_max_keep_ge_n() {
let n = 5;
let indices = downsample_uniform_with_edges_indices(n, 5);
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
let indices = downsample_uniform_with_edges_indices(n, 10);
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn empty_input_returns_empty() {
assert!(downsample_uniform_with_edges_indices(0, 0).is_empty());
assert!(downsample_uniform_with_edges_indices(0, 10).is_empty());
}
#[test]
fn max_keep_less_than_three_returns_first_middle_last() {
let n = 10;
let mid = n / 2;
for max_keep in [0, 1, 2] {
let indices = downsample_uniform_with_edges_indices(n, max_keep);
assert_eq!(indices, vec![0, mid, n - 1]);
}
}
#[test]
fn max_keep_three_exactly_returns_first_middle_last() {
let n = 10;
let indices = downsample_uniform_with_edges_indices(n, 3);
assert_eq!(indices, vec![0, n / 2, n - 1]);
let n = 3;
let indices = downsample_uniform_with_edges_indices(n, 3);
assert_eq!(indices, vec![0, 1, 2]);
}
#[test]
fn downsampling_uniformity_for_general_case() {
let n = 10;
let max_keep = 5;
let indices = downsample_uniform_with_edges_indices(n, max_keep);
assert_eq!(indices.len(), max_keep);
assert_eq!(indices.first().unwrap(), &0);
assert_eq!(indices.last().unwrap(), &(n - 1));
assert!(indices.windows(2).all(|w| w[1] > w[0]));
}
#[test]
fn works_with_large_data() {
let n = 1000;
let max_keep = 100;
let indices = downsample_uniform_with_edges_indices(n, max_keep);
assert_eq!(indices.len(), max_keep);
assert_eq!(indices.first().unwrap(), &0);
assert_eq!(indices.last().unwrap(), &(n - 1));
}
#[test]
fn indices_match_observations() {
let obs = make_obs(10);
let max_keep = 5;
let indices = downsample_uniform_with_edges_indices(obs.len(), max_keep);
let times: Vec<_> = indices.iter().map(|&i| obs[i].time).collect();
assert!(times.windows(2).all(|w| w[1] > w[0]));
}
}
}