use photom::observation_dataset::observation::Observation;
pub(crate) fn downsample_uniform_with_edges(n: usize, max_keep: usize) -> Vec<usize> {
match n {
0 => vec![],
_ if max_keep >= n => (0..n).collect(),
_ if max_keep <= 3 => vec![0, n / 2, n - 1],
_ => (0..max_keep)
.map(|i| i * (n - 1) / (max_keep - 1))
.collect(),
}
}
#[derive(Debug, Clone, Copy)]
struct LastWindow {
lo: usize,
hi: usize,
}
impl LastWindow {
fn compute(anchor: usize, epochs: &[f64], dt_min: f64, dt_max: f64) -> Self {
let n = epochs.len();
let t0 = epochs[anchor];
let mut lo = anchor + 2;
while lo < n && epochs[lo] - t0 < dt_min {
lo += 1;
}
let mut hi = lo.saturating_sub(1).max(anchor + 1);
while hi + 1 < n && epochs[hi + 1] - t0 <= dt_max {
hi += 1;
}
Self { lo, hi }
}
fn is_empty(&self, anchor: usize, n: usize) -> bool {
self.lo >= n || self.lo > self.hi || self.hi <= anchor + 1
}
}
pub struct TripletIndexGenerator {
epochs: Vec<f64>,
anchor: usize,
middle: usize,
last: usize,
window: LastWindow,
n: usize,
dt_min: f64,
dt_max: f64,
remaining: usize,
}
impl TripletIndexGenerator {
pub fn new(epochs: Vec<f64>, dt_min: f64, dt_max: f64, cap: usize) -> Self {
let n = epochs.len();
let window = if n >= 3 {
LastWindow::compute(0, &epochs, dt_min, dt_max)
} else {
LastWindow { lo: n, hi: 0 }
};
Self {
n,
epochs,
dt_min,
dt_max,
anchor: 0,
middle: 1,
last: window.lo.max(2),
window,
remaining: cap,
}
}
pub fn from_observations(
observations: &[&Observation],
dt_min: f64,
dt_max: f64,
max_reduced: usize,
cap: usize,
) -> Self {
let keep = downsample_uniform_with_edges(observations.len(), max_reduced);
let epochs: Vec<f64> = keep.iter().map(|&i| observations[i].mjd_tt()).collect();
Self::new(epochs, dt_min, dt_max, cap)
}
pub fn reduced_times(&self) -> &[f64] {
&self.epochs
}
fn advance_anchor(&mut self) -> bool {
self.anchor += 1;
if self.anchor + 2 >= self.n {
return false;
}
self.window = LastWindow::compute(self.anchor, &self.epochs, self.dt_min, self.dt_max);
self.middle = self.anchor + 1;
self.last = self.window.lo.max(self.middle + 1);
true
}
#[inline]
fn reset_last_for_middle(&mut self) {
self.last = self.window.lo.max(self.middle + 1);
}
}
impl Iterator for TripletIndexGenerator {
type Item = (usize, usize, usize);
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
loop {
if self.anchor + 2 >= self.n {
return None;
}
if self.window.is_empty(self.anchor, self.n) {
if !self.advance_anchor() {
return None;
}
continue;
}
if self.middle >= self.window.hi {
if !self.advance_anchor() {
return None;
}
continue;
}
if self.last <= self.middle {
self.reset_last_for_middle();
}
if self.last > self.window.hi {
self.middle += 1;
self.reset_last_for_middle();
continue;
}
let triplet = (self.anchor, self.middle, self.last);
self.last += 1;
self.remaining -= 1;
return Some(triplet);
}
}
}
#[cfg(test)]
mod triplet_generator_tests {
use super::*;
use proptest::prelude::*;
fn gen_from_epochs(epochs: Vec<f64>, dt_min: f64, dt_max: f64) -> TripletIndexGenerator {
TripletIndexGenerator::new(epochs, dt_min, dt_max, usize::MAX)
}
fn collect_and_validate(
epochs: &[f64],
dt_min: f64,
dt_max: f64,
) -> Vec<(usize, usize, usize)> {
let gen = gen_from_epochs(epochs.to_vec(), dt_min, dt_max);
let mut out = Vec::new();
for (i, j, k) in gen {
assert!(i < j, "first < middle violated: ({i},{j},{k})");
assert!(j < k, "middle < last violated: ({i},{j},{k})");
let span = epochs[k] - epochs[i];
assert!(
span >= dt_min - 1e-12,
"span {span} < dt_min {dt_min}: ({i},{j},{k})"
);
assert!(
span <= dt_max + 1e-12,
"span {span} > dt_max {dt_max}: ({i},{j},{k})"
);
out.push((i, j, k));
}
out
}
fn brute_force(epochs: &[f64], dt_min: f64, dt_max: f64) -> Vec<(usize, usize, usize)> {
let n = epochs.len();
let mut out = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
for k in (j + 1)..n {
let span = epochs[k] - epochs[i];
if span >= dt_min - 1e-12 && span <= dt_max + 1e-12 {
out.push((i, j, k));
}
}
}
}
out
}
#[test]
fn downsample_empty() {
assert!(downsample_uniform_with_edges(0, 10).is_empty());
}
#[test]
fn downsample_no_op_when_max_ge_n() {
let result = downsample_uniform_with_edges(5, 10);
assert_eq!(result, vec![0, 1, 2, 3, 4]);
}
#[test]
fn downsample_exact_n() {
let result = downsample_uniform_with_edges(5, 5);
assert_eq!(result, vec![0, 1, 2, 3, 4]);
}
#[test]
fn downsample_max_keep_3() {
let result = downsample_uniform_with_edges(9, 3);
assert_eq!(result, vec![0, 4, 8]);
}
#[test]
fn downsample_max_keep_le_3_small_n() {
let result = downsample_uniform_with_edges(3, 2);
assert_eq!(result[0], 0);
assert_eq!(*result.last().unwrap(), 2);
}
#[test]
fn downsample_endpoints_always_present() {
for n in 4..=20 {
for max_keep in 3..=n {
let result = downsample_uniform_with_edges(n, max_keep);
assert_eq!(
result[0], 0,
"first endpoint missing for n={n} max={max_keep}"
);
assert_eq!(
*result.last().unwrap(),
n - 1,
"last endpoint missing for n={n} max={max_keep}"
);
}
}
}
#[test]
fn downsample_length_respects_max_keep() {
for n in 4..=30 {
for max_keep in 3..n {
let result = downsample_uniform_with_edges(n, max_keep);
assert!(
result.len() <= max_keep,
"len={} > max_keep={max_keep} for n={n}",
result.len()
);
}
}
}
#[test]
fn downsample_strictly_increasing() {
let result = downsample_uniform_with_edges(100, 10);
for w in result.windows(2) {
assert!(w[0] < w[1], "not strictly increasing: {:?}", result);
}
}
#[test]
fn generator_empty_on_fewer_than_3_obs() {
for n in 0..=2 {
let epochs: Vec<f64> = (0..n).map(|i| i as f64).collect();
let triplets = collect_and_validate(&epochs, 0.0, 10.0);
assert!(triplets.is_empty(), "expected empty for n={n}");
}
}
#[test]
fn generator_empty_when_dt_min_gt_dt_max() {
let epochs = vec![0.0, 1.0, 2.0, 3.0];
let triplets = collect_and_validate(&epochs, 5.0, 2.0);
assert!(triplets.is_empty());
}
#[test]
fn generator_empty_when_all_spans_below_dt_min() {
let epochs = vec![0.0, 1.0, 2.0, 3.0];
let triplets = collect_and_validate(&epochs, 10.0, 100.0);
assert!(triplets.is_empty());
}
#[test]
fn generator_empty_when_all_spans_above_dt_max() {
let epochs = vec![0.0, 10.0, 20.0, 30.0];
let triplets = collect_and_validate(&epochs, 0.0, 1.0);
assert!(triplets.is_empty());
}
#[test]
fn generator_single_feasible_triplet() {
let epochs = vec![0.0, 1.0, 2.0];
let triplets = collect_and_validate(&epochs, 2.0, 2.0);
assert_eq!(triplets, vec![(0, 1, 2)]);
}
#[test]
fn generator_matches_brute_force_small() {
let epochs = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let dt_min = 1.5;
let dt_max = 3.5;
let mut got = collect_and_validate(&epochs, dt_min, dt_max);
let mut expected = brute_force(&epochs, dt_min, dt_max);
got.sort();
expected.sort();
assert_eq!(got, expected);
}
#[test]
fn generator_matches_brute_force_no_constraint() {
let epochs: Vec<f64> = (0..6).map(|i| i as f64).collect();
let mut got = collect_and_validate(&epochs, 0.0, f64::MAX);
let mut expected = brute_force(&epochs, 0.0, f64::MAX);
got.sort();
expected.sort();
assert_eq!(got, expected);
}
#[test]
fn generator_matches_brute_force_equal_spacing() {
let epochs: Vec<f64> = (0..7).map(|i| i as f64 * 2.0).collect();
let dt_min = 3.0;
let dt_max = 9.0;
let mut got = collect_and_validate(&epochs, dt_min, dt_max);
let mut expected = brute_force(&epochs, dt_min, dt_max);
got.sort();
expected.sort();
assert_eq!(got, expected);
}
#[test]
fn generator_no_duplicates() {
let epochs: Vec<f64> = (0..8).map(|i| i as f64).collect();
let mut triplets = collect_and_validate(&epochs, 1.0, 6.0);
triplets.sort();
triplets.dedup();
let all = collect_and_validate(&epochs, 1.0, 6.0);
assert_eq!(triplets.len(), all.len(), "duplicates detected");
}
#[test]
fn generator_respects_max_triplets_cap() {
let epochs: Vec<f64> = (0..10).map(|i| i as f64).collect();
let cap = 5;
let gen = TripletIndexGenerator::new(epochs, 1.0, 20.0, cap);
let count = gen.count();
assert_eq!(count, cap);
}
#[test]
fn generator_cap_zero_yields_nothing() {
let epochs = vec![0.0, 1.0, 2.0, 3.0];
let gen = TripletIndexGenerator::new(epochs, 0.0, 10.0, 0);
assert_eq!(gen.count(), 0);
}
#[test]
fn reduced_times_match_input_epochs() {
let epochs: Vec<f64> = (0..5).map(|i| i as f64).collect();
let gen = TripletIndexGenerator::new(epochs.clone(), 0.0, 10.0, usize::MAX);
assert_eq!(gen.reduced_times(), epochs.as_slice());
}
#[test]
fn reduced_times_aligned_with_mapping() {
let epochs = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let gen = TripletIndexGenerator::new(epochs.clone(), 0.0, 20.0, usize::MAX);
assert_eq!(gen.reduced_times().len(), epochs.len());
}
proptest! {
#[test]
fn prop_all_invariants_hold(
n in 3usize..=12,
steps in prop::collection::vec(0.1f64..5.0, 11),
dt_min in 0.0f64..10.0,
dt_range in 0.0f64..20.0,
) {
let dt_max = dt_min + dt_range;
let mut epochs = vec![0.0f64];
for &s in steps.iter().take(n - 1) {
epochs.push(epochs.last().unwrap() + s);
}
let _triplets = collect_and_validate(&epochs, dt_min, dt_max);
}
#[test]
fn prop_matches_brute_force(
n in 3usize..=8,
steps in prop::collection::vec(0.5f64..3.0, 7),
dt_min in 0.0f64..5.0,
dt_range in 0.0f64..10.0,
) {
let dt_max = dt_min + dt_range;
let mut epochs = vec![0.0f64];
for &s in steps.iter().take(n - 1) {
epochs.push(epochs.last().unwrap() + s);
}
let mut got = collect_and_validate(&epochs, dt_min, dt_max);
let mut expected = brute_force(&epochs, dt_min, dt_max);
got.sort();
expected.sort();
prop_assert_eq!(got, expected);
}
#[test]
fn prop_no_duplicates(
n in 3usize..=10,
steps in prop::collection::vec(0.1f64..4.0, 9),
dt_min in 0.0f64..5.0,
dt_range in 0.0f64..15.0,
) {
let dt_max = dt_min + dt_range;
let mut epochs = vec![0.0f64];
for &s in steps.iter().take(n - 1) {
epochs.push(epochs.last().unwrap() + s);
}
let mut triplets = collect_and_validate(&epochs, dt_min, dt_max);
let total = triplets.len();
triplets.sort();
triplets.dedup();
prop_assert_eq!(triplets.len(), total, "duplicate triplets found");
}
#[test]
fn prop_cap_respected(
n in 3usize..=12,
steps in prop::collection::vec(0.5f64..3.0, 11),
cap in 0usize..=20,
) {
let mut epochs = vec![0.0f64];
for &s in steps.iter().take(n - 1) {
epochs.push(epochs.last().unwrap() + s);
}
let gen = TripletIndexGenerator::new(
epochs, 0.0, f64::MAX, cap,
);
prop_assert!(gen.count() <= cap);
}
#[test]
fn prop_downsample_endpoints_and_length(
n in 1usize..=50,
max_keep in 3usize..=50,
) {
let result = downsample_uniform_with_edges(n, max_keep);
prop_assert!(result.len() <= max_keep.min(n).max(3));
if n >= 1 {
prop_assert_eq!(result[0], 0);
prop_assert_eq!(*result.last().unwrap(), n - 1);
}
}
#[test]
fn prop_downsample_strictly_increasing(
n in 2usize..=50,
max_keep in 3usize..=50,
) {
let result = downsample_uniform_with_edges(n, max_keep);
for w in result.windows(2) {
prop_assert!(w[0] < w[1]);
}
}
}
}