use std::{
collections::HashMap,
fmt::{Display, Formatter, Result as FmtResult},
};
use crate::{
units::{HashFor, Unit, UnitCoord},
NodeCount, NodeIndex, NodeMap, Round,
};
pub struct UnitStoreStatus {
size: usize,
top_row: NodeMap<Round>,
}
impl UnitStoreStatus {
pub fn top_round(&self) -> Round {
self.top_row.values().max().cloned().unwrap_or(0)
}
}
impl Display for UnitStoreStatus {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "total units: {}, top row: {}", self.size, self.top_row)
}
}
pub struct UnitStore<U: Unit> {
by_hash: HashMap<HashFor<U>, U>,
canonical_units: NodeMap<HashMap<Round, HashFor<U>>>,
top_row: NodeMap<Round>,
}
impl<U: Unit> UnitStore<U> {
pub fn new(node_count: NodeCount) -> Self {
let mut canonical_units = NodeMap::with_size(node_count);
for node_id in node_count.into_iterator() {
canonical_units.insert(node_id, HashMap::new());
}
let top_row = NodeMap::with_size(node_count);
UnitStore {
by_hash: HashMap::new(),
canonical_units,
top_row,
}
}
fn mut_hashes_by(&mut self, creator: NodeIndex) -> &mut HashMap<Round, HashFor<U>> {
self.canonical_units
.get_mut(creator)
.expect("all hashmaps initialized")
}
fn hashes_by(&self, creator: NodeIndex) -> &HashMap<Round, HashFor<U>> {
self.canonical_units
.get(creator)
.expect("all hashmaps initialized")
}
fn canonical_by_hash(&self, hash: &HashFor<U>) -> &U {
self.by_hash.get(hash).expect("we have all canonical units")
}
fn maybe_set_canonical(&mut self, unit: &U) {
let unit_coord = unit.coord();
if self.canonical_unit(unit_coord).is_none() {
self.mut_hashes_by(unit_coord.creator())
.insert(unit.round(), unit.hash());
if self
.top_row
.get(unit.creator())
.map(|max_round| unit.round() > *max_round)
.unwrap_or(true)
{
self.top_row.insert(unit_coord.creator(), unit.round());
}
}
}
pub fn insert(&mut self, unit: U) {
self.maybe_set_canonical(&unit);
let unit_hash = unit.hash();
self.by_hash.insert(unit_hash, unit);
}
fn maybe_unset_canonical(&mut self, unit: &U) {
let creator_hashes = self.mut_hashes_by(unit.creator());
if creator_hashes.get(&unit.round()) != Some(&unit.hash()) {
return;
}
creator_hashes.remove(&unit.round());
if self.top_row.get(unit.creator()) == Some(&unit.round()) {
match self.hashes_by(unit.creator()).keys().max().copied() {
Some(max_round) => self.top_row.insert(unit.creator(), max_round),
None => self.top_row.delete(unit.creator()),
}
}
}
pub fn remove(&mut self, hash: &HashFor<U>) {
if let Some(unit) = self.by_hash.remove(hash) {
self.maybe_unset_canonical(&unit);
}
}
pub fn canonical_unit(&self, coord: UnitCoord) -> Option<&U> {
self.hashes_by(coord.creator())
.get(&coord.round())
.map(|hash| self.canonical_by_hash(hash))
}
pub fn canonical_units(&self, creator: NodeIndex) -> impl Iterator<Item = &U> {
let canonical_hashes = self.hashes_by(creator);
let max_round = canonical_hashes.keys().max().cloned().unwrap_or(0);
(0..=max_round)
.filter_map(|round| canonical_hashes.get(&round))
.map(|hash| self.canonical_by_hash(hash))
}
pub fn unit(&self, hash: &HashFor<U>) -> Option<&U> {
self.by_hash.get(hash)
}
pub fn top_round_for(&self, creator: NodeIndex) -> Option<Round> {
self.top_row.get(creator).copied()
}
pub fn status(&self) -> UnitStoreStatus {
UnitStoreStatus {
size: self.by_hash.len(),
top_row: self.top_row.clone(),
}
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use crate::{
units::{random_full_parent_units_up_to, TestingFullUnit, Unit, UnitCoord, UnitStore},
NodeCount, NodeIndex,
};
#[test]
fn empty_has_no_units() {
let node_count = NodeCount(7);
let store = UnitStore::<TestingFullUnit>::new(node_count);
assert!(store
.canonical_unit(UnitCoord::new(0, NodeIndex(0)))
.is_none());
assert!(store.canonical_units(NodeIndex(0)).next().is_none());
assert!(store.top_round_for(NodeIndex(0)).is_none());
}
#[test]
fn single_unit_basic_operations() {
let node_count = NodeCount(7);
let mut store = UnitStore::new(node_count);
let unit = random_full_parent_units_up_to(0, node_count, 43)
.first()
.expect("we have the first round")
.first()
.expect("we have the initial unit for the zeroth creator")
.clone();
store.insert(unit.clone());
assert_eq!(store.unit(&unit.hash()), Some(&unit));
assert_eq!(store.canonical_unit(unit.coord()), Some(&unit));
assert_eq!(store.top_round_for(unit.creator()), Some(unit.round()));
{
let mut canonical_units = store.canonical_units(unit.creator());
assert_eq!(canonical_units.next(), Some(&unit));
assert_eq!(canonical_units.next(), None);
}
store.remove(&unit.hash());
assert_eq!(store.unit(&unit.hash()), None);
assert_eq!(store.canonical_unit(unit.coord()), None);
assert_eq!(store.canonical_units(unit.creator()).next(), None);
assert_eq!(store.top_round_for(unit.creator()), None);
}
#[test]
fn first_variant_is_canonical() {
let node_count = NodeCount(7);
let mut store = UnitStore::new(node_count);
#[allow(clippy::mutable_key_type)]
let variants: HashSet<_> = (0..15)
.map(|_| {
random_full_parent_units_up_to(0, node_count, 43)
.first()
.expect("we have the first round")
.first()
.expect("we have the initial unit for the zeroth creator")
.clone()
})
.collect();
let variants: Vec<_> = variants.into_iter().collect();
for unit in &variants {
store.insert(unit.clone());
}
for unit in &variants {
assert_eq!(store.unit(&unit.hash()), Some(unit));
}
let canonical_unit = variants.first().expect("we have the unit").clone();
assert_eq!(
store.canonical_unit(canonical_unit.coord()),
Some(&canonical_unit)
);
{
let mut canonical_units = store.canonical_units(canonical_unit.creator());
assert_eq!(canonical_units.next(), Some(&canonical_unit));
assert_eq!(canonical_units.next(), None);
}
store.remove(&canonical_unit.hash());
assert_eq!(store.unit(&canonical_unit.hash()), None);
assert_eq!(store.canonical_unit(canonical_unit.coord()), None);
assert_eq!(store.canonical_units(canonical_unit.creator()).next(), None);
for unit in variants.iter().skip(1) {
assert_eq!(store.unit(&unit.hash()), Some(unit));
}
}
#[test]
fn stores_lots_of_units() {
let node_count = NodeCount(7);
let mut store = UnitStore::new(node_count);
let max_round = 15;
let units = random_full_parent_units_up_to(max_round, node_count, 43);
for round_units in &units {
for unit in round_units {
store.insert(unit.clone());
}
}
for round_units in &units {
for unit in round_units {
assert_eq!(store.unit(&unit.hash()), Some(unit));
assert_eq!(store.canonical_unit(unit.coord()), Some(unit));
}
}
for node_id in node_count.into_iterator() {
let mut canonical_units = store.canonical_units(node_id);
for round in 0..=max_round {
assert_eq!(
canonical_units.next(),
Some(&units[round as usize][node_id.0])
);
}
assert_eq!(canonical_units.next(), None);
}
}
#[test]
fn handles_fragmented_canonical() {
let node_count = NodeCount(7);
let mut store = UnitStore::new(node_count);
let max_round = 15;
let units = random_full_parent_units_up_to(max_round, node_count, 43);
for round_units in &units {
for unit in round_units {
store.insert(unit.clone());
}
}
for round_units in &units {
for unit in round_units {
if unit.round() as usize % (unit.creator().0 + 1) == 0 {
store.remove(&unit.hash());
}
}
}
for node_id in node_count.into_iterator() {
let mut canonical_units = store.canonical_units(node_id);
for round in 0..=max_round {
if round as usize % (node_id.0 + 1) != 0 {
assert_eq!(
canonical_units.next(),
Some(&units[round as usize][node_id.0])
);
}
}
assert_eq!(canonical_units.next(), None);
}
}
}