1use std::{
2 collections::HashMap,
3 fmt::{Display, Formatter, Result as FmtResult},
4};
5
6use crate::{
7 units::{HashFor, Unit, UnitCoord},
8 NodeCount, NodeIndex, NodeMap, Round,
9};
10
11pub struct UnitStoreStatus {
13 size: usize,
14 top_row: NodeMap<Round>,
15}
16
17impl UnitStoreStatus {
18 pub fn top_round(&self) -> Round {
20 self.top_row.values().max().cloned().unwrap_or(0)
21 }
22}
23
24impl Display for UnitStoreStatus {
25 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
26 write!(f, "total units: {}, top row: {}", self.size, self.top_row)
27 }
28}
29
30pub struct UnitStore<U: Unit> {
33 by_hash: HashMap<HashFor<U>, U>,
34 canonical_units: NodeMap<HashMap<Round, HashFor<U>>>,
35 top_row: NodeMap<Round>,
36}
37
38impl<U: Unit> UnitStore<U> {
39 pub fn new(node_count: NodeCount) -> Self {
41 let mut canonical_units = NodeMap::with_size(node_count);
42 for node_id in node_count.into_iterator() {
43 canonical_units.insert(node_id, HashMap::new());
44 }
45 let top_row = NodeMap::with_size(node_count);
46 UnitStore {
47 by_hash: HashMap::new(),
48 canonical_units,
49 top_row,
50 }
51 }
52
53 fn mut_hashes_by(&mut self, creator: NodeIndex) -> &mut HashMap<Round, HashFor<U>> {
54 self.canonical_units
55 .get_mut(creator)
56 .expect("all hashmaps initialized")
57 }
58
59 fn hashes_by(&self, creator: NodeIndex) -> &HashMap<Round, HashFor<U>> {
60 self.canonical_units
61 .get(creator)
62 .expect("all hashmaps initialized")
63 }
64
65 fn canonical_by_hash(&self, hash: &HashFor<U>) -> &U {
67 self.by_hash.get(hash).expect("we have all canonical units")
68 }
69
70 fn maybe_set_canonical(&mut self, unit: &U) {
71 let unit_coord = unit.coord();
72 if self.canonical_unit(unit_coord).is_none() {
73 self.mut_hashes_by(unit_coord.creator())
74 .insert(unit.round(), unit.hash());
75 if self
77 .top_row
78 .get(unit.creator())
79 .map(|max_round| unit.round() > *max_round)
80 .unwrap_or(true)
81 {
82 self.top_row.insert(unit_coord.creator(), unit.round());
83 }
84 }
85 }
86
87 pub fn insert(&mut self, unit: U) {
89 self.maybe_set_canonical(&unit);
90 let unit_hash = unit.hash();
91 self.by_hash.insert(unit_hash, unit);
92 }
93
94 fn maybe_unset_canonical(&mut self, unit: &U) {
95 let creator_hashes = self.mut_hashes_by(unit.creator());
96 if creator_hashes.get(&unit.round()) != Some(&unit.hash()) {
97 return;
98 }
99 creator_hashes.remove(&unit.round());
100 if self.top_row.get(unit.creator()) == Some(&unit.round()) {
101 match self.hashes_by(unit.creator()).keys().max().copied() {
102 Some(max_round) => self.top_row.insert(unit.creator(), max_round),
103 None => self.top_row.delete(unit.creator()),
104 }
105 }
106 }
107
108 pub fn remove(&mut self, hash: &HashFor<U>) {
110 if let Some(unit) = self.by_hash.remove(hash) {
111 self.maybe_unset_canonical(&unit);
112 }
113 }
114
115 pub fn canonical_unit(&self, coord: UnitCoord) -> Option<&U> {
117 self.hashes_by(coord.creator())
118 .get(&coord.round())
119 .map(|hash| self.canonical_by_hash(hash))
120 }
121
122 pub fn canonical_units(&self, creator: NodeIndex) -> impl Iterator<Item = &U> {
124 let canonical_hashes = self.hashes_by(creator);
125 let max_round = canonical_hashes.keys().max().cloned().unwrap_or(0);
126 (0..=max_round)
127 .filter_map(|round| canonical_hashes.get(&round))
128 .map(|hash| self.canonical_by_hash(hash))
129 }
130
131 pub fn unit(&self, hash: &HashFor<U>) -> Option<&U> {
133 self.by_hash.get(hash)
134 }
135
136 pub fn top_round_for(&self, creator: NodeIndex) -> Option<Round> {
138 self.top_row.get(creator).copied()
139 }
140
141 pub fn status(&self) -> UnitStoreStatus {
143 UnitStoreStatus {
144 size: self.by_hash.len(),
145 top_row: self.top_row.clone(),
146 }
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use std::collections::HashSet;
153
154 use crate::{
155 units::{random_full_parent_units_up_to, TestingFullUnit, Unit, UnitCoord, UnitStore},
156 NodeCount, NodeIndex,
157 };
158
159 #[test]
160 fn empty_has_no_units() {
161 let node_count = NodeCount(7);
162 let store = UnitStore::<TestingFullUnit>::new(node_count);
163 assert!(store
164 .canonical_unit(UnitCoord::new(0, NodeIndex(0)))
165 .is_none());
166 assert!(store.canonical_units(NodeIndex(0)).next().is_none());
167 assert!(store.top_round_for(NodeIndex(0)).is_none());
168 }
169
170 #[test]
171 fn single_unit_basic_operations() {
172 let node_count = NodeCount(7);
173 let mut store = UnitStore::new(node_count);
174 let unit = random_full_parent_units_up_to(0, node_count, 43)
175 .first()
176 .expect("we have the first round")
177 .first()
178 .expect("we have the initial unit for the zeroth creator")
179 .clone();
180 store.insert(unit.clone());
181 assert_eq!(store.unit(&unit.hash()), Some(&unit));
182 assert_eq!(store.canonical_unit(unit.coord()), Some(&unit));
183 assert_eq!(store.top_round_for(unit.creator()), Some(unit.round()));
184 {
185 let mut canonical_units = store.canonical_units(unit.creator());
187 assert_eq!(canonical_units.next(), Some(&unit));
188 assert_eq!(canonical_units.next(), None);
189 }
190 store.remove(&unit.hash());
191 assert_eq!(store.unit(&unit.hash()), None);
192 assert_eq!(store.canonical_unit(unit.coord()), None);
193 assert_eq!(store.canonical_units(unit.creator()).next(), None);
194 assert_eq!(store.top_round_for(unit.creator()), None);
195 }
196
197 #[test]
198 fn first_variant_is_canonical() {
199 let node_count = NodeCount(7);
200 let mut store = UnitStore::new(node_count);
201 #[allow(clippy::mutable_key_type)]
203 let variants: HashSet<_> = (0..15)
204 .map(|_| {
205 random_full_parent_units_up_to(0, node_count, 43)
206 .first()
207 .expect("we have the first round")
208 .first()
209 .expect("we have the initial unit for the zeroth creator")
210 .clone()
211 })
212 .collect();
213 let variants: Vec<_> = variants.into_iter().collect();
214 for unit in &variants {
215 store.insert(unit.clone());
216 }
217 for unit in &variants {
218 assert_eq!(store.unit(&unit.hash()), Some(unit));
219 }
220 let canonical_unit = variants.first().expect("we have the unit").clone();
221 assert_eq!(
222 store.canonical_unit(canonical_unit.coord()),
223 Some(&canonical_unit)
224 );
225 {
226 let mut canonical_units = store.canonical_units(canonical_unit.creator());
228 assert_eq!(canonical_units.next(), Some(&canonical_unit));
229 assert_eq!(canonical_units.next(), None);
230 }
231 store.remove(&canonical_unit.hash());
232 assert_eq!(store.unit(&canonical_unit.hash()), None);
233 assert_eq!(store.canonical_unit(canonical_unit.coord()), None);
235 assert_eq!(store.canonical_units(canonical_unit.creator()).next(), None);
236 for unit in variants.iter().skip(1) {
238 assert_eq!(store.unit(&unit.hash()), Some(unit));
239 }
240 }
241
242 #[test]
243 fn stores_lots_of_units() {
244 let node_count = NodeCount(7);
245 let mut store = UnitStore::new(node_count);
246 let max_round = 15;
247 let units = random_full_parent_units_up_to(max_round, node_count, 43);
248 for round_units in &units {
249 for unit in round_units {
250 store.insert(unit.clone());
251 }
252 }
253 for round_units in &units {
254 for unit in round_units {
255 assert_eq!(store.unit(&unit.hash()), Some(unit));
256 assert_eq!(store.canonical_unit(unit.coord()), Some(unit));
257 }
258 }
259 for node_id in node_count.into_iterator() {
260 let mut canonical_units = store.canonical_units(node_id);
261 for round in 0..=max_round {
262 assert_eq!(
263 canonical_units.next(),
264 Some(&units[round as usize][node_id.0])
265 );
266 }
267 assert_eq!(canonical_units.next(), None);
268 }
269 }
270
271 #[test]
272 fn handles_fragmented_canonical() {
273 let node_count = NodeCount(7);
274 let mut store = UnitStore::new(node_count);
275 let max_round = 15;
276 let units = random_full_parent_units_up_to(max_round, node_count, 43);
277 for round_units in &units {
278 for unit in round_units {
279 store.insert(unit.clone());
280 }
281 }
282 for round_units in &units {
283 for unit in round_units {
284 if unit.round() as usize % (unit.creator().0 + 1) == 0 {
286 store.remove(&unit.hash());
287 }
288 }
289 }
290 for node_id in node_count.into_iterator() {
291 let mut canonical_units = store.canonical_units(node_id);
292 for round in 0..=max_round {
293 if round as usize % (node_id.0 + 1) != 0 {
294 assert_eq!(
295 canonical_units.next(),
296 Some(&units[round as usize][node_id.0])
297 );
298 }
299 }
300 assert_eq!(canonical_units.next(), None);
301 }
302 }
303}