use std::collections::{HashMap, HashSet};
pub type RankId = u32;
pub(crate) const UNRANKED_SENTINEL: RankId = u32::MAX;
use crate::error::PartialError;
#[derive(Debug)]
pub struct BaseMergeAccumulator {
pub image_owner: HashMap<i64, RankId>,
pub seen_rank_ids: HashSet<RankId>,
pub strict: bool,
}
impl BaseMergeAccumulator {
pub fn new(strict: bool) -> Self {
Self {
image_owner: HashMap::new(),
seen_rank_ids: HashSet::new(),
strict,
}
}
pub fn ingest_rank_id(&mut self, rank_id: Option<RankId>) -> Result<(), PartialError> {
if !self.strict {
return Ok(());
}
if let Some(rid) = rank_id {
if !self.seen_rank_ids.insert(rid) {
return Err(PartialError::RankCollision { rank_id: rid });
}
}
Ok(())
}
pub fn ingest_image_ids(
&mut self,
rank_id: Option<RankId>,
image_ids: impl IntoIterator<Item = i64>,
) -> Result<(), PartialError> {
let owner = rank_id.unwrap_or(UNRANKED_SENTINEL);
for id in image_ids {
if let Some(&prev) = self.image_owner.get(&id) {
let (a, b) = (prev.min(owner), prev.max(owner));
return Err(PartialError::PartitionOverlap {
rank_a: a,
rank_b: b,
image_id: id,
});
}
self.image_owner.insert(id, owner);
}
Ok(())
}
pub fn image_ids(&self) -> impl Iterator<Item = i64> + '_ {
self.image_owner.keys().copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_collision_strict() {
let mut acc = BaseMergeAccumulator::new(true);
acc.ingest_rank_id(Some(0)).unwrap();
let err = acc.ingest_rank_id(Some(0)).unwrap_err();
assert!(matches!(err, PartialError::RankCollision { rank_id: 0 }));
}
#[test]
fn rank_collision_corrected_tolerated() {
let mut acc = BaseMergeAccumulator::new(false);
acc.ingest_rank_id(Some(0)).unwrap();
acc.ingest_rank_id(Some(0)).unwrap();
}
#[test]
fn partition_overlap_named_ranks() {
let mut acc = BaseMergeAccumulator::new(true);
acc.ingest_image_ids(Some(0), [1, 2, 3]).unwrap();
let err = acc.ingest_image_ids(Some(1), [3, 4]).unwrap_err();
match err {
PartialError::PartitionOverlap {
rank_a,
rank_b,
image_id,
} => {
assert_eq!(rank_a, 0);
assert_eq!(rank_b, 1);
assert_eq!(image_id, 3);
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn partition_overlap_unranked_sentinel() {
let mut acc = BaseMergeAccumulator::new(false);
acc.ingest_image_ids(None, [7]).unwrap();
let err = acc.ingest_image_ids(None, [7]).unwrap_err();
match err {
PartialError::PartitionOverlap { rank_a, rank_b, .. } => {
assert_eq!(rank_a, UNRANKED_SENTINEL);
assert_eq!(rank_b, UNRANKED_SENTINEL);
}
other => panic!("unexpected error: {other:?}"),
}
}
}