amareleo_node_bft/helpers/
storage.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21    ledger::{
22        block::{Block, Transaction},
23        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24    },
25    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26    utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::RwLock;
34use rayon::iter::{IntoParallelIterator, ParallelIterator};
35use std::{
36    collections::{HashMap, HashSet},
37    sync::{
38        Arc,
39        atomic::{AtomicU32, AtomicU64, Ordering},
40    },
41};
42use tracing::subscriber::DefaultGuard;
43
44#[derive(Clone, Debug)]
45pub struct Storage<N: Network>(Arc<StorageInner<N>>);
46
47impl<N: Network> std::ops::Deref for Storage<N> {
48    type Target = Arc<StorageInner<N>>;
49
50    fn deref(&self) -> &Self::Target {
51        &self.0
52    }
53}
54
55impl<N: Network> TracingHandlerGuard for Storage<N> {
56    /// Retruns tracing guard
57    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
58        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
59    }
60}
61
62/// The storage for the memory pool.
63///
64/// The storage is used to store the following:
65/// - `current_height` tracker.
66/// - `current_round` tracker.
67/// - `round` to `(certificate ID, batch ID, author)` entries.
68/// - `certificate ID` to `certificate` entries.
69/// - `batch ID` to `round` entries.
70/// - `transmission ID` to `(transmission, certificate IDs)` entries.
71///
72/// The chain of events is as follows:
73/// 1. A `transmission` is received.
74/// 2. After a `batch` is ready to be stored:
75///   - The `certificate` is inserted, triggering updates to the
76///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
77///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
78///   - The certificate ID is inserted into the `transmissions` map.
79/// 3. After a `round` reaches quorum threshold:
80///  - The next round is inserted into the `current_round`.
81#[derive(Debug)]
82pub struct StorageInner<N: Network> {
83    /// The ledger service.
84    ledger: Arc<dyn LedgerService<N>>,
85    /* Once per block */
86    /// The current height.
87    current_height: AtomicU32,
88    /* Once per round */
89    /// The current round.
90    current_round: AtomicU64,
91    /// The `round` for which garbage collection has occurred **up to** (inclusive).
92    gc_round: AtomicU64,
93    /// The maximum number of rounds to keep in storage.
94    max_gc_rounds: u64,
95    /* Once per batch */
96    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
97    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
98    /// The map of `certificate ID` to `certificate`.
99    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
100    /// The map of `batch ID` to `round`.
101    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
102    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
103    transmissions: Arc<dyn StorageService<N>>,
104    /// Tracing handle
105    tracing: Option<TracingHandler>,
106}
107
108impl<N: Network> Storage<N> {
109    /// Initializes a new instance of storage.
110    pub fn new(
111        ledger: Arc<dyn LedgerService<N>>,
112        transmissions: Arc<dyn StorageService<N>>,
113        max_gc_rounds: u64,
114        tracing: Option<TracingHandler>,
115    ) -> Self {
116        // Retrieve the current committee.
117        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
118        // Retrieve the current round.
119        let current_round = committee.starting_round().max(1);
120
121        // Return the storage.
122        let storage = Self(Arc::new(StorageInner {
123            ledger,
124            current_height: Default::default(),
125            current_round: Default::default(),
126            gc_round: Default::default(),
127            max_gc_rounds,
128            rounds: Default::default(),
129            certificates: Default::default(),
130            batch_ids: Default::default(),
131            transmissions,
132            tracing,
133        }));
134        // Update the storage to the current round.
135        storage.update_current_round(current_round);
136        // Perform GC on the current round.
137        storage.garbage_collect_certificates(current_round);
138        // Return the storage.
139        storage
140    }
141}
142
143impl<N: Network> Storage<N> {
144    /// Returns the current height.
145    pub fn current_height(&self) -> u32 {
146        // Get the current height.
147        self.current_height.load(Ordering::SeqCst)
148    }
149}
150
151impl<N: Network> Storage<N> {
152    /// Returns the current round.
153    pub fn current_round(&self) -> u64 {
154        // Get the current round.
155        self.current_round.load(Ordering::SeqCst)
156    }
157
158    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
159    pub fn gc_round(&self) -> u64 {
160        // Get the GC round.
161        self.gc_round.load(Ordering::SeqCst)
162    }
163
164    /// Returns the maximum number of rounds to keep in storage.
165    pub fn max_gc_rounds(&self) -> u64 {
166        self.max_gc_rounds
167    }
168
169    /// Increments storage to the next round, updating the current round.
170    /// Note: This method is only called once per round, upon certification of the primary's batch.
171    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
172        // Determine the next round.
173        let next_round = current_round + 1;
174
175        // Check if the next round is less than the current round in storage.
176        {
177            // Retrieve the storage round.
178            let storage_round = self.current_round();
179            // If the next round is less than the current round in storage, return early with the storage round.
180            if next_round < storage_round {
181                return Ok(storage_round);
182            }
183        }
184
185        // Retrieve the current committee.
186        let current_committee = self.ledger.current_committee()?;
187        // Retrieve the current committee's starting round.
188        let starting_round = current_committee.starting_round();
189        // If the primary is behind the current committee's starting round, sync with the latest block.
190        if next_round < starting_round {
191            // Retrieve the latest block round.
192            let latest_block_round = self.ledger.latest_round();
193            // Log the round sync.
194            guard_info!(
195                self,
196                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
197            );
198            // Sync the round with the latest block.
199            self.sync_round_with_block(latest_block_round);
200            // Return the latest block round.
201            return Ok(latest_block_round);
202        }
203
204        // Update the storage to the next round.
205        self.update_current_round(next_round);
206
207        #[cfg(feature = "metrics")]
208        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
209
210        // Retrieve the storage round.
211        let storage_round = self.current_round();
212        // Retrieve the GC round.
213        let gc_round = self.gc_round();
214        // Ensure the next round matches in storage.
215        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
216        // Ensure the next round is greater than or equal to the GC round.
217        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
218
219        // Log the updated round.
220        guard_info!(self, "Starting round {next_round}...");
221        Ok(next_round)
222    }
223
224    /// Updates the storage to the next round.
225    fn update_current_round(&self, next_round: u64) {
226        // Update the current round.
227        self.current_round.store(next_round, Ordering::SeqCst);
228    }
229
230    /// Update the storage by performing garbage collection based on the next round.
231    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
232        // Fetch the current GC round.
233        let current_gc_round = self.gc_round();
234        // Compute the next GC round.
235        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
236        // Check if storage needs to be garbage collected.
237        if next_gc_round > current_gc_round {
238            // Remove the GC round(s) from storage.
239            for gc_round in current_gc_round..=next_gc_round {
240                // Iterate over the certificates for the GC round.
241                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
242                    // Remove the certificate from storage.
243                    self.remove_certificate(id);
244                }
245            }
246            // Update the GC round.
247            self.gc_round.store(next_gc_round, Ordering::SeqCst);
248        }
249    }
250}
251
252impl<N: Network> Storage<N> {
253    /// Returns `true` if the storage contains the specified `round`.
254    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
255        // Check if the round exists in storage.
256        self.rounds.read().contains_key(&round)
257    }
258
259    /// Returns `true` if the storage contains the specified `certificate ID`.
260    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
261        // Check if the certificate ID exists in storage.
262        self.certificates.read().contains_key(&certificate_id)
263    }
264
265    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
266    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
267        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
268    }
269
270    /// Returns `true` if the storage contains the specified `batch ID`.
271    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
272        // Check if the batch ID exists in storage.
273        self.batch_ids.read().contains_key(&batch_id)
274    }
275
276    /// Returns `true` if the storage contains the specified `transmission ID`.
277    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
278        self.transmissions.contains_transmission(transmission_id.into())
279    }
280
281    /// Returns the transmission for the given `transmission ID`.
282    /// If the transmission ID does not exist in storage, `None` is returned.
283    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
284        self.transmissions.get_transmission(transmission_id.into())
285    }
286
287    /// Returns the round for the given `certificate ID`.
288    /// If the certificate ID does not exist in storage, `None` is returned.
289    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
290        // Get the round.
291        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
292    }
293
294    /// Returns the round for the given `batch ID`.
295    /// If the batch ID does not exist in storage, `None` is returned.
296    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
297        // Get the round.
298        self.batch_ids.read().get(&batch_id).copied()
299    }
300
301    /// Returns the certificate round for the given `certificate ID`.
302    /// If the certificate ID does not exist in storage, `None` is returned.
303    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
304        // Get the batch certificate and return the round.
305        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
306    }
307
308    /// Returns the certificate for the given `certificate ID`.
309    /// If the certificate ID does not exist in storage, `None` is returned.
310    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
311        // Get the batch certificate.
312        self.certificates.read().get(&certificate_id).cloned()
313    }
314
315    /// Returns the certificate for the given `round` and `author`.
316    /// If the round does not exist in storage, `None` is returned.
317    /// If the author for the round does not exist in storage, `None` is returned.
318    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
319        // Retrieve the certificates.
320        if let Some(entries) = self.rounds.read().get(&round) {
321            let certificates = self.certificates.read();
322            entries.iter().find_map(
323                |(certificate_id, _, a)| {
324                    if a == &author { certificates.get(certificate_id).cloned() } else { None }
325                },
326            )
327        } else {
328            Default::default()
329        }
330    }
331
332    /// Returns the certificates for the given `round`.
333    /// If the round does not exist in storage, an empty set is returned.
334    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
335        // The genesis round does not have batch certificates.
336        if round == 0 {
337            return Default::default();
338        }
339        // Retrieve the certificates.
340        if let Some(entries) = self.rounds.read().get(&round) {
341            let certificates = self.certificates.read();
342            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
343        } else {
344            Default::default()
345        }
346    }
347
348    /// Returns the certificate IDs for the given `round`.
349    /// If the round does not exist in storage, an empty set is returned.
350    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
351        // The genesis round does not have batch certificates.
352        if round == 0 {
353            return Default::default();
354        }
355        // Retrieve the certificates.
356        if let Some(entries) = self.rounds.read().get(&round) {
357            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
358        } else {
359            Default::default()
360        }
361    }
362
363    /// Returns the certificate authors for the given `round`.
364    /// If the round does not exist in storage, an empty set is returned.
365    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
366        // The genesis round does not have batch certificates.
367        if round == 0 {
368            return Default::default();
369        }
370        // Retrieve the certificates.
371        if let Some(entries) = self.rounds.read().get(&round) {
372            entries.iter().map(|(_, _, author)| *author).collect()
373        } else {
374            Default::default()
375        }
376    }
377
378    /// Returns the certificates that have not yet been included in the ledger.
379    /// Note that the order of this set is by round and then insertion.
380    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
381        // Obtain the read locks.
382        let rounds = self.rounds.read();
383        let certificates = self.certificates.read();
384
385        // Iterate over the rounds.
386        cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
387            .flat_map(|(_, certificates_for_round)| {
388                // Iterate over the certificates for the round.
389                cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
390                    // Skip the certificate if it already exists in the ledger.
391                    if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
392                        None
393                    } else {
394                        // Add the certificate to the pending certificates.
395                        certificates.get(&certificate_id).cloned()
396                    }
397                })
398            })
399            .collect()
400    }
401
402    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
403    ///
404    /// This method ensures the following invariants:
405    /// - The batch ID does not already exist in storage.
406    /// - The author is a member of the committee for the batch round.
407    /// - The timestamp is within the allowed time range.
408    /// - None of the transmissions are from any past rounds (up to GC).
409    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
410    /// - All previous certificates declared in the certificate exist in storage (up to GC).
411    /// - All previous certificates are for the previous round (i.e. round - 1).
412    /// - All previous certificates contain a unique author.
413    /// - The previous certificates reached the quorum threshold (N - f).
414    pub fn check_batch_header(
415        &self,
416        batch_header: &BatchHeader<N>,
417        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
418        aborted_transmissions: HashSet<TransmissionID<N>>,
419    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
420        // Retrieve the round.
421        let round = batch_header.round();
422        // Retrieve the GC round.
423        let gc_round = self.gc_round();
424        // Construct a GC log message.
425        let gc_log = format!("(gc = {gc_round})");
426
427        // Ensure the batch ID does not already exist in storage.
428        if self.contains_batch(batch_header.batch_id()) {
429            bail!("Batch for round {round} already exists in storage {gc_log}")
430        }
431
432        // Retrieve the committee lookback for the batch round.
433        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
434            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
435        };
436        // Ensure the author is in the committee.
437        if !committee_lookback.is_committee_member(batch_header.author()) {
438            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
439        }
440
441        // Check the timestamp for liveness.
442        check_timestamp_for_liveness(batch_header.timestamp())?;
443
444        // Retrieve the missing transmissions in storage from the given transmissions.
445        let missing_transmissions = self
446            .transmissions
447            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
448            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
449
450        // Compute the previous round.
451        let previous_round = round.saturating_sub(1);
452        // Check if the previous round is within range of the GC round.
453        if previous_round > gc_round {
454            // Retrieve the committee lookback for the previous round.
455            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
456                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
457            };
458            // Ensure the previous round certificates exists in storage.
459            if !self.contains_certificates_for_round(previous_round) {
460                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
461            }
462            // Ensure the number of previous certificate IDs is at or below the number of committee members.
463            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
464                bail!("Too many previous certificates for round {round} {gc_log}")
465            }
466            // Initialize a set of the previous authors.
467            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
468            // Ensure storage contains all declared previous certificates (up to GC).
469            for previous_certificate_id in batch_header.previous_certificate_ids() {
470                // Retrieve the previous certificate.
471                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
472                    bail!(
473                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
474                        fmt_id(previous_certificate_id)
475                    )
476                };
477                // Ensure the previous certificate is for the previous round.
478                if previous_certificate.round() != previous_round {
479                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
480                }
481                // Ensure the previous author is new.
482                if previous_authors.contains(&previous_certificate.author()) {
483                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
484                }
485                // Insert the author of the previous certificate.
486                previous_authors.insert(previous_certificate.author());
487            }
488            // Ensure the previous certificates have reached the quorum threshold.
489            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
490                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
491            }
492        }
493        Ok(missing_transmissions)
494    }
495
496    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
497    ///
498    /// This method ensures the following invariants:
499    /// - The certificate ID does not already exist in storage.
500    /// - The batch ID does not already exist in storage.
501    /// - The author is a member of the committee for the batch round.
502    /// - The author has not already created a certificate for the batch round.
503    /// - The timestamp is within the allowed time range.
504    /// - None of the transmissions are from any past rounds (up to GC).
505    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
506    /// - All previous certificates declared in the certificate exist in storage (up to GC).
507    /// - All previous certificates are for the previous round (i.e. round - 1).
508    /// - The previous certificates reached the quorum threshold (N - f).
509    /// - The timestamps from the signers are all within the allowed time range.
510    /// - The signers have reached the quorum threshold (N - f).
511    pub fn check_certificate(
512        &self,
513        certificate: &BatchCertificate<N>,
514        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
515        aborted_transmissions: HashSet<TransmissionID<N>>,
516    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
517        // Retrieve the round.
518        let round = certificate.round();
519        // Retrieve the GC round.
520        let gc_round = self.gc_round();
521        // Construct a GC log message.
522        let gc_log = format!("(gc = {gc_round})");
523
524        // Ensure the certificate ID does not already exist in storage.
525        if self.contains_certificate(certificate.id()) {
526            bail!("Certificate for round {round} already exists in storage {gc_log}")
527        }
528
529        // Ensure the storage does not already contain a certificate for this author in this round.
530        if self.contains_certificate_in_round_from(round, certificate.author()) {
531            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
532        }
533
534        // Ensure the batch header is well-formed.
535        let missing_transmissions =
536            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
537
538        // Check the timestamp for liveness.
539        check_timestamp_for_liveness(certificate.timestamp())?;
540
541        // Retrieve the committee lookback for the batch round.
542        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
543            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
544        };
545
546        // Initialize a set of the signers.
547        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
548        // Append the batch author.
549        signers.insert(certificate.author());
550
551        // Iterate over the signatures.
552        for signature in certificate.signatures() {
553            // Retrieve the signer.
554            let signer = signature.to_address();
555            // Ensure the signer is in the committee.
556            if !committee_lookback.is_committee_member(signer) {
557                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
558            }
559            // Append the signer.
560            signers.insert(signer);
561        }
562
563        // Ensure the signatures have reached the quorum threshold.
564        if !committee_lookback.is_quorum_threshold_reached(&signers) {
565            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
566        }
567        Ok(missing_transmissions)
568    }
569
570    /// Inserts the given `certificate` into storage.
571    ///
572    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
573    ///
574    /// This method ensures the following invariants:
575    /// - The certificate ID does not already exist in storage.
576    /// - The batch ID does not already exist in storage.
577    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
578    /// - All previous certificates declared in the certificate exist in storage (up to GC).
579    /// - All previous certificates are for the previous round (i.e. round - 1).
580    /// - The previous certificates reached the quorum threshold (N - f).
581    pub fn insert_certificate(
582        &self,
583        certificate: BatchCertificate<N>,
584        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
585        aborted_transmissions: HashSet<TransmissionID<N>>,
586    ) -> Result<()> {
587        // Ensure the certificate round is above the GC round.
588        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
589        // Ensure the certificate and its transmissions are valid.
590        let missing_transmissions =
591            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
592        // Insert the certificate into storage.
593        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
594        Ok(())
595    }
596
597    /// Inserts the given `certificate` into storage.
598    ///
599    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
600    ///
601    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
602    fn insert_certificate_atomic(
603        &self,
604        certificate: BatchCertificate<N>,
605        aborted_transmission_ids: HashSet<TransmissionID<N>>,
606        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
607    ) {
608        // Retrieve the round.
609        let round = certificate.round();
610        // Retrieve the certificate ID.
611        let certificate_id = certificate.id();
612        // Retrieve the batch ID.
613        let batch_id = certificate.batch_id();
614        // Retrieve the author of the batch.
615        let author = certificate.author();
616
617        // Insert the round to certificate ID entry.
618        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
619        // Obtain the certificate's transmission ids.
620        let transmission_ids = certificate.transmission_ids().clone();
621        // Insert the certificate.
622        self.certificates.write().insert(certificate_id, certificate);
623        // Insert the batch ID.
624        self.batch_ids.write().insert(batch_id, round);
625        // Insert the certificate ID for each of the transmissions into storage.
626        self.transmissions.insert_transmissions(
627            certificate_id,
628            transmission_ids,
629            aborted_transmission_ids,
630            missing_transmissions,
631        );
632    }
633
634    /// Removes the given `certificate ID` from storage.
635    ///
636    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
637    ///
638    /// If the certificate was successfully removed, `true` is returned.
639    /// If the certificate did not exist in storage, `false` is returned.
640    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
641        // Retrieve the certificate.
642        let Some(certificate) = self.get_certificate(certificate_id) else {
643            guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
644            return false;
645        };
646        // Retrieve the round.
647        let round = certificate.round();
648        // Retrieve the batch ID.
649        let batch_id = certificate.batch_id();
650        // Compute the author of the batch.
651        let author = certificate.author();
652
653        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
654        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
655        //  It will be better to write tests that compare the union of the sets.
656
657        // Update the round.
658        match self.rounds.write().entry(round) {
659            Entry::Occupied(mut entry) => {
660                // Remove the round to certificate ID entry.
661                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
662                // If the round is empty, remove it.
663                if entry.get().is_empty() {
664                    entry.swap_remove();
665                }
666            }
667            Entry::Vacant(_) => {}
668        }
669        // Remove the certificate.
670        self.certificates.write().swap_remove(&certificate_id);
671        // Remove the batch ID.
672        self.batch_ids.write().swap_remove(&batch_id);
673        // Remove the transmission entries in the certificate from storage.
674        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
675        // Return successfully.
676        true
677    }
678}
679
680impl<N: Network> Storage<N> {
681    /// Syncs the current height with the block.
682    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
683        // If the block height is greater than the current height in storage, sync the height.
684        if next_height > self.current_height() {
685            // Update the current height in storage.
686            self.current_height.store(next_height, Ordering::SeqCst);
687        }
688    }
689
690    /// Syncs the current round with the block.
691    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
692        // Retrieve the current round in the block.
693        let next_round = next_round.max(1);
694        // If the round in the block is greater than the current round in storage, sync the round.
695        if next_round > self.current_round() {
696            // Update the current round in storage.
697            self.update_current_round(next_round);
698            // Log the updated round.
699            guard_info!(self, "Synced to round {next_round}...");
700        }
701    }
702
703    /// Syncs the batch certificate with the block.
704    pub(crate) fn sync_certificate_with_block(
705        &self,
706        block: &Block<N>,
707        certificate: BatchCertificate<N>,
708        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
709    ) {
710        // Skip if the certificate round is below the GC round.
711        if certificate.round() <= self.gc_round() {
712            return;
713        }
714        // If the certificate ID already exists in storage, skip it.
715        if self.contains_certificate(certificate.id()) {
716            return;
717        }
718        // Retrieve the transmissions for the certificate.
719        let mut missing_transmissions = HashMap::new();
720
721        // Retrieve the aborted transmissions for the certificate.
722        let mut aborted_transmissions = HashSet::new();
723
724        // Track the block's aborted solutions and transactions.
725        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
726        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
727
728        // Iterate over the transmission IDs.
729        for transmission_id in certificate.transmission_ids() {
730            // If the transmission ID already exists in the map, skip it.
731            if missing_transmissions.contains_key(transmission_id) {
732                continue;
733            }
734            // If the transmission ID exists in storage, skip it.
735            if self.contains_transmission(*transmission_id) {
736                continue;
737            }
738            // Retrieve the transmission.
739            match transmission_id {
740                TransmissionID::Ratification => (),
741                TransmissionID::Solution(solution_id, _) => {
742                    // Retrieve the solution.
743                    match block.get_solution(solution_id) {
744                        // Insert the solution.
745                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
746                        // Otherwise, try to load the solution from the ledger.
747                        None => match self.ledger.get_solution(solution_id) {
748                            // Insert the solution.
749                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
750                            // Check if the solution is in the aborted solutions.
751                            Err(_) => {
752                                // Insert the aborted solution if it exists in the block or ledger.
753                                match aborted_solutions.contains(solution_id)
754                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
755                                {
756                                    true => {
757                                        aborted_transmissions.insert(*transmission_id);
758                                    }
759                                    false => {
760                                        guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
761                                    }
762                                }
763                                continue;
764                            }
765                        },
766                    };
767                }
768                TransmissionID::Transaction(transaction_id, _) => {
769                    // Retrieve the transaction.
770                    match unconfirmed_transactions.get(transaction_id) {
771                        // Insert the transaction.
772                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
773                        // Otherwise, try to load the unconfirmed transaction from the ledger.
774                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
775                            // Insert the transaction.
776                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
777                            // Check if the transaction is in the aborted transactions.
778                            Err(_) => {
779                                // Insert the aborted transaction if it exists in the block or ledger.
780                                match aborted_transactions.contains(transaction_id)
781                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
782                                {
783                                    true => {
784                                        aborted_transmissions.insert(*transmission_id);
785                                    }
786                                    false => guard_warn!(
787                                        self,
788                                        "Missing transaction {transaction_id} in block {}",
789                                        block.height()
790                                    ),
791                                }
792                                continue;
793                            }
794                        },
795                    };
796                }
797            }
798        }
799        // Insert the batch certificate into storage.
800        let certificate_id = fmt_id(certificate.id());
801        guard_debug!(
802            self,
803            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
804            certificate.round(),
805            certificate.transmission_ids().len()
806        );
807        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
808            guard_error!(
809                self,
810                "Failed to insert certificate '{certificate_id}' from block {} - {error}",
811                block.height()
812            );
813        }
814    }
815}
816
817#[cfg(test)]
818impl<N: Network> Storage<N> {
819    /// Returns the ledger service.
820    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
821        &self.ledger
822    }
823
824    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
825    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
826        self.rounds.read().clone().into_iter()
827    }
828
829    /// Returns an iterator over the `(certificate ID, certificate)` entries.
830    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
831        self.certificates.read().clone().into_iter()
832    }
833
834    /// Returns an iterator over the `(batch ID, round)` entries.
835    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
836        self.batch_ids.read().clone().into_iter()
837    }
838
839    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
840    pub fn transmissions_iter(
841        &self,
842    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
843        self.transmissions.as_hashmap().into_iter()
844    }
845
846    /// Inserts the given `certificate` into storage.
847    ///
848    /// Note: Do NOT use this in production. This is for **testing only**.
849    #[cfg(test)]
850    #[doc(hidden)]
851    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
852        // Retrieve the round.
853        let round = certificate.round();
854        // Retrieve the certificate ID.
855        let certificate_id = certificate.id();
856        // Retrieve the batch ID.
857        let batch_id = certificate.batch_id();
858        // Retrieve the author of the batch.
859        let author = certificate.author();
860
861        // Insert the round to certificate ID entry.
862        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
863        // Obtain the certificate's transmission ids.
864        let transmission_ids = certificate.transmission_ids().clone();
865        // Insert the certificate.
866        self.certificates.write().insert(certificate_id, certificate);
867        // Insert the batch ID.
868        self.batch_ids.write().insert(batch_id, round);
869
870        // Construct the dummy missing transmissions (for testing purposes).
871        let missing_transmissions = transmission_ids
872            .iter()
873            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
874            .collect::<HashMap<_, _>>();
875        // Insert the certificate ID for each of the transmissions into storage.
876        self.transmissions.insert_transmissions(
877            certificate_id,
878            transmission_ids,
879            Default::default(),
880            missing_transmissions,
881        );
882    }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887    use super::*;
888    use amareleo_node_bft_ledger_service::MockLedgerService;
889    use amareleo_node_bft_storage_service::BFTMemoryService;
890    use snarkvm::{
891        ledger::narwhal::Data,
892        prelude::{Rng, TestRng},
893    };
894
895    use ::bytes::Bytes;
896    use indexmap::indexset;
897
898    type CurrentNetwork = snarkvm::prelude::MainnetV0;
899
900    /// Asserts that the storage matches the expected layout.
901    pub fn assert_storage<N: Network>(
902        storage: &Storage<N>,
903        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
904        certificates: &[(Field<N>, BatchCertificate<N>)],
905        batch_ids: &[(Field<N>, u64)],
906        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
907    ) {
908        // Ensure the rounds are well-formed.
909        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
910        // Ensure the certificates are well-formed.
911        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
912        // Ensure the batch IDs are well-formed.
913        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
914        // Ensure the transmissions are well-formed.
915        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
916    }
917
918    /// Samples a random transmission.
919    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
920        // Sample random fake solution bytes.
921        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
922        // Sample random fake transaction bytes.
923        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
924        // Sample a random transmission.
925        match rng.gen::<bool>() {
926            true => Transmission::Solution(s(rng)),
927            false => Transmission::Transaction(t(rng)),
928        }
929    }
930
931    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
932    pub(crate) fn sample_transmissions(
933        certificate: &BatchCertificate<CurrentNetwork>,
934        rng: &mut TestRng,
935    ) -> (
936        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
937        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
938    ) {
939        // Retrieve the certificate ID.
940        let certificate_id = certificate.id();
941
942        let mut missing_transmissions = HashMap::new();
943        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
944        for transmission_id in certificate.transmission_ids() {
945            // Initialize the transmission.
946            let transmission = sample_transmission(rng);
947            // Update the missing transmissions.
948            missing_transmissions.insert(*transmission_id, transmission.clone());
949            // Update the transmissions map.
950            transmissions
951                .entry(*transmission_id)
952                .or_insert((transmission, Default::default()))
953                .1
954                .insert(certificate_id);
955        }
956        (missing_transmissions, transmissions)
957    }
958
959    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
960
961    #[test]
962    fn test_certificate_insert_remove() {
963        let rng = &mut TestRng::default();
964
965        // Sample a committee.
966        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
967        // Initialize the ledger.
968        let ledger = Arc::new(MockLedgerService::new(committee));
969        // Initialize the storage.
970        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
971
972        // Ensure the storage is empty.
973        assert_storage(&storage, &[], &[], &[], &Default::default());
974
975        // Create a new certificate.
976        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
977        // Retrieve the certificate ID.
978        let certificate_id = certificate.id();
979        // Retrieve the round.
980        let round = certificate.round();
981        // Retrieve the batch ID.
982        let batch_id = certificate.batch_id();
983        // Retrieve the author of the batch.
984        let author = certificate.author();
985
986        // Construct the sample 'transmissions'.
987        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
988
989        // Insert the certificate.
990        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
991        // Ensure the certificate exists in storage.
992        assert!(storage.contains_certificate(certificate_id));
993        // Ensure the certificate is stored in the correct round.
994        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
995        // Ensure the certificate is stored for the correct round and author.
996        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
997
998        // Check that the underlying storage representation is correct.
999        {
1000            // Construct the expected layout for 'rounds'.
1001            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1002            // Construct the expected layout for 'certificates'.
1003            let certificates = [(certificate_id, certificate.clone())];
1004            // Construct the expected layout for 'batch_ids'.
1005            let batch_ids = [(batch_id, round)];
1006            // Assert the storage is well-formed.
1007            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1008        }
1009
1010        // Retrieve the certificate.
1011        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1012        // Ensure the retrieved certificate is the same as the inserted certificate.
1013        assert_eq!(certificate, candidate_certificate);
1014
1015        // Remove the certificate.
1016        assert!(storage.remove_certificate(certificate_id));
1017        // Ensure the certificate does not exist in storage.
1018        assert!(!storage.contains_certificate(certificate_id));
1019        // Ensure the certificate is no longer stored in the round.
1020        assert!(storage.get_certificates_for_round(round).is_empty());
1021        // Ensure the certificate is no longer stored for the round and author.
1022        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1023        // Ensure the storage is empty.
1024        assert_storage(&storage, &[], &[], &[], &Default::default());
1025    }
1026
1027    #[test]
1028    fn test_certificate_duplicate() {
1029        let rng = &mut TestRng::default();
1030
1031        // Sample a committee.
1032        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1033        // Initialize the ledger.
1034        let ledger = Arc::new(MockLedgerService::new(committee));
1035        // Initialize the storage.
1036        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1037
1038        // Ensure the storage is empty.
1039        assert_storage(&storage, &[], &[], &[], &Default::default());
1040
1041        // Create a new certificate.
1042        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1043        // Retrieve the certificate ID.
1044        let certificate_id = certificate.id();
1045        // Retrieve the round.
1046        let round = certificate.round();
1047        // Retrieve the batch ID.
1048        let batch_id = certificate.batch_id();
1049        // Retrieve the author of the batch.
1050        let author = certificate.author();
1051
1052        // Construct the expected layout for 'rounds'.
1053        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1054        // Construct the expected layout for 'certificates'.
1055        let certificates = [(certificate_id, certificate.clone())];
1056        // Construct the expected layout for 'batch_ids'.
1057        let batch_ids = [(batch_id, round)];
1058        // Construct the sample 'transmissions'.
1059        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1060
1061        // Insert the certificate.
1062        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1063        // Ensure the certificate exists in storage.
1064        assert!(storage.contains_certificate(certificate_id));
1065        // Check that the underlying storage representation is correct.
1066        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1067
1068        // Insert the certificate again - without any missing transmissions.
1069        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1070        // Ensure the certificate exists in storage.
1071        assert!(storage.contains_certificate(certificate_id));
1072        // Check that the underlying storage representation remains unchanged.
1073        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1074
1075        // Insert the certificate again - with all of the original missing transmissions.
1076        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1077        // Ensure the certificate exists in storage.
1078        assert!(storage.contains_certificate(certificate_id));
1079        // Check that the underlying storage representation remains unchanged.
1080        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1081    }
1082}
1083
1084#[cfg(test)]
1085pub mod prop_tests {
1086    use super::*;
1087    use crate::helpers::{now, storage::tests::assert_storage};
1088    use amareleo_node_bft_ledger_service::MockLedgerService;
1089    use amareleo_node_bft_storage_service::BFTMemoryService;
1090    use snarkvm::{
1091        ledger::{
1092            committee::prop_tests::{CommitteeContext, ValidatorSet},
1093            narwhal::{BatchHeader, Data},
1094            puzzle::SolutionID,
1095        },
1096        prelude::{Signature, Uniform},
1097    };
1098
1099    use ::bytes::Bytes;
1100    use indexmap::indexset;
1101    use proptest::{
1102        collection,
1103        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1104        prop_oneof,
1105        sample::{Selector, size_range},
1106        test_runner::TestRng,
1107    };
1108    use rand::{CryptoRng, Error, Rng, RngCore};
1109    use std::fmt::Debug;
1110    use test_strategy::proptest;
1111
1112    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1113
1114    impl Arbitrary for Storage<CurrentNetwork> {
1115        type Parameters = CommitteeContext;
1116        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1117
1118        fn arbitrary() -> Self::Strategy {
1119            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1120                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1121                    let ledger = Arc::new(MockLedgerService::new(committee));
1122                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1123                })
1124                .boxed()
1125        }
1126
1127        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1128            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1129                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1130                    let ledger = Arc::new(MockLedgerService::new(committee));
1131                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1132                })
1133                .boxed()
1134        }
1135    }
1136
1137    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1138    #[derive(Debug)]
1139    pub struct CryptoTestRng(TestRng);
1140
1141    impl Arbitrary for CryptoTestRng {
1142        type Parameters = ();
1143        type Strategy = BoxedStrategy<CryptoTestRng>;
1144
1145        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1146            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1147        }
1148    }
1149    impl RngCore for CryptoTestRng {
1150        fn next_u32(&mut self) -> u32 {
1151            self.0.next_u32()
1152        }
1153
1154        fn next_u64(&mut self) -> u64 {
1155            self.0.next_u64()
1156        }
1157
1158        fn fill_bytes(&mut self, dest: &mut [u8]) {
1159            self.0.fill_bytes(dest);
1160        }
1161
1162        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1163            self.0.try_fill_bytes(dest)
1164        }
1165    }
1166
1167    impl CryptoRng for CryptoTestRng {}
1168
1169    #[derive(Debug, Clone)]
1170    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1171
1172    impl Arbitrary for AnyTransmission {
1173        type Parameters = ();
1174        type Strategy = BoxedStrategy<AnyTransmission>;
1175
1176        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1177            any_transmission().prop_map(AnyTransmission).boxed()
1178        }
1179    }
1180
1181    #[derive(Debug, Clone)]
1182    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1183
1184    impl Arbitrary for AnyTransmissionID {
1185        type Parameters = ();
1186        type Strategy = BoxedStrategy<AnyTransmissionID>;
1187
1188        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1189            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1190        }
1191    }
1192
1193    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1194        prop_oneof![
1195            (collection::vec(any::<u8>(), 512..=512))
1196                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1197            (collection::vec(any::<u8>(), 2048..=2048))
1198                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1199        ]
1200        .boxed()
1201    }
1202
1203    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1204        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1205    }
1206
1207    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1208        Just(0)
1209            .prop_perturb(|_, rng| {
1210                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1211            })
1212            .boxed()
1213    }
1214
1215    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1216        prop_oneof![
1217            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1218                id,
1219                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1220            )),
1221            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1222                id,
1223                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1224            )),
1225        ]
1226        .boxed()
1227    }
1228
1229    pub fn sign_batch_header<R: Rng + CryptoRng>(
1230        validator_set: &ValidatorSet,
1231        batch_header: &BatchHeader<CurrentNetwork>,
1232        rng: &mut R,
1233    ) -> IndexSet<Signature<CurrentNetwork>> {
1234        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1235        for validator in validator_set.0.iter() {
1236            let private_key = validator.private_key;
1237            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1238        }
1239        signatures
1240    }
1241
1242    #[proptest]
1243    fn test_certificate_duplicate(
1244        context: CommitteeContext,
1245        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1246        mut rng: CryptoTestRng,
1247        selector: Selector,
1248    ) {
1249        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1250        let committee_id = committee.id();
1251
1252        // Initialize the storage.
1253        let ledger = Arc::new(MockLedgerService::new(committee));
1254        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1255
1256        // Ensure the storage is empty.
1257        assert_storage(&storage, &[], &[], &[], &Default::default());
1258
1259        // Create a new certificate.
1260        let signer = selector.select(&validators);
1261
1262        let mut transmission_map = IndexMap::new();
1263
1264        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1265            transmission_map.insert(*id, t.clone());
1266        }
1267
1268        let batch_header = BatchHeader::new(
1269            &signer.private_key,
1270            0,
1271            now(),
1272            committee_id,
1273            transmission_map.keys().cloned().collect(),
1274            Default::default(),
1275            &mut rng,
1276        )
1277        .unwrap();
1278
1279        // Remove the author from the validator set passed to create the batch
1280        // certificate, the author should not sign their own batch.
1281        let mut validators = validators.clone();
1282        validators.remove(signer);
1283
1284        let certificate = BatchCertificate::from(
1285            batch_header.clone(),
1286            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1287        )
1288        .unwrap();
1289
1290        // Retrieve the certificate ID.
1291        let certificate_id = certificate.id();
1292        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1293        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1294            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1295        }
1296
1297        // Retrieve the round.
1298        let round = certificate.round();
1299        // Retrieve the batch ID.
1300        let batch_id = certificate.batch_id();
1301        // Retrieve the author of the batch.
1302        let author = certificate.author();
1303
1304        // Construct the expected layout for 'rounds'.
1305        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1306        // Construct the expected layout for 'certificates'.
1307        let certificates = [(certificate_id, certificate.clone())];
1308        // Construct the expected layout for 'batch_ids'.
1309        let batch_ids = [(batch_id, round)];
1310
1311        // Insert the certificate.
1312        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1313            transmission_map.into_iter().collect();
1314        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1315        // Ensure the certificate exists in storage.
1316        assert!(storage.contains_certificate(certificate_id));
1317        // Check that the underlying storage representation is correct.
1318        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1319
1320        // Insert the certificate again - without any missing transmissions.
1321        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1322        // Ensure the certificate exists in storage.
1323        assert!(storage.contains_certificate(certificate_id));
1324        // Check that the underlying storage representation remains unchanged.
1325        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1326
1327        // Insert the certificate again - with all of the original missing transmissions.
1328        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1329        // Ensure the certificate exists in storage.
1330        assert!(storage.contains_certificate(certificate_id));
1331        // Check that the underlying storage representation remains unchanged.
1332        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1333    }
1334}