aleph_bft/units/
store.rs

1use std::{
2    collections::HashMap,
3    fmt::{Display, Formatter, Result as FmtResult},
4};
5
6use crate::{
7    units::{HashFor, Unit, UnitCoord},
8    NodeCount, NodeIndex, NodeMap, Round,
9};
10
11/// An overview of what is in the unit store.
12pub struct UnitStoreStatus {
13    size: usize,
14    top_row: NodeMap<Round>,
15}
16
17impl UnitStoreStatus {
18    /// Highest round among units in the store.
19    pub fn top_round(&self) -> Round {
20        self.top_row.values().max().cloned().unwrap_or(0)
21    }
22}
23
24impl Display for UnitStoreStatus {
25    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
26        write!(f, "total units: {}, top row: {}", self.size, self.top_row)
27    }
28}
29
30/// Stores units, and keeps track of which are canonical, i.e. the first ones inserted with a given coordinate.
31/// See `remove` for limitation on trusting canonical units, although they don't impact our usecases.
32pub struct UnitStore<U: Unit> {
33    by_hash: HashMap<HashFor<U>, U>,
34    canonical_units: NodeMap<HashMap<Round, HashFor<U>>>,
35    top_row: NodeMap<Round>,
36}
37
38impl<U: Unit> UnitStore<U> {
39    /// Create a new unit store for the given number of nodes.
40    pub fn new(node_count: NodeCount) -> Self {
41        let mut canonical_units = NodeMap::with_size(node_count);
42        for node_id in node_count.into_iterator() {
43            canonical_units.insert(node_id, HashMap::new());
44        }
45        let top_row = NodeMap::with_size(node_count);
46        UnitStore {
47            by_hash: HashMap::new(),
48            canonical_units,
49            top_row,
50        }
51    }
52
53    fn mut_hashes_by(&mut self, creator: NodeIndex) -> &mut HashMap<Round, HashFor<U>> {
54        self.canonical_units
55            .get_mut(creator)
56            .expect("all hashmaps initialized")
57    }
58
59    fn hashes_by(&self, creator: NodeIndex) -> &HashMap<Round, HashFor<U>> {
60        self.canonical_units
61            .get(creator)
62            .expect("all hashmaps initialized")
63    }
64
65    // only call this for canonical units
66    fn canonical_by_hash(&self, hash: &HashFor<U>) -> &U {
67        self.by_hash.get(hash).expect("we have all canonical units")
68    }
69
70    fn maybe_set_canonical(&mut self, unit: &U) {
71        let unit_coord = unit.coord();
72        if self.canonical_unit(unit_coord).is_none() {
73            self.mut_hashes_by(unit_coord.creator())
74                .insert(unit.round(), unit.hash());
75            // the top row is only cached information for optimization purposes
76            if self
77                .top_row
78                .get(unit.creator())
79                .map(|max_round| unit.round() > *max_round)
80                .unwrap_or(true)
81            {
82                self.top_row.insert(unit_coord.creator(), unit.round());
83            }
84        }
85    }
86
87    /// Insert a unit. If no other unit with this coord is in the store it becomes canonical.
88    pub fn insert(&mut self, unit: U) {
89        self.maybe_set_canonical(&unit);
90        let unit_hash = unit.hash();
91        self.by_hash.insert(unit_hash, unit);
92    }
93
94    fn maybe_unset_canonical(&mut self, unit: &U) {
95        let creator_hashes = self.mut_hashes_by(unit.creator());
96        if creator_hashes.get(&unit.round()) != Some(&unit.hash()) {
97            return;
98        }
99        creator_hashes.remove(&unit.round());
100        if self.top_row.get(unit.creator()) == Some(&unit.round()) {
101            match self.hashes_by(unit.creator()).keys().max().copied() {
102                Some(max_round) => self.top_row.insert(unit.creator(), max_round),
103                None => self.top_row.delete(unit.creator()),
104            }
105        }
106    }
107
108    /// Remove a unit with a given hash. Notably if you remove a unit another might become canonical in its place in the future.
109    pub fn remove(&mut self, hash: &HashFor<U>) {
110        if let Some(unit) = self.by_hash.remove(hash) {
111            self.maybe_unset_canonical(&unit);
112        }
113    }
114
115    /// The canonical unit for the given coord if it exists.
116    pub fn canonical_unit(&self, coord: UnitCoord) -> Option<&U> {
117        self.hashes_by(coord.creator())
118            .get(&coord.round())
119            .map(|hash| self.canonical_by_hash(hash))
120    }
121
122    /// All the canonical units for the given creator, in order of rounds.
123    pub fn canonical_units(&self, creator: NodeIndex) -> impl Iterator<Item = &U> {
124        let canonical_hashes = self.hashes_by(creator);
125        let max_round = canonical_hashes.keys().max().cloned().unwrap_or(0);
126        (0..=max_round)
127            .filter_map(|round| canonical_hashes.get(&round))
128            .map(|hash| self.canonical_by_hash(hash))
129    }
130
131    /// The unit for the given hash, if present.
132    pub fn unit(&self, hash: &HashFor<U>) -> Option<&U> {
133        self.by_hash.get(hash)
134    }
135
136    /// The highest known round for the given creator.
137    pub fn top_round_for(&self, creator: NodeIndex) -> Option<Round> {
138        self.top_row.get(creator).copied()
139    }
140
141    /// The status summary of this store.
142    pub fn status(&self) -> UnitStoreStatus {
143        UnitStoreStatus {
144            size: self.by_hash.len(),
145            top_row: self.top_row.clone(),
146        }
147    }
148}
149
150#[cfg(test)]
151mod test {
152    use std::collections::HashSet;
153
154    use crate::{
155        units::{random_full_parent_units_up_to, TestingFullUnit, Unit, UnitCoord, UnitStore},
156        NodeCount, NodeIndex,
157    };
158
159    #[test]
160    fn empty_has_no_units() {
161        let node_count = NodeCount(7);
162        let store = UnitStore::<TestingFullUnit>::new(node_count);
163        assert!(store
164            .canonical_unit(UnitCoord::new(0, NodeIndex(0)))
165            .is_none());
166        assert!(store.canonical_units(NodeIndex(0)).next().is_none());
167        assert!(store.top_round_for(NodeIndex(0)).is_none());
168    }
169
170    #[test]
171    fn single_unit_basic_operations() {
172        let node_count = NodeCount(7);
173        let mut store = UnitStore::new(node_count);
174        let unit = random_full_parent_units_up_to(0, node_count, 43)
175            .first()
176            .expect("we have the first round")
177            .first()
178            .expect("we have the initial unit for the zeroth creator")
179            .clone();
180        store.insert(unit.clone());
181        assert_eq!(store.unit(&unit.hash()), Some(&unit));
182        assert_eq!(store.canonical_unit(unit.coord()), Some(&unit));
183        assert_eq!(store.top_round_for(unit.creator()), Some(unit.round()));
184        {
185            // in block to drop the iterator
186            let mut canonical_units = store.canonical_units(unit.creator());
187            assert_eq!(canonical_units.next(), Some(&unit));
188            assert_eq!(canonical_units.next(), None);
189        }
190        store.remove(&unit.hash());
191        assert_eq!(store.unit(&unit.hash()), None);
192        assert_eq!(store.canonical_unit(unit.coord()), None);
193        assert_eq!(store.canonical_units(unit.creator()).next(), None);
194        assert_eq!(store.top_round_for(unit.creator()), None);
195    }
196
197    #[test]
198    fn first_variant_is_canonical() {
199        let node_count = NodeCount(7);
200        let mut store = UnitStore::new(node_count);
201        // only unique variants
202        #[allow(clippy::mutable_key_type)]
203        let variants: HashSet<_> = (0..15)
204            .map(|_| {
205                random_full_parent_units_up_to(0, node_count, 43)
206                    .first()
207                    .expect("we have the first round")
208                    .first()
209                    .expect("we have the initial unit for the zeroth creator")
210                    .clone()
211            })
212            .collect();
213        let variants: Vec<_> = variants.into_iter().collect();
214        for unit in &variants {
215            store.insert(unit.clone());
216        }
217        for unit in &variants {
218            assert_eq!(store.unit(&unit.hash()), Some(unit));
219        }
220        let canonical_unit = variants.first().expect("we have the unit").clone();
221        assert_eq!(
222            store.canonical_unit(canonical_unit.coord()),
223            Some(&canonical_unit)
224        );
225        {
226            // in block to drop the iterator
227            let mut canonical_units = store.canonical_units(canonical_unit.creator());
228            assert_eq!(canonical_units.next(), Some(&canonical_unit));
229            assert_eq!(canonical_units.next(), None);
230        }
231        store.remove(&canonical_unit.hash());
232        assert_eq!(store.unit(&canonical_unit.hash()), None);
233        // we don't have a canonical unit any more
234        assert_eq!(store.canonical_unit(canonical_unit.coord()), None);
235        assert_eq!(store.canonical_units(canonical_unit.creator()).next(), None);
236        // we still have all this other units
237        for unit in variants.iter().skip(1) {
238            assert_eq!(store.unit(&unit.hash()), Some(unit));
239        }
240    }
241
242    #[test]
243    fn stores_lots_of_units() {
244        let node_count = NodeCount(7);
245        let mut store = UnitStore::new(node_count);
246        let max_round = 15;
247        let units = random_full_parent_units_up_to(max_round, node_count, 43);
248        for round_units in &units {
249            for unit in round_units {
250                store.insert(unit.clone());
251            }
252        }
253        for round_units in &units {
254            for unit in round_units {
255                assert_eq!(store.unit(&unit.hash()), Some(unit));
256                assert_eq!(store.canonical_unit(unit.coord()), Some(unit));
257            }
258        }
259        for node_id in node_count.into_iterator() {
260            let mut canonical_units = store.canonical_units(node_id);
261            for round in 0..=max_round {
262                assert_eq!(
263                    canonical_units.next(),
264                    Some(&units[round as usize][node_id.0])
265                );
266            }
267            assert_eq!(canonical_units.next(), None);
268        }
269    }
270
271    #[test]
272    fn handles_fragmented_canonical() {
273        let node_count = NodeCount(7);
274        let mut store = UnitStore::new(node_count);
275        let max_round = 15;
276        let units = random_full_parent_units_up_to(max_round, node_count, 43);
277        for round_units in &units {
278            for unit in round_units {
279                store.insert(unit.clone());
280            }
281        }
282        for round_units in &units {
283            for unit in round_units {
284                // remove some units with a weird criterion
285                if unit.round() as usize % (unit.creator().0 + 1) == 0 {
286                    store.remove(&unit.hash());
287                }
288            }
289        }
290        for node_id in node_count.into_iterator() {
291            let mut canonical_units = store.canonical_units(node_id);
292            for round in 0..=max_round {
293                if round as usize % (node_id.0 + 1) != 0 {
294                    assert_eq!(
295                        canonical_units.next(),
296                        Some(&units[round as usize][node_id.0])
297                    );
298                }
299            }
300            assert_eq!(canonical_units.next(), None);
301        }
302    }
303}