use ahash::AHashMap;
use itertools::Either;
use crate::{NightId, TrajId, observation_dataset::ObsId};
pub type ObsIndex = usize;
pub type ObservationIndexMap = AHashMap<ObsId, ObsIndex>;
#[derive(Debug, Clone)]
pub enum ObsMapIndex {
#[cfg_attr(not(feature = "polars"), allow(dead_code))]
Contiguous { start: ObsIndex, end: ObsIndex },
Split(Vec<ObsIndex>),
}
impl ObsMapIndex {
#[cfg_attr(not(feature = "polars"), allow(dead_code))]
pub(crate) fn push_split(&mut self, idx: ObsIndex) {
match self {
ObsMapIndex::Split(vec) => vec.push(idx),
ObsMapIndex::Contiguous { .. } => {
panic!("push_split called on a Contiguous ObsMapIndex entry")
}
}
}
}
pub type NightIndexMap = AHashMap<NightId, ObsMapIndex>;
pub type TrajIndexMap = AHashMap<TrajId, ObsMapIndex>;
pub type TrajAliasMap = AHashMap<String, TrajId>;
fn shift_obs_map_index(idx: ObsMapIndex, offset: usize) -> ObsMapIndex {
match idx {
ObsMapIndex::Contiguous { start, end } => ObsMapIndex::Contiguous {
start: start + offset,
end: end + offset,
},
ObsMapIndex::Split(mut v) => {
v.iter_mut().for_each(|i| *i += offset);
ObsMapIndex::Split(v)
}
}
}
fn merge_obs_map<K>(
self_map: &mut AHashMap<K, ObsMapIndex>,
other_map: AHashMap<K, ObsMapIndex>,
offset: usize,
) where
K: Eq + std::hash::Hash,
{
for (key, other_idx) in other_map {
let shifted = shift_obs_map_index(other_idx, offset);
self_map
.entry(key)
.and_modify(|existing| {
let mut merged: Vec<ObsIndex> = match existing {
ObsMapIndex::Contiguous { start, end } => (*start..*end).collect(),
ObsMapIndex::Split(v) => std::mem::take(v),
};
match &shifted {
ObsMapIndex::Contiguous { start, end } => merged.extend(*start..*end),
ObsMapIndex::Split(v) => merged.extend_from_slice(v),
}
*existing = ObsMapIndex::Split(merged);
})
.or_insert(shifted);
}
}
fn merge_optional_obs_map<K>(
self_opt: &mut Option<AHashMap<K, ObsMapIndex>>,
other_opt: Option<AHashMap<K, ObsMapIndex>>,
offset: usize,
) where
K: Eq + std::hash::Hash,
{
let Some(other_map) = other_opt else { return };
match self_opt {
Some(self_map) => merge_obs_map(self_map, other_map, offset),
None => {
let shifted = other_map
.into_iter()
.map(|(k, idx)| (k, shift_obs_map_index(idx, offset)))
.collect();
*self_opt = Some(shifted);
}
}
}
#[derive(Debug, Clone)]
pub struct ObsDatasetIndex {
pub(crate) obs_index_by_id: ObservationIndexMap,
pub(crate) obs_index_by_night: Option<NightIndexMap>,
pub(crate) obs_index_by_trajectory: Option<TrajIndexMap>,
pub(crate) traj_aliases: TrajAliasMap,
}
impl ObsDatasetIndex {
#[cfg_attr(not(feature = "polars"), allow(dead_code))]
pub(crate) fn new(
obs_index_by_id: ObservationIndexMap,
obs_index_by_night: Option<NightIndexMap>,
obs_index_by_trajectory: Option<TrajIndexMap>,
) -> Self {
Self {
obs_index_by_id,
obs_index_by_night,
obs_index_by_trajectory,
traj_aliases: TrajAliasMap::new(),
}
}
pub(crate) fn len_night(&self, night_id: &NightId) -> Option<usize> {
self.obs_index_by_night
.as_ref()?
.get(night_id)
.map(|indices| match indices {
ObsMapIndex::Contiguous { start, end } => end - start,
ObsMapIndex::Split(vec) => vec.len(),
})
}
pub(crate) fn len_trajectory(&self, traj_id: impl Into<TrajId>) -> Option<usize> {
let traj_id = traj_id.into();
self.obs_index_by_trajectory
.as_ref()?
.get(&traj_id)
.map(|indices| match indices {
ObsMapIndex::Contiguous { start, end } => end - start,
ObsMapIndex::Split(vec) => vec.len(),
})
}
pub(crate) fn iter_night_id(&self) -> Option<impl Iterator<Item = &NightId>> {
self.obs_index_by_night
.as_ref()
.map(|night_map| night_map.keys())
}
pub(crate) fn iter_traj_id(&self) -> Option<impl Iterator<Item = &TrajId>> {
self.obs_index_by_trajectory
.as_ref()
.map(|traj_map| traj_map.keys())
}
pub(crate) fn get_by_id(&self, obs_id: &ObsId) -> Option<ObsIndex> {
self.obs_index_by_id.get(obs_id).copied()
}
pub(crate) fn get_by_night(&self, night_id: &NightId) -> Option<&ObsMapIndex> {
self.obs_index_by_night.as_ref()?.get(night_id)
}
pub(crate) fn iter_night_obs_index(
&self,
night_id: &NightId,
) -> Option<impl Iterator<Item = ObsIndex> + '_> {
self.get_by_night(night_id).map(|indices| match indices {
ObsMapIndex::Contiguous { start, end } => Either::Left(*start..*end),
ObsMapIndex::Split(vec) => Either::Right(vec.iter().copied()),
})
}
pub(crate) fn iter_full_night(&self) -> Option<impl Iterator<Item = (NightId, ObsIndex)> + '_> {
self.obs_index_by_night.as_ref().map(|night_map| {
night_map
.iter()
.flat_map(|(night_id, indices)| match indices {
ObsMapIndex::Contiguous { start, end } => {
Either::Left((*start..*end).map(move |idx| (*night_id, idx)))
}
ObsMapIndex::Split(vec) => {
Either::Right(vec.iter().map(move |&idx| (*night_id, idx)))
}
})
})
}
pub(crate) fn get_by_trajectory(&self, traj_id: impl Into<TrajId>) -> Option<&ObsMapIndex> {
let traj_id = traj_id.into();
self.obs_index_by_trajectory.as_ref()?.get(&traj_id)
}
pub(crate) fn iter_traj_obs_index(
&self,
traj_id: impl Into<TrajId>,
) -> Option<impl Iterator<Item = ObsIndex> + '_> {
self.get_by_trajectory(traj_id)
.map(|indices| match indices {
ObsMapIndex::Contiguous { start, end } => Either::Left(*start..*end),
ObsMapIndex::Split(vec) => Either::Right(vec.iter().copied()),
})
}
pub(crate) fn iter_full_trajectory(
&self,
) -> Option<impl Iterator<Item = (TrajId, ObsIndex)> + '_> {
self.obs_index_by_trajectory.as_ref().map(|traj_map| {
traj_map
.iter()
.flat_map(|(traj_id, indices)| match indices {
ObsMapIndex::Contiguous { start, end } => {
Either::Left((*start..*end).map(move |idx| (traj_id.clone(), idx)))
}
ObsMapIndex::Split(vec) => {
Either::Right(vec.iter().map(move |&idx| (traj_id.clone(), idx)))
}
})
})
}
#[cfg(feature = "mpc_80_col")]
pub(crate) fn register_alias(&mut self, alias: String, primary: TrajId) {
self.traj_aliases.insert(alias, primary);
}
pub(crate) fn resolve_alias(&self, alias: &str) -> Option<&TrajId> {
self.traj_aliases.get(alias)
}
#[cfg(feature = "serde")]
pub(crate) fn iter_aliases(&self) -> impl Iterator<Item = (&str, &TrajId)> {
self.traj_aliases.iter().map(|(k, v)| (k.as_str(), v))
}
#[cfg(feature = "serde")]
pub(crate) fn set_aliases(&mut self, aliases: TrajAliasMap) {
self.traj_aliases = aliases;
}
#[cfg_attr(not(any(feature = "ades", feature = "mpc_80_col")), allow(dead_code))]
pub(crate) fn merge_from(&mut self, other: ObsDatasetIndex, offset: usize) {
self.obs_index_by_id.reserve(other.obs_index_by_id.len());
for (id, pos) in other.obs_index_by_id {
self.obs_index_by_id.insert(id, pos + offset);
}
merge_optional_obs_map(
&mut self.obs_index_by_night,
other.obs_index_by_night,
offset,
);
merge_optional_obs_map(
&mut self.obs_index_by_trajectory,
other.obs_index_by_trajectory,
offset,
);
self.traj_aliases.extend(other.traj_aliases);
}
pub(crate) fn push_trajectory(mut self, traj_id: TrajId, obs_index: &[ObsIndex]) -> Self {
if let Some(traj_map) = self.obs_index_by_trajectory.as_mut() {
traj_map.insert(traj_id, ObsMapIndex::Split(obs_index.to_vec()));
}
self
}
}
#[cfg(test)]
mod obs_map_index_unit_tests {
use crate::{
NightId, TrajId,
observation_dataset::index::{
NightIndexMap, ObsDatasetIndex, ObsMapIndex, ObservationIndexMap, TrajIndexMap,
},
};
#[test]
fn push_split_appends_to_split_entry() {
let mut entry = ObsMapIndex::Split(vec![0, 1]);
entry.push_split(2);
match entry {
ObsMapIndex::Split(v) => assert_eq!(v, vec![0, 1, 2]),
ObsMapIndex::Contiguous { .. } => panic!("expected Split"),
}
}
#[test]
#[should_panic(expected = "push_split called on a Contiguous ObsMapIndex entry")]
fn push_split_panics_on_contiguous() {
let mut entry = ObsMapIndex::Contiguous { start: 0, end: 5 };
entry.push_split(5); }
#[test]
fn len_night_both_variants() {
let mut night_map: NightIndexMap = ahash::AHashMap::new();
night_map.insert(NightId(1), ObsMapIndex::Contiguous { start: 0, end: 3 });
night_map.insert(NightId(2), ObsMapIndex::Split(vec![4, 5]));
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(night_map), None);
assert_eq!(
idx.len_night(&NightId(1)),
Some(3),
"Contiguous(0..3) must report len 3"
);
assert_eq!(
idx.len_night(&NightId(2)),
Some(2),
"Split([4,5]) must report len 2"
);
assert_eq!(
idx.len_night(&NightId(99)),
None,
"unknown night must return None"
);
}
#[test]
fn len_trajectory_both_variants() {
let mut traj_map: TrajIndexMap = ahash::AHashMap::new();
traj_map.insert(
TrajId::Int(10),
ObsMapIndex::Contiguous { start: 0, end: 4 },
);
traj_map.insert(TrajId::Int(20), ObsMapIndex::Split(vec![0, 2, 4]));
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(traj_map));
assert_eq!(
idx.len_trajectory(TrajId::Int(10)),
Some(4),
"Contiguous(0..4) must report len 4"
);
assert_eq!(
idx.len_trajectory(TrajId::Int(20)),
Some(3),
"Split([0,2,4]) must report len 3"
);
assert_eq!(
idx.len_trajectory(TrajId::Int(99)),
None,
"unknown traj must return None"
);
}
#[test]
fn iter_night_obs_index_both_variants() {
let mut night_map: NightIndexMap = ahash::AHashMap::new();
night_map.insert(NightId(1), ObsMapIndex::Contiguous { start: 2, end: 5 });
night_map.insert(NightId(2), ObsMapIndex::Split(vec![0, 7, 9]));
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(night_map), None);
let contiguous_indices: Vec<_> = idx
.iter_night_obs_index(&NightId(1))
.expect("night 1 must be present")
.collect();
assert_eq!(contiguous_indices, vec![2, 3, 4]);
let split_indices: Vec<_> = idx
.iter_night_obs_index(&NightId(2))
.expect("night 2 must be present")
.collect();
assert_eq!(split_indices, vec![0, 7, 9]);
}
#[test]
fn iter_traj_obs_index_both_variants() {
let mut traj_map: TrajIndexMap = ahash::AHashMap::new();
traj_map.insert(
TrajId::Int(10),
ObsMapIndex::Contiguous { start: 0, end: 3 },
);
traj_map.insert(TrajId::Int(20), ObsMapIndex::Split(vec![5, 6]));
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(traj_map));
let contiguous_indices: Vec<_> = idx
.iter_traj_obs_index(&TrajId::Int(10))
.expect("traj 10 must be present")
.collect();
assert_eq!(contiguous_indices, vec![0, 1, 2]);
let split_indices: Vec<_> = idx
.iter_traj_obs_index(&TrajId::Int(20))
.expect("traj 20 must be present")
.collect();
assert_eq!(split_indices, vec![5, 6]);
}
#[test]
fn push_trajectory_stores_split() {
let mut traj_map: TrajIndexMap = ahash::AHashMap::new();
traj_map.insert(TrajId::Int(10), ObsMapIndex::Split(vec![99]));
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(traj_map));
let idx_with_new_traj = idx.push_trajectory(TrajId::Int(10), &[0, 2, 4]);
let entry = idx_with_new_traj
.get_by_trajectory(TrajId::Int(10))
.expect("traj 10 must exist after push");
match entry {
ObsMapIndex::Split(v) => assert_eq!(v, &[0, 2, 4]),
ObsMapIndex::Contiguous { .. } => panic!("push_trajectory must produce Split"),
}
}
#[test]
fn push_trajectory_noop_when_no_traj_index() {
let idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, None);
let idx_with_new_traj = idx.push_trajectory(TrajId::Int(42), &[0, 1, 2]);
assert!(
idx_with_new_traj
.get_by_trajectory(TrajId::Int(42))
.is_none(),
"no traj index → get_by_trajectory must return None"
);
}
#[test]
fn merge_from_obs_id_key_not_offset() {
let mut id_map_self = ObservationIndexMap::new();
id_map_self.insert(10, 0);
let mut self_idx = ObsDatasetIndex::new(id_map_self, None, None);
let mut id_map_other = ObservationIndexMap::new();
id_map_other.insert(20, 0); let other_idx = ObsDatasetIndex::new(id_map_other, None, None);
self_idx.merge_from(other_idx, 1);
assert_eq!(
self_idx.get_by_id(&20),
Some(1),
"id key must not be offset; position must be shifted by 1"
);
assert_eq!(self_idx.get_by_id(&10), Some(0));
}
#[test]
fn merge_from_traj_contiguous_preserved_for_new_key() {
let self_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, None);
let mut traj_map = TrajIndexMap::new();
traj_map.insert(TrajId::Int(1), ObsMapIndex::Contiguous { start: 0, end: 3 });
let other_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(traj_map));
let mut self_idx = self_idx;
self_idx.merge_from(other_idx, 5);
match self_idx.get_by_trajectory(TrajId::Int(1)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 5, "start must be shifted by offset");
assert_eq!(*end, 8, "end must be shifted by offset");
}
ObsMapIndex::Split(_) => panic!("expected Contiguous, got Split"),
}
}
#[test]
fn merge_from_traj_contiguous_preserved_when_self_has_other_keys() {
let mut self_traj = TrajIndexMap::new();
self_traj.insert(TrajId::Int(99), ObsMapIndex::Split(vec![0]));
let mut self_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(self_traj));
let mut other_traj = TrajIndexMap::new();
other_traj.insert(TrajId::Int(1), ObsMapIndex::Contiguous { start: 0, end: 2 });
let other_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(other_traj));
self_idx.merge_from(other_idx, 4);
match self_idx.get_by_trajectory(TrajId::Int(1)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 4);
assert_eq!(*end, 6);
}
ObsMapIndex::Split(_) => panic!("expected Contiguous, got Split"),
}
}
#[test]
fn merge_from_traj_collision_produces_split() {
let mut self_traj = TrajIndexMap::new();
self_traj.insert(
TrajId::Int(1),
ObsMapIndex::Contiguous { start: 0, end: 2 }, );
let mut self_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(self_traj));
let mut other_traj = TrajIndexMap::new();
other_traj.insert(
TrajId::Int(1),
ObsMapIndex::Contiguous { start: 0, end: 2 }, );
let other_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), None, Some(other_traj));
self_idx.merge_from(other_idx, 2);
match self_idx.get_by_trajectory(TrajId::Int(1)).unwrap() {
ObsMapIndex::Split(v) => assert_eq!(v, &[0, 1, 2, 3]),
ObsMapIndex::Contiguous { .. } => panic!("expected Split after collision"),
}
}
#[test]
fn merge_from_night_contiguous_preserved_for_new_key() {
let mut self_night = NightIndexMap::new();
self_night.insert(NightId(1), ObsMapIndex::Contiguous { start: 0, end: 2 });
let mut self_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(self_night), None);
let mut other_night = NightIndexMap::new();
other_night.insert(NightId(2), ObsMapIndex::Contiguous { start: 0, end: 3 });
let other_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(other_night), None);
self_idx.merge_from(other_idx, 2);
match self_idx.get_by_night(&NightId(1)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!((*start, *end), (0, 2));
}
_ => panic!("expected Contiguous for night 1"),
}
match self_idx.get_by_night(&NightId(2)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!((*start, *end), (2, 5));
}
ObsMapIndex::Split(_) => panic!("expected Contiguous for new night key"),
}
}
#[test]
fn merge_from_night_collision_produces_split() {
let mut self_night = NightIndexMap::new();
self_night.insert(NightId(1), ObsMapIndex::Split(vec![0, 1]));
let mut self_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(self_night), None);
let mut other_night = NightIndexMap::new();
other_night.insert(NightId(1), ObsMapIndex::Split(vec![0, 1])); let other_idx = ObsDatasetIndex::new(ObservationIndexMap::new(), Some(other_night), None);
self_idx.merge_from(other_idx, 2);
match self_idx.get_by_night(&NightId(1)).unwrap() {
ObsMapIndex::Split(v) => assert_eq!(v, &[0, 1, 2, 3]),
ObsMapIndex::Contiguous { .. } => panic!("expected Split after collision"),
}
}
}