amareleo_node_bft/helpers/
storage.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_node_bft_ledger_service::LedgerService;
18use amareleo_node_bft_storage_service::StorageService;
19use snarkvm::{
20    ledger::{
21        block::{Block, Transaction},
22        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
23    },
24    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
25};
26
27use indexmap::{IndexMap, IndexSet, map::Entry};
28use parking_lot::RwLock;
29use std::{
30    collections::{HashMap, HashSet},
31    sync::{
32        Arc,
33        atomic::{AtomicU32, AtomicU64, Ordering},
34    },
35};
36
37#[derive(Clone, Debug)]
38pub struct Storage<N: Network>(Arc<StorageInner<N>>);
39
40impl<N: Network> std::ops::Deref for Storage<N> {
41    type Target = Arc<StorageInner<N>>;
42
43    fn deref(&self) -> &Self::Target {
44        &self.0
45    }
46}
47
48/// The storage for the memory pool.
49///
50/// The storage is used to store the following:
51/// - `current_height` tracker.
52/// - `current_round` tracker.
53/// - `round` to `(certificate ID, batch ID, author)` entries.
54/// - `certificate ID` to `certificate` entries.
55/// - `batch ID` to `round` entries.
56/// - `transmission ID` to `(transmission, certificate IDs)` entries.
57///
58/// The chain of events is as follows:
59/// 1. A `transmission` is received.
60/// 2. After a `batch` is ready to be stored:
61///   - The `certificate` is inserted, triggering updates to the
62///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
63///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
64///   - The certificate ID is inserted into the `transmissions` map.
65/// 3. After a `round` reaches quorum threshold:
66///  - The next round is inserted into the `current_round`.
67#[derive(Debug)]
68pub struct StorageInner<N: Network> {
69    /// The ledger service.
70    ledger: Arc<dyn LedgerService<N>>,
71    /* Once per block */
72    /// The current height.
73    current_height: AtomicU32,
74    /* Once per round */
75    /// The current round.
76    current_round: AtomicU64,
77    /// The `round` for which garbage collection has occurred **up to** (inclusive).
78    gc_round: AtomicU64,
79    /// The maximum number of rounds to keep in storage.
80    max_gc_rounds: u64,
81    /* Once per batch */
82    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
83    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
84    /// The map of `certificate ID` to `certificate`.
85    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
86    /// The map of `batch ID` to `round`.
87    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
88    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
89    transmissions: Arc<dyn StorageService<N>>,
90}
91
92impl<N: Network> Storage<N> {
93    /// Initializes a new instance of storage.
94    pub fn new(
95        ledger: Arc<dyn LedgerService<N>>,
96        transmissions: Arc<dyn StorageService<N>>,
97        max_gc_rounds: u64,
98    ) -> Self {
99        // Retrieve the current committee.
100        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
101        // Retrieve the current round.
102        let current_round = committee.starting_round().max(1);
103
104        // Return the storage.
105        let storage = Self(Arc::new(StorageInner {
106            ledger,
107            current_height: Default::default(),
108            current_round: Default::default(),
109            gc_round: Default::default(),
110            max_gc_rounds,
111            rounds: Default::default(),
112            certificates: Default::default(),
113            batch_ids: Default::default(),
114            transmissions,
115        }));
116        // Update the storage to the current round.
117        storage.update_current_round(current_round);
118        // Perform GC on the current round.
119        storage.garbage_collect_certificates(current_round);
120        // Return the storage.
121        storage
122    }
123}
124
125impl<N: Network> Storage<N> {
126    /// Returns the current height.
127    pub fn current_height(&self) -> u32 {
128        // Get the current height.
129        self.current_height.load(Ordering::SeqCst)
130    }
131}
132
133impl<N: Network> Storage<N> {
134    /// Returns the current round.
135    pub fn current_round(&self) -> u64 {
136        // Get the current round.
137        self.current_round.load(Ordering::SeqCst)
138    }
139
140    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
141    pub fn gc_round(&self) -> u64 {
142        // Get the GC round.
143        self.gc_round.load(Ordering::SeqCst)
144    }
145
146    /// Returns the maximum number of rounds to keep in storage.
147    pub fn max_gc_rounds(&self) -> u64 {
148        self.max_gc_rounds
149    }
150
151    /// Increments storage to the next round, updating the current round.
152    /// Note: This method is only called once per round, upon certification of the primary's batch.
153    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
154        // Determine the next round.
155        let next_round = current_round + 1;
156
157        // Check if the next round is less than the current round in storage.
158        {
159            // Retrieve the storage round.
160            let storage_round = self.current_round();
161            // If the next round is less than the current round in storage, return early with the storage round.
162            if next_round < storage_round {
163                return Ok(storage_round);
164            }
165        }
166
167        // Retrieve the current committee.
168        let current_committee = self.ledger.current_committee()?;
169        // Retrieve the current committee's starting round.
170        let starting_round = current_committee.starting_round();
171        // If the primary is behind the current committee's starting round, sync with the latest block.
172        if next_round < starting_round {
173            // Retrieve the latest block round.
174            let latest_block_round = self.ledger.latest_round();
175            // Log the round sync.
176            info!(
177                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
178            );
179            // Sync the round with the latest block.
180            self.sync_round_with_block(latest_block_round);
181            // Return the latest block round.
182            return Ok(latest_block_round);
183        }
184
185        // Update the storage to the next round.
186        self.update_current_round(next_round);
187
188        #[cfg(feature = "metrics")]
189        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
190
191        // Retrieve the storage round.
192        let storage_round = self.current_round();
193        // Retrieve the GC round.
194        let gc_round = self.gc_round();
195        // Ensure the next round matches in storage.
196        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
197        // Ensure the next round is greater than or equal to the GC round.
198        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
199
200        // Log the updated round.
201        info!("Starting round {next_round}...");
202        Ok(next_round)
203    }
204
205    /// Updates the storage to the next round.
206    fn update_current_round(&self, next_round: u64) {
207        // Update the current round.
208        self.current_round.store(next_round, Ordering::SeqCst);
209    }
210
211    /// Update the storage by performing garbage collection based on the next round.
212    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
213        // Fetch the current GC round.
214        let current_gc_round = self.gc_round();
215        // Compute the next GC round.
216        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
217        // Check if storage needs to be garbage collected.
218        if next_gc_round > current_gc_round {
219            // Remove the GC round(s) from storage.
220            for gc_round in current_gc_round..=next_gc_round {
221                // Iterate over the certificates for the GC round.
222                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
223                    // Remove the certificate from storage.
224                    self.remove_certificate(id);
225                }
226            }
227            // Update the GC round.
228            self.gc_round.store(next_gc_round, Ordering::SeqCst);
229        }
230    }
231}
232
233impl<N: Network> Storage<N> {
234    /// Returns `true` if the storage contains the specified `round`.
235    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
236        // Check if the round exists in storage.
237        self.rounds.read().contains_key(&round)
238    }
239
240    /// Returns `true` if the storage contains the specified `certificate ID`.
241    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
242        // Check if the certificate ID exists in storage.
243        self.certificates.read().contains_key(&certificate_id)
244    }
245
246    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
247    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
248        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
249    }
250
251    /// Returns `true` if the storage contains the specified `batch ID`.
252    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
253        // Check if the batch ID exists in storage.
254        self.batch_ids.read().contains_key(&batch_id)
255    }
256
257    /// Returns `true` if the storage contains the specified `transmission ID`.
258    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
259        self.transmissions.contains_transmission(transmission_id.into())
260    }
261
262    /// Returns the transmission for the given `transmission ID`.
263    /// If the transmission ID does not exist in storage, `None` is returned.
264    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
265        self.transmissions.get_transmission(transmission_id.into())
266    }
267
268    /// Returns the round for the given `certificate ID`.
269    /// If the certificate ID does not exist in storage, `None` is returned.
270    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
271        // Get the round.
272        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
273    }
274
275    /// Returns the round for the given `batch ID`.
276    /// If the batch ID does not exist in storage, `None` is returned.
277    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
278        // Get the round.
279        self.batch_ids.read().get(&batch_id).copied()
280    }
281
282    /// Returns the certificate round for the given `certificate ID`.
283    /// If the certificate ID does not exist in storage, `None` is returned.
284    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
285        // Get the batch certificate and return the round.
286        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
287    }
288
289    /// Returns the certificate for the given `certificate ID`.
290    /// If the certificate ID does not exist in storage, `None` is returned.
291    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
292        // Get the batch certificate.
293        self.certificates.read().get(&certificate_id).cloned()
294    }
295
296    /// Returns the certificate for the given `round` and `author`.
297    /// If the round does not exist in storage, `None` is returned.
298    /// If the author for the round does not exist in storage, `None` is returned.
299    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
300        // Retrieve the certificates.
301        if let Some(entries) = self.rounds.read().get(&round) {
302            let certificates = self.certificates.read();
303            entries.iter().find_map(
304                |(certificate_id, _, a)| {
305                    if a == &author { certificates.get(certificate_id).cloned() } else { None }
306                },
307            )
308        } else {
309            Default::default()
310        }
311    }
312
313    /// Returns the certificates for the given `round`.
314    /// If the round does not exist in storage, `None` is returned.
315    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
316        // The genesis round does not have batch certificates.
317        if round == 0 {
318            return Default::default();
319        }
320        // Retrieve the certificates.
321        if let Some(entries) = self.rounds.read().get(&round) {
322            let certificates = self.certificates.read();
323            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
324        } else {
325            Default::default()
326        }
327    }
328
329    /// Returns the certificate IDs for the given `round`.
330    /// If the round does not exist in storage, `None` is returned.
331    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
332        // The genesis round does not have batch certificates.
333        if round == 0 {
334            return Default::default();
335        }
336        // Retrieve the certificates.
337        if let Some(entries) = self.rounds.read().get(&round) {
338            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
339        } else {
340            Default::default()
341        }
342    }
343
344    /// Returns the certificate authors for the given `round`.
345    /// If the round does not exist in storage, `None` is returned.
346    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
347        // The genesis round does not have batch certificates.
348        if round == 0 {
349            return Default::default();
350        }
351        // Retrieve the certificates.
352        if let Some(entries) = self.rounds.read().get(&round) {
353            entries.iter().map(|(_, _, author)| *author).collect()
354        } else {
355            Default::default()
356        }
357    }
358
359    /// Returns the certificates that have not yet been included in the ledger.
360    /// Note that the order of this set is by round and then insertion.
361    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
362        let mut pending_certificates = IndexSet::new();
363
364        // Obtain the read locks.
365        let rounds = self.rounds.read();
366        let certificates = self.certificates.read();
367
368        // Iterate over the rounds.
369        for (_, certificates_for_round) in rounds.clone().sorted_by(|a, _, b, _| a.cmp(b)) {
370            // Iterate over the certificates for the round.
371            for (certificate_id, _, _) in certificates_for_round {
372                // Skip the certificate if it already exists in the ledger.
373                if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
374                    continue;
375                }
376
377                // Add the certificate to the pending certificates.
378                match certificates.get(&certificate_id).cloned() {
379                    Some(certificate) => pending_certificates.insert(certificate),
380                    None => continue,
381                };
382            }
383        }
384
385        pending_certificates
386    }
387
388    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
389    ///
390    /// This method ensures the following invariants:
391    /// - The batch ID does not already exist in storage.
392    /// - The author is a member of the committee for the batch round.
393    /// - The timestamp is within the allowed time range.
394    /// - None of the transmissions are from any past rounds (up to GC).
395    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
396    /// - All previous certificates declared in the certificate exist in storage (up to GC).
397    /// - All previous certificates are for the previous round (i.e. round - 1).
398    /// - All previous certificates contain a unique author.
399    /// - The previous certificates reached the quorum threshold (N - f).
400    pub fn check_batch_header(
401        &self,
402        batch_header: &BatchHeader<N>,
403        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
404        aborted_transmissions: HashSet<TransmissionID<N>>,
405    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
406        // Retrieve the round.
407        let round = batch_header.round();
408        // Retrieve the GC round.
409        let gc_round = self.gc_round();
410        // Construct a GC log message.
411        let gc_log = format!("(gc = {gc_round})");
412
413        // Ensure the batch ID does not already exist in storage.
414        if self.contains_batch(batch_header.batch_id()) {
415            bail!("Batch for round {round} already exists in storage {gc_log}")
416        }
417
418        // Retrieve the committee lookback for the batch round.
419        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
420            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
421        };
422        // Ensure the author is in the committee.
423        if !committee_lookback.is_committee_member(batch_header.author()) {
424            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
425        }
426
427        // Check the timestamp for liveness.
428        check_timestamp_for_liveness(batch_header.timestamp())?;
429
430        // Retrieve the missing transmissions in storage from the given transmissions.
431        let missing_transmissions = self
432            .transmissions
433            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
434            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
435
436        // Compute the previous round.
437        let previous_round = round.saturating_sub(1);
438        // Check if the previous round is within range of the GC round.
439        if previous_round > gc_round {
440            // Retrieve the committee lookback for the previous round.
441            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
442                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
443            };
444            // Ensure the previous round certificates exists in storage.
445            if !self.contains_certificates_for_round(previous_round) {
446                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
447            }
448            // Ensure the number of previous certificate IDs is at or below the number of committee members.
449            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
450                bail!("Too many previous certificates for round {round} {gc_log}")
451            }
452            // Initialize a set of the previous authors.
453            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
454            // Ensure storage contains all declared previous certificates (up to GC).
455            for previous_certificate_id in batch_header.previous_certificate_ids() {
456                // Retrieve the previous certificate.
457                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
458                    bail!(
459                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
460                        fmt_id(previous_certificate_id)
461                    )
462                };
463                // Ensure the previous certificate is for the previous round.
464                if previous_certificate.round() != previous_round {
465                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
466                }
467                // Ensure the previous author is new.
468                if previous_authors.contains(&previous_certificate.author()) {
469                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
470                }
471                // Insert the author of the previous certificate.
472                previous_authors.insert(previous_certificate.author());
473            }
474            // Ensure the previous certificates have reached the quorum threshold.
475            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
476                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
477            }
478        }
479        Ok(missing_transmissions)
480    }
481
482    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
483    ///
484    /// This method ensures the following invariants:
485    /// - The certificate ID does not already exist in storage.
486    /// - The batch ID does not already exist in storage.
487    /// - The author is a member of the committee for the batch round.
488    /// - The author has not already created a certificate for the batch round.
489    /// - The timestamp is within the allowed time range.
490    /// - None of the transmissions are from any past rounds (up to GC).
491    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
492    /// - All previous certificates declared in the certificate exist in storage (up to GC).
493    /// - All previous certificates are for the previous round (i.e. round - 1).
494    /// - The previous certificates reached the quorum threshold (N - f).
495    /// - The timestamps from the signers are all within the allowed time range.
496    /// - The signers have reached the quorum threshold (N - f).
497    pub fn check_certificate(
498        &self,
499        certificate: &BatchCertificate<N>,
500        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
501        aborted_transmissions: HashSet<TransmissionID<N>>,
502    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
503        // Retrieve the round.
504        let round = certificate.round();
505        // Retrieve the GC round.
506        let gc_round = self.gc_round();
507        // Construct a GC log message.
508        let gc_log = format!("(gc = {gc_round})");
509
510        // Ensure the certificate ID does not already exist in storage.
511        if self.contains_certificate(certificate.id()) {
512            bail!("Certificate for round {round} already exists in storage {gc_log}")
513        }
514
515        // Ensure the storage does not already contain a certificate for this author in this round.
516        if self.contains_certificate_in_round_from(round, certificate.author()) {
517            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
518        }
519
520        // Ensure the batch header is well-formed.
521        let missing_transmissions =
522            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
523
524        // Check the timestamp for liveness.
525        check_timestamp_for_liveness(certificate.timestamp())?;
526
527        // Retrieve the committee lookback for the batch round.
528        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
529            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
530        };
531
532        // Initialize a set of the signers.
533        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
534        // Append the batch author.
535        signers.insert(certificate.author());
536
537        // Iterate over the signatures.
538        for signature in certificate.signatures() {
539            // Retrieve the signer.
540            let signer = signature.to_address();
541            // Ensure the signer is in the committee.
542            if !committee_lookback.is_committee_member(signer) {
543                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
544            }
545            // Append the signer.
546            signers.insert(signer);
547        }
548
549        // Ensure the signatures have reached the quorum threshold.
550        if !committee_lookback.is_quorum_threshold_reached(&signers) {
551            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
552        }
553        Ok(missing_transmissions)
554    }
555
556    /// Inserts the given `certificate` into storage.
557    ///
558    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
559    ///
560    /// This method ensures the following invariants:
561    /// - The certificate ID does not already exist in storage.
562    /// - The batch ID does not already exist in storage.
563    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
564    /// - All previous certificates declared in the certificate exist in storage (up to GC).
565    /// - All previous certificates are for the previous round (i.e. round - 1).
566    /// - The previous certificates reached the quorum threshold (N - f).
567    pub fn insert_certificate(
568        &self,
569        certificate: BatchCertificate<N>,
570        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
571        aborted_transmissions: HashSet<TransmissionID<N>>,
572    ) -> Result<()> {
573        // Ensure the certificate round is above the GC round.
574        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
575        // Ensure the certificate and its transmissions are valid.
576        let missing_transmissions =
577            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
578        // Insert the certificate into storage.
579        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
580        Ok(())
581    }
582
583    /// Inserts the given `certificate` into storage.
584    ///
585    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
586    ///
587    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
588    fn insert_certificate_atomic(
589        &self,
590        certificate: BatchCertificate<N>,
591        aborted_transmission_ids: HashSet<TransmissionID<N>>,
592        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
593    ) {
594        // Retrieve the round.
595        let round = certificate.round();
596        // Retrieve the certificate ID.
597        let certificate_id = certificate.id();
598        // Retrieve the batch ID.
599        let batch_id = certificate.batch_id();
600        // Retrieve the author of the batch.
601        let author = certificate.author();
602
603        // Insert the round to certificate ID entry.
604        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
605        // Obtain the certificate's transmission ids.
606        let transmission_ids = certificate.transmission_ids().clone();
607        // Insert the certificate.
608        self.certificates.write().insert(certificate_id, certificate);
609        // Insert the batch ID.
610        self.batch_ids.write().insert(batch_id, round);
611        // Insert the certificate ID for each of the transmissions into storage.
612        self.transmissions.insert_transmissions(
613            certificate_id,
614            transmission_ids,
615            aborted_transmission_ids,
616            missing_transmissions,
617        );
618    }
619
620    /// Removes the given `certificate ID` from storage.
621    ///
622    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
623    ///
624    /// If the certificate was successfully removed, `true` is returned.
625    /// If the certificate did not exist in storage, `false` is returned.
626    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
627        // Retrieve the certificate.
628        let Some(certificate) = self.get_certificate(certificate_id) else {
629            warn!("Certificate {certificate_id} does not exist in storage");
630            return false;
631        };
632        // Retrieve the round.
633        let round = certificate.round();
634        // Retrieve the batch ID.
635        let batch_id = certificate.batch_id();
636        // Compute the author of the batch.
637        let author = certificate.author();
638
639        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
640        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
641        //  It will be better to write tests that compare the union of the sets.
642
643        // Update the round.
644        match self.rounds.write().entry(round) {
645            Entry::Occupied(mut entry) => {
646                // Remove the round to certificate ID entry.
647                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
648                // If the round is empty, remove it.
649                if entry.get().is_empty() {
650                    entry.swap_remove();
651                }
652            }
653            Entry::Vacant(_) => {}
654        }
655        // Remove the certificate.
656        self.certificates.write().swap_remove(&certificate_id);
657        // Remove the batch ID.
658        self.batch_ids.write().swap_remove(&batch_id);
659        // Remove the transmission entries in the certificate from storage.
660        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
661        // Return successfully.
662        true
663    }
664}
665
666impl<N: Network> Storage<N> {
667    /// Syncs the current height with the block.
668    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
669        // If the block height is greater than the current height in storage, sync the height.
670        if next_height > self.current_height() {
671            // Update the current height in storage.
672            self.current_height.store(next_height, Ordering::SeqCst);
673        }
674    }
675
676    /// Syncs the current round with the block.
677    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
678        // Retrieve the current round in the block.
679        let next_round = next_round.max(1);
680        // If the round in the block is greater than the current round in storage, sync the round.
681        if next_round > self.current_round() {
682            // Update the current round in storage.
683            self.update_current_round(next_round);
684            // Log the updated round.
685            info!("Synced to round {next_round}...");
686        }
687    }
688
689    /// Syncs the batch certificate with the block.
690    pub(crate) fn sync_certificate_with_block(
691        &self,
692        block: &Block<N>,
693        certificate: BatchCertificate<N>,
694        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
695    ) {
696        // Skip if the certificate round is below the GC round.
697        if certificate.round() <= self.gc_round() {
698            return;
699        }
700        // If the certificate ID already exists in storage, skip it.
701        if self.contains_certificate(certificate.id()) {
702            return;
703        }
704        // Retrieve the transmissions for the certificate.
705        let mut missing_transmissions = HashMap::new();
706
707        // Retrieve the aborted transmissions for the certificate.
708        let mut aborted_transmissions = HashSet::new();
709
710        // Track the block's aborted solutions and transactions.
711        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
712        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
713
714        // Iterate over the transmission IDs.
715        for transmission_id in certificate.transmission_ids() {
716            // If the transmission ID already exists in the map, skip it.
717            if missing_transmissions.contains_key(transmission_id) {
718                continue;
719            }
720            // If the transmission ID exists in storage, skip it.
721            if self.contains_transmission(*transmission_id) {
722                continue;
723            }
724            // Retrieve the transmission.
725            match transmission_id {
726                TransmissionID::Ratification => (),
727                TransmissionID::Solution(solution_id, _) => {
728                    // Retrieve the solution.
729                    match block.get_solution(solution_id) {
730                        // Insert the solution.
731                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
732                        // Otherwise, try to load the solution from the ledger.
733                        None => match self.ledger.get_solution(solution_id) {
734                            // Insert the solution.
735                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
736                            // Check if the solution is in the aborted solutions.
737                            Err(_) => {
738                                // Insert the aborted solution if it exists in the block or ledger.
739                                match aborted_solutions.contains(solution_id)
740                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
741                                {
742                                    true => {
743                                        aborted_transmissions.insert(*transmission_id);
744                                    }
745                                    false => error!("Missing solution {solution_id} in block {}", block.height()),
746                                }
747                                continue;
748                            }
749                        },
750                    };
751                }
752                TransmissionID::Transaction(transaction_id, _) => {
753                    // Retrieve the transaction.
754                    match unconfirmed_transactions.get(transaction_id) {
755                        // Insert the transaction.
756                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
757                        // Otherwise, try to load the unconfirmed transaction from the ledger.
758                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
759                            // Insert the transaction.
760                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
761                            // Check if the transaction is in the aborted transactions.
762                            Err(_) => {
763                                // Insert the aborted transaction if it exists in the block or ledger.
764                                match aborted_transactions.contains(transaction_id)
765                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
766                                {
767                                    true => {
768                                        aborted_transmissions.insert(*transmission_id);
769                                    }
770                                    false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
771                                }
772                                continue;
773                            }
774                        },
775                    };
776                }
777            }
778        }
779        // Insert the batch certificate into storage.
780        let certificate_id = fmt_id(certificate.id());
781        debug!(
782            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
783            certificate.round(),
784            certificate.transmission_ids().len()
785        );
786        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
787            error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
788        }
789    }
790}
791
792#[cfg(test)]
793impl<N: Network> Storage<N> {
794    /// Returns the ledger service.
795    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
796        &self.ledger
797    }
798
799    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
800    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
801        self.rounds.read().clone().into_iter()
802    }
803
804    /// Returns an iterator over the `(certificate ID, certificate)` entries.
805    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
806        self.certificates.read().clone().into_iter()
807    }
808
809    /// Returns an iterator over the `(batch ID, round)` entries.
810    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
811        self.batch_ids.read().clone().into_iter()
812    }
813
814    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
815    pub fn transmissions_iter(
816        &self,
817    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
818        self.transmissions.as_hashmap().into_iter()
819    }
820
821    /// Inserts the given `certificate` into storage.
822    ///
823    /// Note: Do NOT use this in production. This is for **testing only**.
824    #[cfg(test)]
825    #[doc(hidden)]
826    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
827        // Retrieve the round.
828        let round = certificate.round();
829        // Retrieve the certificate ID.
830        let certificate_id = certificate.id();
831        // Retrieve the batch ID.
832        let batch_id = certificate.batch_id();
833        // Retrieve the author of the batch.
834        let author = certificate.author();
835
836        // Insert the round to certificate ID entry.
837        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
838        // Obtain the certificate's transmission ids.
839        let transmission_ids = certificate.transmission_ids().clone();
840        // Insert the certificate.
841        self.certificates.write().insert(certificate_id, certificate);
842        // Insert the batch ID.
843        self.batch_ids.write().insert(batch_id, round);
844
845        // Construct the dummy missing transmissions (for testing purposes).
846        let missing_transmissions = transmission_ids
847            .iter()
848            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
849            .collect::<HashMap<_, _>>();
850        // Insert the certificate ID for each of the transmissions into storage.
851        self.transmissions.insert_transmissions(
852            certificate_id,
853            transmission_ids,
854            Default::default(),
855            missing_transmissions,
856        );
857    }
858}
859
860#[cfg(test)]
861pub(crate) mod tests {
862    use super::*;
863    use amareleo_node_bft_ledger_service::MockLedgerService;
864    use amareleo_node_bft_storage_service::BFTMemoryService;
865    use snarkvm::{
866        ledger::narwhal::Data,
867        prelude::{Rng, TestRng},
868    };
869
870    use ::bytes::Bytes;
871    use indexmap::indexset;
872
873    type CurrentNetwork = snarkvm::prelude::MainnetV0;
874
875    /// Asserts that the storage matches the expected layout.
876    pub fn assert_storage<N: Network>(
877        storage: &Storage<N>,
878        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
879        certificates: &[(Field<N>, BatchCertificate<N>)],
880        batch_ids: &[(Field<N>, u64)],
881        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
882    ) {
883        // Ensure the rounds are well-formed.
884        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
885        // Ensure the certificates are well-formed.
886        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
887        // Ensure the batch IDs are well-formed.
888        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
889        // Ensure the transmissions are well-formed.
890        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
891    }
892
893    /// Samples a random transmission.
894    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
895        // Sample random fake solution bytes.
896        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
897        // Sample random fake transaction bytes.
898        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
899        // Sample a random transmission.
900        match rng.gen::<bool>() {
901            true => Transmission::Solution(s(rng)),
902            false => Transmission::Transaction(t(rng)),
903        }
904    }
905
906    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
907    pub(crate) fn sample_transmissions(
908        certificate: &BatchCertificate<CurrentNetwork>,
909        rng: &mut TestRng,
910    ) -> (
911        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
912        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
913    ) {
914        // Retrieve the certificate ID.
915        let certificate_id = certificate.id();
916
917        let mut missing_transmissions = HashMap::new();
918        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
919        for transmission_id in certificate.transmission_ids() {
920            // Initialize the transmission.
921            let transmission = sample_transmission(rng);
922            // Update the missing transmissions.
923            missing_transmissions.insert(*transmission_id, transmission.clone());
924            // Update the transmissions map.
925            transmissions
926                .entry(*transmission_id)
927                .or_insert((transmission, Default::default()))
928                .1
929                .insert(certificate_id);
930        }
931        (missing_transmissions, transmissions)
932    }
933
934    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
935
936    #[test]
937    fn test_certificate_insert_remove() {
938        let rng = &mut TestRng::default();
939
940        // Sample a committee.
941        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
942        // Initialize the ledger.
943        let ledger = Arc::new(MockLedgerService::new(committee));
944        // Initialize the storage.
945        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
946
947        // Ensure the storage is empty.
948        assert_storage(&storage, &[], &[], &[], &Default::default());
949
950        // Create a new certificate.
951        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
952        // Retrieve the certificate ID.
953        let certificate_id = certificate.id();
954        // Retrieve the round.
955        let round = certificate.round();
956        // Retrieve the batch ID.
957        let batch_id = certificate.batch_id();
958        // Retrieve the author of the batch.
959        let author = certificate.author();
960
961        // Construct the sample 'transmissions'.
962        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
963
964        // Insert the certificate.
965        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
966        // Ensure the certificate exists in storage.
967        assert!(storage.contains_certificate(certificate_id));
968        // Ensure the certificate is stored in the correct round.
969        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
970        // Ensure the certificate is stored for the correct round and author.
971        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
972
973        // Check that the underlying storage representation is correct.
974        {
975            // Construct the expected layout for 'rounds'.
976            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
977            // Construct the expected layout for 'certificates'.
978            let certificates = [(certificate_id, certificate.clone())];
979            // Construct the expected layout for 'batch_ids'.
980            let batch_ids = [(batch_id, round)];
981            // Assert the storage is well-formed.
982            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
983        }
984
985        // Retrieve the certificate.
986        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
987        // Ensure the retrieved certificate is the same as the inserted certificate.
988        assert_eq!(certificate, candidate_certificate);
989
990        // Remove the certificate.
991        assert!(storage.remove_certificate(certificate_id));
992        // Ensure the certificate does not exist in storage.
993        assert!(!storage.contains_certificate(certificate_id));
994        // Ensure the certificate is no longer stored in the round.
995        assert!(storage.get_certificates_for_round(round).is_empty());
996        // Ensure the certificate is no longer stored for the round and author.
997        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
998        // Ensure the storage is empty.
999        assert_storage(&storage, &[], &[], &[], &Default::default());
1000    }
1001
1002    #[test]
1003    fn test_certificate_duplicate() {
1004        let rng = &mut TestRng::default();
1005
1006        // Sample a committee.
1007        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1008        // Initialize the ledger.
1009        let ledger = Arc::new(MockLedgerService::new(committee));
1010        // Initialize the storage.
1011        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1012
1013        // Ensure the storage is empty.
1014        assert_storage(&storage, &[], &[], &[], &Default::default());
1015
1016        // Create a new certificate.
1017        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1018        // Retrieve the certificate ID.
1019        let certificate_id = certificate.id();
1020        // Retrieve the round.
1021        let round = certificate.round();
1022        // Retrieve the batch ID.
1023        let batch_id = certificate.batch_id();
1024        // Retrieve the author of the batch.
1025        let author = certificate.author();
1026
1027        // Construct the expected layout for 'rounds'.
1028        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1029        // Construct the expected layout for 'certificates'.
1030        let certificates = [(certificate_id, certificate.clone())];
1031        // Construct the expected layout for 'batch_ids'.
1032        let batch_ids = [(batch_id, round)];
1033        // Construct the sample 'transmissions'.
1034        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1035
1036        // Insert the certificate.
1037        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1038        // Ensure the certificate exists in storage.
1039        assert!(storage.contains_certificate(certificate_id));
1040        // Check that the underlying storage representation is correct.
1041        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1042
1043        // Insert the certificate again - without any missing transmissions.
1044        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1045        // Ensure the certificate exists in storage.
1046        assert!(storage.contains_certificate(certificate_id));
1047        // Check that the underlying storage representation remains unchanged.
1048        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1049
1050        // Insert the certificate again - with all of the original missing transmissions.
1051        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1052        // Ensure the certificate exists in storage.
1053        assert!(storage.contains_certificate(certificate_id));
1054        // Check that the underlying storage representation remains unchanged.
1055        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1056    }
1057}
1058
1059#[cfg(test)]
1060pub mod prop_tests {
1061    use super::*;
1062    use crate::helpers::{now, storage::tests::assert_storage};
1063    use amareleo_node_bft_ledger_service::MockLedgerService;
1064    use amareleo_node_bft_storage_service::BFTMemoryService;
1065    use snarkvm::{
1066        ledger::{
1067            committee::prop_tests::{CommitteeContext, ValidatorSet},
1068            narwhal::{BatchHeader, Data},
1069            puzzle::SolutionID,
1070        },
1071        prelude::{Signature, Uniform},
1072    };
1073
1074    use ::bytes::Bytes;
1075    use indexmap::indexset;
1076    use proptest::{
1077        collection,
1078        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1079        prop_oneof,
1080        sample::{Selector, size_range},
1081        test_runner::TestRng,
1082    };
1083    use rand::{CryptoRng, Error, Rng, RngCore};
1084    use std::fmt::Debug;
1085    use test_strategy::proptest;
1086
1087    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1088
1089    impl Arbitrary for Storage<CurrentNetwork> {
1090        type Parameters = CommitteeContext;
1091        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1092
1093        fn arbitrary() -> Self::Strategy {
1094            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1095                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1096                    let ledger = Arc::new(MockLedgerService::new(committee));
1097                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1098                })
1099                .boxed()
1100        }
1101
1102        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1103            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1104                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1105                    let ledger = Arc::new(MockLedgerService::new(committee));
1106                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1107                })
1108                .boxed()
1109        }
1110    }
1111
1112    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1113    #[derive(Debug)]
1114    pub struct CryptoTestRng(TestRng);
1115
1116    impl Arbitrary for CryptoTestRng {
1117        type Parameters = ();
1118        type Strategy = BoxedStrategy<CryptoTestRng>;
1119
1120        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1121            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1122        }
1123    }
1124    impl RngCore for CryptoTestRng {
1125        fn next_u32(&mut self) -> u32 {
1126            self.0.next_u32()
1127        }
1128
1129        fn next_u64(&mut self) -> u64 {
1130            self.0.next_u64()
1131        }
1132
1133        fn fill_bytes(&mut self, dest: &mut [u8]) {
1134            self.0.fill_bytes(dest);
1135        }
1136
1137        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1138            self.0.try_fill_bytes(dest)
1139        }
1140    }
1141
1142    impl CryptoRng for CryptoTestRng {}
1143
1144    #[derive(Debug, Clone)]
1145    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1146
1147    impl Arbitrary for AnyTransmission {
1148        type Parameters = ();
1149        type Strategy = BoxedStrategy<AnyTransmission>;
1150
1151        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1152            any_transmission().prop_map(AnyTransmission).boxed()
1153        }
1154    }
1155
1156    #[derive(Debug, Clone)]
1157    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1158
1159    impl Arbitrary for AnyTransmissionID {
1160        type Parameters = ();
1161        type Strategy = BoxedStrategy<AnyTransmissionID>;
1162
1163        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1164            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1165        }
1166    }
1167
1168    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1169        prop_oneof![
1170            (collection::vec(any::<u8>(), 512..=512))
1171                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1172            (collection::vec(any::<u8>(), 2048..=2048))
1173                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1174        ]
1175        .boxed()
1176    }
1177
1178    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1179        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1180    }
1181
1182    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1183        Just(0)
1184            .prop_perturb(|_, rng| {
1185                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1186            })
1187            .boxed()
1188    }
1189
1190    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1191        prop_oneof![
1192            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1193                id,
1194                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1195            )),
1196            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1197                id,
1198                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1199            )),
1200        ]
1201        .boxed()
1202    }
1203
1204    pub fn sign_batch_header<R: Rng + CryptoRng>(
1205        validator_set: &ValidatorSet,
1206        batch_header: &BatchHeader<CurrentNetwork>,
1207        rng: &mut R,
1208    ) -> IndexSet<Signature<CurrentNetwork>> {
1209        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1210        for validator in validator_set.0.iter() {
1211            let private_key = validator.private_key;
1212            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1213        }
1214        signatures
1215    }
1216
1217    #[proptest]
1218    fn test_certificate_duplicate(
1219        context: CommitteeContext,
1220        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1221        mut rng: CryptoTestRng,
1222        selector: Selector,
1223    ) {
1224        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1225        let committee_id = committee.id();
1226
1227        // Initialize the storage.
1228        let ledger = Arc::new(MockLedgerService::new(committee));
1229        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1230
1231        // Ensure the storage is empty.
1232        assert_storage(&storage, &[], &[], &[], &Default::default());
1233
1234        // Create a new certificate.
1235        let signer = selector.select(&validators);
1236
1237        let mut transmission_map = IndexMap::new();
1238
1239        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1240            transmission_map.insert(*id, t.clone());
1241        }
1242
1243        let batch_header = BatchHeader::new(
1244            &signer.private_key,
1245            0,
1246            now(),
1247            committee_id,
1248            transmission_map.keys().cloned().collect(),
1249            Default::default(),
1250            &mut rng,
1251        )
1252        .unwrap();
1253
1254        // Remove the author from the validator set passed to create the batch
1255        // certificate, the author should not sign their own batch.
1256        let mut validators = validators.clone();
1257        validators.remove(signer);
1258
1259        let certificate = BatchCertificate::from(
1260            batch_header.clone(),
1261            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1262        )
1263        .unwrap();
1264
1265        // Retrieve the certificate ID.
1266        let certificate_id = certificate.id();
1267        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1268        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1269            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1270        }
1271
1272        // Retrieve the round.
1273        let round = certificate.round();
1274        // Retrieve the batch ID.
1275        let batch_id = certificate.batch_id();
1276        // Retrieve the author of the batch.
1277        let author = certificate.author();
1278
1279        // Construct the expected layout for 'rounds'.
1280        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1281        // Construct the expected layout for 'certificates'.
1282        let certificates = [(certificate_id, certificate.clone())];
1283        // Construct the expected layout for 'batch_ids'.
1284        let batch_ids = [(batch_id, round)];
1285
1286        // Insert the certificate.
1287        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1288            transmission_map.into_iter().collect();
1289        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1290        // Ensure the certificate exists in storage.
1291        assert!(storage.contains_certificate(certificate_id));
1292        // Check that the underlying storage representation is correct.
1293        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1294
1295        // Insert the certificate again - without any missing transmissions.
1296        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1297        // Ensure the certificate exists in storage.
1298        assert!(storage.contains_certificate(certificate_id));
1299        // Check that the underlying storage representation remains unchanged.
1300        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1301
1302        // Insert the certificate again - with all of the original missing transmissions.
1303        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1304        // Ensure the certificate exists in storage.
1305        assert!(storage.contains_certificate(certificate_id));
1306        // Check that the underlying storage representation remains unchanged.
1307        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1308    }
1309}