amareleo_node_bft/helpers/
storage.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::TracingHandler;
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21    ledger::{
22        block::{Block, Transaction},
23        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24    },
25    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26    utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30use parking_lot::RwLock;
31use rayon::iter::{IntoParallelIterator, ParallelIterator};
32use std::{
33    collections::{HashMap, HashSet},
34    sync::{
35        Arc,
36        atomic::{AtomicU32, AtomicU64, Ordering},
37    },
38};
39use tracing::subscriber::DefaultGuard;
40
41#[derive(Clone, Debug)]
42pub struct Storage<N: Network>(Arc<StorageInner<N>>);
43
44impl<N: Network> std::ops::Deref for Storage<N> {
45    type Target = Arc<StorageInner<N>>;
46
47    fn deref(&self) -> &Self::Target {
48        &self.0
49    }
50}
51
52/// The storage for the memory pool.
53///
54/// The storage is used to store the following:
55/// - `current_height` tracker.
56/// - `current_round` tracker.
57/// - `round` to `(certificate ID, batch ID, author)` entries.
58/// - `certificate ID` to `certificate` entries.
59/// - `batch ID` to `round` entries.
60/// - `transmission ID` to `(transmission, certificate IDs)` entries.
61///
62/// The chain of events is as follows:
63/// 1. A `transmission` is received.
64/// 2. After a `batch` is ready to be stored:
65///   - The `certificate` is inserted, triggering updates to the
66///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
67///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
68///   - The certificate ID is inserted into the `transmissions` map.
69/// 3. After a `round` reaches quorum threshold:
70///  - The next round is inserted into the `current_round`.
71#[derive(Debug)]
72pub struct StorageInner<N: Network> {
73    /// The ledger service.
74    ledger: Arc<dyn LedgerService<N>>,
75    /* Once per block */
76    /// The current height.
77    current_height: AtomicU32,
78    /* Once per round */
79    /// The current round.
80    current_round: AtomicU64,
81    /// The `round` for which garbage collection has occurred **up to** (inclusive).
82    gc_round: AtomicU64,
83    /// The maximum number of rounds to keep in storage.
84    max_gc_rounds: u64,
85    /* Once per batch */
86    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
87    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
88    /// The map of `certificate ID` to `certificate`.
89    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
90    /// The map of `batch ID` to `round`.
91    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
92    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
93    transmissions: Arc<dyn StorageService<N>>,
94    /// Tracing handle
95    tracing: Option<TracingHandler>,
96}
97
98impl<N: Network> Storage<N> {
99    /// Initializes a new instance of storage.
100    pub fn new(
101        ledger: Arc<dyn LedgerService<N>>,
102        transmissions: Arc<dyn StorageService<N>>,
103        max_gc_rounds: u64,
104        tracing: Option<TracingHandler>,
105    ) -> Self {
106        // Retrieve the current committee.
107        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
108        // Retrieve the current round.
109        let current_round = committee.starting_round().max(1);
110
111        // Return the storage.
112        let storage = Self(Arc::new(StorageInner {
113            ledger,
114            current_height: Default::default(),
115            current_round: Default::default(),
116            gc_round: Default::default(),
117            max_gc_rounds,
118            rounds: Default::default(),
119            certificates: Default::default(),
120            batch_ids: Default::default(),
121            transmissions,
122            tracing,
123        }));
124        // Update the storage to the current round.
125        storage.update_current_round(current_round);
126        // Perform GC on the current round.
127        storage.garbage_collect_certificates(current_round);
128        // Return the storage.
129        storage
130    }
131}
132
133impl<N: Network> Storage<N> {
134    /// Returns the current height.
135    pub fn current_height(&self) -> u32 {
136        // Get the current height.
137        self.current_height.load(Ordering::SeqCst)
138    }
139}
140
141impl<N: Network> Storage<N> {
142    /// Returns the current round.
143    pub fn current_round(&self) -> u64 {
144        // Get the current round.
145        self.current_round.load(Ordering::SeqCst)
146    }
147
148    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
149    pub fn gc_round(&self) -> u64 {
150        // Get the GC round.
151        self.gc_round.load(Ordering::SeqCst)
152    }
153
154    /// Returns the maximum number of rounds to keep in storage.
155    pub fn max_gc_rounds(&self) -> u64 {
156        self.max_gc_rounds
157    }
158
159    /// Retruns tracing guard
160    pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
161        self.tracing.clone().map(|trace_handle| trace_handle.subscribe_thread())
162    }
163
164    /// Increments storage to the next round, updating the current round.
165    /// Note: This method is only called once per round, upon certification of the primary's batch.
166    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
167        let _guard = self.get_tracing_guard();
168
169        // Determine the next round.
170        let next_round = current_round + 1;
171
172        // Check if the next round is less than the current round in storage.
173        {
174            // Retrieve the storage round.
175            let storage_round = self.current_round();
176            // If the next round is less than the current round in storage, return early with the storage round.
177            if next_round < storage_round {
178                return Ok(storage_round);
179            }
180        }
181
182        // Retrieve the current committee.
183        let current_committee = self.ledger.current_committee()?;
184        // Retrieve the current committee's starting round.
185        let starting_round = current_committee.starting_round();
186        // If the primary is behind the current committee's starting round, sync with the latest block.
187        if next_round < starting_round {
188            // Retrieve the latest block round.
189            let latest_block_round = self.ledger.latest_round();
190            // Log the round sync.
191            info!(
192                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
193            );
194            // Sync the round with the latest block.
195            self.sync_round_with_block(latest_block_round);
196            // Return the latest block round.
197            return Ok(latest_block_round);
198        }
199
200        // Update the storage to the next round.
201        self.update_current_round(next_round);
202
203        #[cfg(feature = "metrics")]
204        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
205
206        // Retrieve the storage round.
207        let storage_round = self.current_round();
208        // Retrieve the GC round.
209        let gc_round = self.gc_round();
210        // Ensure the next round matches in storage.
211        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
212        // Ensure the next round is greater than or equal to the GC round.
213        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
214
215        // Log the updated round.
216        info!("Starting round {next_round}...");
217        Ok(next_round)
218    }
219
220    /// Updates the storage to the next round.
221    fn update_current_round(&self, next_round: u64) {
222        // Update the current round.
223        self.current_round.store(next_round, Ordering::SeqCst);
224    }
225
226    /// Update the storage by performing garbage collection based on the next round.
227    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
228        // Fetch the current GC round.
229        let current_gc_round = self.gc_round();
230        // Compute the next GC round.
231        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
232        // Check if storage needs to be garbage collected.
233        if next_gc_round > current_gc_round {
234            // Remove the GC round(s) from storage.
235            for gc_round in current_gc_round..=next_gc_round {
236                // Iterate over the certificates for the GC round.
237                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
238                    // Remove the certificate from storage.
239                    self.remove_certificate(id);
240                }
241            }
242            // Update the GC round.
243            self.gc_round.store(next_gc_round, Ordering::SeqCst);
244        }
245    }
246}
247
248impl<N: Network> Storage<N> {
249    /// Returns `true` if the storage contains the specified `round`.
250    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
251        // Check if the round exists in storage.
252        self.rounds.read().contains_key(&round)
253    }
254
255    /// Returns `true` if the storage contains the specified `certificate ID`.
256    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
257        // Check if the certificate ID exists in storage.
258        self.certificates.read().contains_key(&certificate_id)
259    }
260
261    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
262    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
263        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
264    }
265
266    /// Returns `true` if the storage contains the specified `batch ID`.
267    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
268        // Check if the batch ID exists in storage.
269        self.batch_ids.read().contains_key(&batch_id)
270    }
271
272    /// Returns `true` if the storage contains the specified `transmission ID`.
273    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
274        self.transmissions.contains_transmission(transmission_id.into())
275    }
276
277    /// Returns the transmission for the given `transmission ID`.
278    /// If the transmission ID does not exist in storage, `None` is returned.
279    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
280        self.transmissions.get_transmission(transmission_id.into())
281    }
282
283    /// Returns the round for the given `certificate ID`.
284    /// If the certificate ID does not exist in storage, `None` is returned.
285    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
286        // Get the round.
287        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
288    }
289
290    /// Returns the round for the given `batch ID`.
291    /// If the batch ID does not exist in storage, `None` is returned.
292    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
293        // Get the round.
294        self.batch_ids.read().get(&batch_id).copied()
295    }
296
297    /// Returns the certificate round for the given `certificate ID`.
298    /// If the certificate ID does not exist in storage, `None` is returned.
299    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
300        // Get the batch certificate and return the round.
301        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
302    }
303
304    /// Returns the certificate for the given `certificate ID`.
305    /// If the certificate ID does not exist in storage, `None` is returned.
306    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
307        // Get the batch certificate.
308        self.certificates.read().get(&certificate_id).cloned()
309    }
310
311    /// Returns the certificate for the given `round` and `author`.
312    /// If the round does not exist in storage, `None` is returned.
313    /// If the author for the round does not exist in storage, `None` is returned.
314    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
315        // Retrieve the certificates.
316        if let Some(entries) = self.rounds.read().get(&round) {
317            let certificates = self.certificates.read();
318            entries.iter().find_map(
319                |(certificate_id, _, a)| {
320                    if a == &author { certificates.get(certificate_id).cloned() } else { None }
321                },
322            )
323        } else {
324            Default::default()
325        }
326    }
327
328    /// Returns the certificates for the given `round`.
329    /// If the round does not exist in storage, an empty set is returned.
330    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
331        // The genesis round does not have batch certificates.
332        if round == 0 {
333            return Default::default();
334        }
335        // Retrieve the certificates.
336        if let Some(entries) = self.rounds.read().get(&round) {
337            let certificates = self.certificates.read();
338            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
339        } else {
340            Default::default()
341        }
342    }
343
344    /// Returns the certificate IDs for the given `round`.
345    /// If the round does not exist in storage, an empty set is returned.
346    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
347        // The genesis round does not have batch certificates.
348        if round == 0 {
349            return Default::default();
350        }
351        // Retrieve the certificates.
352        if let Some(entries) = self.rounds.read().get(&round) {
353            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
354        } else {
355            Default::default()
356        }
357    }
358
359    /// Returns the certificate authors for the given `round`.
360    /// If the round does not exist in storage, an empty set is returned.
361    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
362        // The genesis round does not have batch certificates.
363        if round == 0 {
364            return Default::default();
365        }
366        // Retrieve the certificates.
367        if let Some(entries) = self.rounds.read().get(&round) {
368            entries.iter().map(|(_, _, author)| *author).collect()
369        } else {
370            Default::default()
371        }
372    }
373
374    /// Returns the certificates that have not yet been included in the ledger.
375    /// Note that the order of this set is by round and then insertion.
376    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
377        // Obtain the read locks.
378        let rounds = self.rounds.read();
379        let certificates = self.certificates.read();
380
381        // Iterate over the rounds.
382        cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
383            .flat_map(|(_, certificates_for_round)| {
384                // Iterate over the certificates for the round.
385                cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
386                    // Skip the certificate if it already exists in the ledger.
387                    if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
388                        None
389                    } else {
390                        // Add the certificate to the pending certificates.
391                        certificates.get(&certificate_id).cloned()
392                    }
393                })
394            })
395            .collect()
396    }
397
398    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
399    ///
400    /// This method ensures the following invariants:
401    /// - The batch ID does not already exist in storage.
402    /// - The author is a member of the committee for the batch round.
403    /// - The timestamp is within the allowed time range.
404    /// - None of the transmissions are from any past rounds (up to GC).
405    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
406    /// - All previous certificates declared in the certificate exist in storage (up to GC).
407    /// - All previous certificates are for the previous round (i.e. round - 1).
408    /// - All previous certificates contain a unique author.
409    /// - The previous certificates reached the quorum threshold (N - f).
410    pub fn check_batch_header(
411        &self,
412        batch_header: &BatchHeader<N>,
413        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
414        aborted_transmissions: HashSet<TransmissionID<N>>,
415    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
416        // Retrieve the round.
417        let round = batch_header.round();
418        // Retrieve the GC round.
419        let gc_round = self.gc_round();
420        // Construct a GC log message.
421        let gc_log = format!("(gc = {gc_round})");
422
423        // Ensure the batch ID does not already exist in storage.
424        if self.contains_batch(batch_header.batch_id()) {
425            bail!("Batch for round {round} already exists in storage {gc_log}")
426        }
427
428        // Retrieve the committee lookback for the batch round.
429        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
430            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
431        };
432        // Ensure the author is in the committee.
433        if !committee_lookback.is_committee_member(batch_header.author()) {
434            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
435        }
436
437        // Check the timestamp for liveness.
438        check_timestamp_for_liveness(batch_header.timestamp())?;
439
440        // Retrieve the missing transmissions in storage from the given transmissions.
441        let missing_transmissions = self
442            .transmissions
443            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
444            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
445
446        // Compute the previous round.
447        let previous_round = round.saturating_sub(1);
448        // Check if the previous round is within range of the GC round.
449        if previous_round > gc_round {
450            // Retrieve the committee lookback for the previous round.
451            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
452                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
453            };
454            // Ensure the previous round certificates exists in storage.
455            if !self.contains_certificates_for_round(previous_round) {
456                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
457            }
458            // Ensure the number of previous certificate IDs is at or below the number of committee members.
459            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
460                bail!("Too many previous certificates for round {round} {gc_log}")
461            }
462            // Initialize a set of the previous authors.
463            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
464            // Ensure storage contains all declared previous certificates (up to GC).
465            for previous_certificate_id in batch_header.previous_certificate_ids() {
466                // Retrieve the previous certificate.
467                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
468                    bail!(
469                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
470                        fmt_id(previous_certificate_id)
471                    )
472                };
473                // Ensure the previous certificate is for the previous round.
474                if previous_certificate.round() != previous_round {
475                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
476                }
477                // Ensure the previous author is new.
478                if previous_authors.contains(&previous_certificate.author()) {
479                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
480                }
481                // Insert the author of the previous certificate.
482                previous_authors.insert(previous_certificate.author());
483            }
484            // Ensure the previous certificates have reached the quorum threshold.
485            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
486                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
487            }
488        }
489        Ok(missing_transmissions)
490    }
491
492    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
493    ///
494    /// This method ensures the following invariants:
495    /// - The certificate ID does not already exist in storage.
496    /// - The batch ID does not already exist in storage.
497    /// - The author is a member of the committee for the batch round.
498    /// - The author has not already created a certificate for the batch round.
499    /// - The timestamp is within the allowed time range.
500    /// - None of the transmissions are from any past rounds (up to GC).
501    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
502    /// - All previous certificates declared in the certificate exist in storage (up to GC).
503    /// - All previous certificates are for the previous round (i.e. round - 1).
504    /// - The previous certificates reached the quorum threshold (N - f).
505    /// - The timestamps from the signers are all within the allowed time range.
506    /// - The signers have reached the quorum threshold (N - f).
507    pub fn check_certificate(
508        &self,
509        certificate: &BatchCertificate<N>,
510        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
511        aborted_transmissions: HashSet<TransmissionID<N>>,
512    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
513        // Retrieve the round.
514        let round = certificate.round();
515        // Retrieve the GC round.
516        let gc_round = self.gc_round();
517        // Construct a GC log message.
518        let gc_log = format!("(gc = {gc_round})");
519
520        // Ensure the certificate ID does not already exist in storage.
521        if self.contains_certificate(certificate.id()) {
522            bail!("Certificate for round {round} already exists in storage {gc_log}")
523        }
524
525        // Ensure the storage does not already contain a certificate for this author in this round.
526        if self.contains_certificate_in_round_from(round, certificate.author()) {
527            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
528        }
529
530        // Ensure the batch header is well-formed.
531        let missing_transmissions =
532            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
533
534        // Check the timestamp for liveness.
535        check_timestamp_for_liveness(certificate.timestamp())?;
536
537        // Retrieve the committee lookback for the batch round.
538        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
539            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
540        };
541
542        // Initialize a set of the signers.
543        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
544        // Append the batch author.
545        signers.insert(certificate.author());
546
547        // Iterate over the signatures.
548        for signature in certificate.signatures() {
549            // Retrieve the signer.
550            let signer = signature.to_address();
551            // Ensure the signer is in the committee.
552            if !committee_lookback.is_committee_member(signer) {
553                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
554            }
555            // Append the signer.
556            signers.insert(signer);
557        }
558
559        // Ensure the signatures have reached the quorum threshold.
560        if !committee_lookback.is_quorum_threshold_reached(&signers) {
561            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
562        }
563        Ok(missing_transmissions)
564    }
565
566    /// Inserts the given `certificate` into storage.
567    ///
568    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
569    ///
570    /// This method ensures the following invariants:
571    /// - The certificate ID does not already exist in storage.
572    /// - The batch ID does not already exist in storage.
573    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
574    /// - All previous certificates declared in the certificate exist in storage (up to GC).
575    /// - All previous certificates are for the previous round (i.e. round - 1).
576    /// - The previous certificates reached the quorum threshold (N - f).
577    pub fn insert_certificate(
578        &self,
579        certificate: BatchCertificate<N>,
580        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
581        aborted_transmissions: HashSet<TransmissionID<N>>,
582    ) -> Result<()> {
583        // Ensure the certificate round is above the GC round.
584        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
585        // Ensure the certificate and its transmissions are valid.
586        let missing_transmissions =
587            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
588        // Insert the certificate into storage.
589        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
590        Ok(())
591    }
592
593    /// Inserts the given `certificate` into storage.
594    ///
595    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
596    ///
597    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
598    fn insert_certificate_atomic(
599        &self,
600        certificate: BatchCertificate<N>,
601        aborted_transmission_ids: HashSet<TransmissionID<N>>,
602        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
603    ) {
604        // Retrieve the round.
605        let round = certificate.round();
606        // Retrieve the certificate ID.
607        let certificate_id = certificate.id();
608        // Retrieve the batch ID.
609        let batch_id = certificate.batch_id();
610        // Retrieve the author of the batch.
611        let author = certificate.author();
612
613        // Insert the round to certificate ID entry.
614        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
615        // Obtain the certificate's transmission ids.
616        let transmission_ids = certificate.transmission_ids().clone();
617        // Insert the certificate.
618        self.certificates.write().insert(certificate_id, certificate);
619        // Insert the batch ID.
620        self.batch_ids.write().insert(batch_id, round);
621        // Insert the certificate ID for each of the transmissions into storage.
622        self.transmissions.insert_transmissions(
623            certificate_id,
624            transmission_ids,
625            aborted_transmission_ids,
626            missing_transmissions,
627        );
628    }
629
630    /// Removes the given `certificate ID` from storage.
631    ///
632    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
633    ///
634    /// If the certificate was successfully removed, `true` is returned.
635    /// If the certificate did not exist in storage, `false` is returned.
636    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
637        let _guard = self.get_tracing_guard();
638        // Retrieve the certificate.
639        let Some(certificate) = self.get_certificate(certificate_id) else {
640            warn!("Certificate {certificate_id} does not exist in storage");
641            return false;
642        };
643        // Retrieve the round.
644        let round = certificate.round();
645        // Retrieve the batch ID.
646        let batch_id = certificate.batch_id();
647        // Compute the author of the batch.
648        let author = certificate.author();
649
650        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
651        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
652        //  It will be better to write tests that compare the union of the sets.
653
654        // Update the round.
655        match self.rounds.write().entry(round) {
656            Entry::Occupied(mut entry) => {
657                // Remove the round to certificate ID entry.
658                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
659                // If the round is empty, remove it.
660                if entry.get().is_empty() {
661                    entry.swap_remove();
662                }
663            }
664            Entry::Vacant(_) => {}
665        }
666        // Remove the certificate.
667        self.certificates.write().swap_remove(&certificate_id);
668        // Remove the batch ID.
669        self.batch_ids.write().swap_remove(&batch_id);
670        // Remove the transmission entries in the certificate from storage.
671        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
672        // Return successfully.
673        true
674    }
675}
676
677impl<N: Network> Storage<N> {
678    /// Syncs the current height with the block.
679    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
680        // If the block height is greater than the current height in storage, sync the height.
681        if next_height > self.current_height() {
682            // Update the current height in storage.
683            self.current_height.store(next_height, Ordering::SeqCst);
684        }
685    }
686
687    /// Syncs the current round with the block.
688    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
689        let _guard = self.get_tracing_guard();
690        // Retrieve the current round in the block.
691        let next_round = next_round.max(1);
692        // If the round in the block is greater than the current round in storage, sync the round.
693        if next_round > self.current_round() {
694            // Update the current round in storage.
695            self.update_current_round(next_round);
696            // Log the updated round.
697            info!("Synced to round {next_round}...");
698        }
699    }
700
701    /// Syncs the batch certificate with the block.
702    pub(crate) fn sync_certificate_with_block(
703        &self,
704        block: &Block<N>,
705        certificate: BatchCertificate<N>,
706        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
707    ) {
708        // Skip if the certificate round is below the GC round.
709        if certificate.round() <= self.gc_round() {
710            return;
711        }
712        // If the certificate ID already exists in storage, skip it.
713        if self.contains_certificate(certificate.id()) {
714            return;
715        }
716        // Retrieve the transmissions for the certificate.
717        let mut missing_transmissions = HashMap::new();
718
719        // Retrieve the aborted transmissions for the certificate.
720        let mut aborted_transmissions = HashSet::new();
721
722        // Track the block's aborted solutions and transactions.
723        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
724        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
725
726        let _guard = self.get_tracing_guard();
727
728        // Iterate over the transmission IDs.
729        for transmission_id in certificate.transmission_ids() {
730            // If the transmission ID already exists in the map, skip it.
731            if missing_transmissions.contains_key(transmission_id) {
732                continue;
733            }
734            // If the transmission ID exists in storage, skip it.
735            if self.contains_transmission(*transmission_id) {
736                continue;
737            }
738            // Retrieve the transmission.
739            match transmission_id {
740                TransmissionID::Ratification => (),
741                TransmissionID::Solution(solution_id, _) => {
742                    // Retrieve the solution.
743                    match block.get_solution(solution_id) {
744                        // Insert the solution.
745                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
746                        // Otherwise, try to load the solution from the ledger.
747                        None => match self.ledger.get_solution(solution_id) {
748                            // Insert the solution.
749                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
750                            // Check if the solution is in the aborted solutions.
751                            Err(_) => {
752                                // Insert the aborted solution if it exists in the block or ledger.
753                                match aborted_solutions.contains(solution_id)
754                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
755                                {
756                                    true => {
757                                        aborted_transmissions.insert(*transmission_id);
758                                    }
759                                    false => error!("Missing solution {solution_id} in block {}", block.height()),
760                                }
761                                continue;
762                            }
763                        },
764                    };
765                }
766                TransmissionID::Transaction(transaction_id, _) => {
767                    // Retrieve the transaction.
768                    match unconfirmed_transactions.get(transaction_id) {
769                        // Insert the transaction.
770                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
771                        // Otherwise, try to load the unconfirmed transaction from the ledger.
772                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
773                            // Insert the transaction.
774                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
775                            // Check if the transaction is in the aborted transactions.
776                            Err(_) => {
777                                // Insert the aborted transaction if it exists in the block or ledger.
778                                match aborted_transactions.contains(transaction_id)
779                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
780                                {
781                                    true => {
782                                        aborted_transmissions.insert(*transmission_id);
783                                    }
784                                    false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
785                                }
786                                continue;
787                            }
788                        },
789                    };
790                }
791            }
792        }
793        // Insert the batch certificate into storage.
794        let certificate_id = fmt_id(certificate.id());
795        debug!(
796            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
797            certificate.round(),
798            certificate.transmission_ids().len()
799        );
800        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
801            error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
802        }
803    }
804}
805
806#[cfg(test)]
807impl<N: Network> Storage<N> {
808    /// Returns the ledger service.
809    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
810        &self.ledger
811    }
812
813    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
814    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
815        self.rounds.read().clone().into_iter()
816    }
817
818    /// Returns an iterator over the `(certificate ID, certificate)` entries.
819    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
820        self.certificates.read().clone().into_iter()
821    }
822
823    /// Returns an iterator over the `(batch ID, round)` entries.
824    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
825        self.batch_ids.read().clone().into_iter()
826    }
827
828    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
829    pub fn transmissions_iter(
830        &self,
831    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
832        self.transmissions.as_hashmap().into_iter()
833    }
834
835    /// Inserts the given `certificate` into storage.
836    ///
837    /// Note: Do NOT use this in production. This is for **testing only**.
838    #[cfg(test)]
839    #[doc(hidden)]
840    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
841        // Retrieve the round.
842        let round = certificate.round();
843        // Retrieve the certificate ID.
844        let certificate_id = certificate.id();
845        // Retrieve the batch ID.
846        let batch_id = certificate.batch_id();
847        // Retrieve the author of the batch.
848        let author = certificate.author();
849
850        // Insert the round to certificate ID entry.
851        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
852        // Obtain the certificate's transmission ids.
853        let transmission_ids = certificate.transmission_ids().clone();
854        // Insert the certificate.
855        self.certificates.write().insert(certificate_id, certificate);
856        // Insert the batch ID.
857        self.batch_ids.write().insert(batch_id, round);
858
859        // Construct the dummy missing transmissions (for testing purposes).
860        let missing_transmissions = transmission_ids
861            .iter()
862            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
863            .collect::<HashMap<_, _>>();
864        // Insert the certificate ID for each of the transmissions into storage.
865        self.transmissions.insert_transmissions(
866            certificate_id,
867            transmission_ids,
868            Default::default(),
869            missing_transmissions,
870        );
871    }
872}
873
874#[cfg(test)]
875pub(crate) mod tests {
876    use super::*;
877    use amareleo_node_bft_ledger_service::MockLedgerService;
878    use amareleo_node_bft_storage_service::BFTMemoryService;
879    use snarkvm::{
880        ledger::narwhal::Data,
881        prelude::{Rng, TestRng},
882    };
883
884    use ::bytes::Bytes;
885    use indexmap::indexset;
886
887    type CurrentNetwork = snarkvm::prelude::MainnetV0;
888
889    /// Asserts that the storage matches the expected layout.
890    pub fn assert_storage<N: Network>(
891        storage: &Storage<N>,
892        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
893        certificates: &[(Field<N>, BatchCertificate<N>)],
894        batch_ids: &[(Field<N>, u64)],
895        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
896    ) {
897        // Ensure the rounds are well-formed.
898        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
899        // Ensure the certificates are well-formed.
900        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
901        // Ensure the batch IDs are well-formed.
902        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
903        // Ensure the transmissions are well-formed.
904        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
905    }
906
907    /// Samples a random transmission.
908    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
909        // Sample random fake solution bytes.
910        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
911        // Sample random fake transaction bytes.
912        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
913        // Sample a random transmission.
914        match rng.gen::<bool>() {
915            true => Transmission::Solution(s(rng)),
916            false => Transmission::Transaction(t(rng)),
917        }
918    }
919
920    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
921    pub(crate) fn sample_transmissions(
922        certificate: &BatchCertificate<CurrentNetwork>,
923        rng: &mut TestRng,
924    ) -> (
925        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
926        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
927    ) {
928        // Retrieve the certificate ID.
929        let certificate_id = certificate.id();
930
931        let mut missing_transmissions = HashMap::new();
932        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
933        for transmission_id in certificate.transmission_ids() {
934            // Initialize the transmission.
935            let transmission = sample_transmission(rng);
936            // Update the missing transmissions.
937            missing_transmissions.insert(*transmission_id, transmission.clone());
938            // Update the transmissions map.
939            transmissions
940                .entry(*transmission_id)
941                .or_insert((transmission, Default::default()))
942                .1
943                .insert(certificate_id);
944        }
945        (missing_transmissions, transmissions)
946    }
947
948    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
949
950    #[test]
951    fn test_certificate_insert_remove() {
952        let rng = &mut TestRng::default();
953
954        // Sample a committee.
955        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
956        // Initialize the ledger.
957        let ledger = Arc::new(MockLedgerService::new(committee));
958        // Initialize the storage.
959        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
960
961        // Ensure the storage is empty.
962        assert_storage(&storage, &[], &[], &[], &Default::default());
963
964        // Create a new certificate.
965        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
966        // Retrieve the certificate ID.
967        let certificate_id = certificate.id();
968        // Retrieve the round.
969        let round = certificate.round();
970        // Retrieve the batch ID.
971        let batch_id = certificate.batch_id();
972        // Retrieve the author of the batch.
973        let author = certificate.author();
974
975        // Construct the sample 'transmissions'.
976        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
977
978        // Insert the certificate.
979        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
980        // Ensure the certificate exists in storage.
981        assert!(storage.contains_certificate(certificate_id));
982        // Ensure the certificate is stored in the correct round.
983        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
984        // Ensure the certificate is stored for the correct round and author.
985        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
986
987        // Check that the underlying storage representation is correct.
988        {
989            // Construct the expected layout for 'rounds'.
990            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
991            // Construct the expected layout for 'certificates'.
992            let certificates = [(certificate_id, certificate.clone())];
993            // Construct the expected layout for 'batch_ids'.
994            let batch_ids = [(batch_id, round)];
995            // Assert the storage is well-formed.
996            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
997        }
998
999        // Retrieve the certificate.
1000        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1001        // Ensure the retrieved certificate is the same as the inserted certificate.
1002        assert_eq!(certificate, candidate_certificate);
1003
1004        // Remove the certificate.
1005        assert!(storage.remove_certificate(certificate_id));
1006        // Ensure the certificate does not exist in storage.
1007        assert!(!storage.contains_certificate(certificate_id));
1008        // Ensure the certificate is no longer stored in the round.
1009        assert!(storage.get_certificates_for_round(round).is_empty());
1010        // Ensure the certificate is no longer stored for the round and author.
1011        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1012        // Ensure the storage is empty.
1013        assert_storage(&storage, &[], &[], &[], &Default::default());
1014    }
1015
1016    #[test]
1017    fn test_certificate_duplicate() {
1018        let rng = &mut TestRng::default();
1019
1020        // Sample a committee.
1021        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1022        // Initialize the ledger.
1023        let ledger = Arc::new(MockLedgerService::new(committee));
1024        // Initialize the storage.
1025        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1026
1027        // Ensure the storage is empty.
1028        assert_storage(&storage, &[], &[], &[], &Default::default());
1029
1030        // Create a new certificate.
1031        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1032        // Retrieve the certificate ID.
1033        let certificate_id = certificate.id();
1034        // Retrieve the round.
1035        let round = certificate.round();
1036        // Retrieve the batch ID.
1037        let batch_id = certificate.batch_id();
1038        // Retrieve the author of the batch.
1039        let author = certificate.author();
1040
1041        // Construct the expected layout for 'rounds'.
1042        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1043        // Construct the expected layout for 'certificates'.
1044        let certificates = [(certificate_id, certificate.clone())];
1045        // Construct the expected layout for 'batch_ids'.
1046        let batch_ids = [(batch_id, round)];
1047        // Construct the sample 'transmissions'.
1048        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1049
1050        // Insert the certificate.
1051        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1052        // Ensure the certificate exists in storage.
1053        assert!(storage.contains_certificate(certificate_id));
1054        // Check that the underlying storage representation is correct.
1055        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1056
1057        // Insert the certificate again - without any missing transmissions.
1058        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1059        // Ensure the certificate exists in storage.
1060        assert!(storage.contains_certificate(certificate_id));
1061        // Check that the underlying storage representation remains unchanged.
1062        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1063
1064        // Insert the certificate again - with all of the original missing transmissions.
1065        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1066        // Ensure the certificate exists in storage.
1067        assert!(storage.contains_certificate(certificate_id));
1068        // Check that the underlying storage representation remains unchanged.
1069        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1070    }
1071}
1072
1073#[cfg(test)]
1074pub mod prop_tests {
1075    use super::*;
1076    use crate::helpers::{now, storage::tests::assert_storage};
1077    use amareleo_node_bft_ledger_service::MockLedgerService;
1078    use amareleo_node_bft_storage_service::BFTMemoryService;
1079    use snarkvm::{
1080        ledger::{
1081            committee::prop_tests::{CommitteeContext, ValidatorSet},
1082            narwhal::{BatchHeader, Data},
1083            puzzle::SolutionID,
1084        },
1085        prelude::{Signature, Uniform},
1086    };
1087
1088    use ::bytes::Bytes;
1089    use indexmap::indexset;
1090    use proptest::{
1091        collection,
1092        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1093        prop_oneof,
1094        sample::{Selector, size_range},
1095        test_runner::TestRng,
1096    };
1097    use rand::{CryptoRng, Error, Rng, RngCore};
1098    use std::fmt::Debug;
1099    use test_strategy::proptest;
1100
1101    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1102
1103    impl Arbitrary for Storage<CurrentNetwork> {
1104        type Parameters = CommitteeContext;
1105        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1106
1107        fn arbitrary() -> Self::Strategy {
1108            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1109                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1110                    let ledger = Arc::new(MockLedgerService::new(committee));
1111                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1112                })
1113                .boxed()
1114        }
1115
1116        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1117            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1118                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1119                    let ledger = Arc::new(MockLedgerService::new(committee));
1120                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1121                })
1122                .boxed()
1123        }
1124    }
1125
1126    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1127    #[derive(Debug)]
1128    pub struct CryptoTestRng(TestRng);
1129
1130    impl Arbitrary for CryptoTestRng {
1131        type Parameters = ();
1132        type Strategy = BoxedStrategy<CryptoTestRng>;
1133
1134        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1135            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1136        }
1137    }
1138    impl RngCore for CryptoTestRng {
1139        fn next_u32(&mut self) -> u32 {
1140            self.0.next_u32()
1141        }
1142
1143        fn next_u64(&mut self) -> u64 {
1144            self.0.next_u64()
1145        }
1146
1147        fn fill_bytes(&mut self, dest: &mut [u8]) {
1148            self.0.fill_bytes(dest);
1149        }
1150
1151        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1152            self.0.try_fill_bytes(dest)
1153        }
1154    }
1155
1156    impl CryptoRng for CryptoTestRng {}
1157
1158    #[derive(Debug, Clone)]
1159    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1160
1161    impl Arbitrary for AnyTransmission {
1162        type Parameters = ();
1163        type Strategy = BoxedStrategy<AnyTransmission>;
1164
1165        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1166            any_transmission().prop_map(AnyTransmission).boxed()
1167        }
1168    }
1169
1170    #[derive(Debug, Clone)]
1171    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1172
1173    impl Arbitrary for AnyTransmissionID {
1174        type Parameters = ();
1175        type Strategy = BoxedStrategy<AnyTransmissionID>;
1176
1177        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1178            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1179        }
1180    }
1181
1182    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1183        prop_oneof![
1184            (collection::vec(any::<u8>(), 512..=512))
1185                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1186            (collection::vec(any::<u8>(), 2048..=2048))
1187                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1188        ]
1189        .boxed()
1190    }
1191
1192    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1193        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1194    }
1195
1196    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1197        Just(0)
1198            .prop_perturb(|_, rng| {
1199                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1200            })
1201            .boxed()
1202    }
1203
1204    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1205        prop_oneof![
1206            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1207                id,
1208                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1209            )),
1210            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1211                id,
1212                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1213            )),
1214        ]
1215        .boxed()
1216    }
1217
1218    pub fn sign_batch_header<R: Rng + CryptoRng>(
1219        validator_set: &ValidatorSet,
1220        batch_header: &BatchHeader<CurrentNetwork>,
1221        rng: &mut R,
1222    ) -> IndexSet<Signature<CurrentNetwork>> {
1223        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1224        for validator in validator_set.0.iter() {
1225            let private_key = validator.private_key;
1226            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1227        }
1228        signatures
1229    }
1230
1231    #[proptest]
1232    fn test_certificate_duplicate(
1233        context: CommitteeContext,
1234        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1235        mut rng: CryptoTestRng,
1236        selector: Selector,
1237    ) {
1238        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1239        let committee_id = committee.id();
1240
1241        // Initialize the storage.
1242        let ledger = Arc::new(MockLedgerService::new(committee));
1243        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1244
1245        // Ensure the storage is empty.
1246        assert_storage(&storage, &[], &[], &[], &Default::default());
1247
1248        // Create a new certificate.
1249        let signer = selector.select(&validators);
1250
1251        let mut transmission_map = IndexMap::new();
1252
1253        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1254            transmission_map.insert(*id, t.clone());
1255        }
1256
1257        let batch_header = BatchHeader::new(
1258            &signer.private_key,
1259            0,
1260            now(),
1261            committee_id,
1262            transmission_map.keys().cloned().collect(),
1263            Default::default(),
1264            &mut rng,
1265        )
1266        .unwrap();
1267
1268        // Remove the author from the validator set passed to create the batch
1269        // certificate, the author should not sign their own batch.
1270        let mut validators = validators.clone();
1271        validators.remove(signer);
1272
1273        let certificate = BatchCertificate::from(
1274            batch_header.clone(),
1275            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1276        )
1277        .unwrap();
1278
1279        // Retrieve the certificate ID.
1280        let certificate_id = certificate.id();
1281        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1282        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1283            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1284        }
1285
1286        // Retrieve the round.
1287        let round = certificate.round();
1288        // Retrieve the batch ID.
1289        let batch_id = certificate.batch_id();
1290        // Retrieve the author of the batch.
1291        let author = certificate.author();
1292
1293        // Construct the expected layout for 'rounds'.
1294        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1295        // Construct the expected layout for 'certificates'.
1296        let certificates = [(certificate_id, certificate.clone())];
1297        // Construct the expected layout for 'batch_ids'.
1298        let batch_ids = [(batch_id, round)];
1299
1300        // Insert the certificate.
1301        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1302            transmission_map.into_iter().collect();
1303        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1304        // Ensure the certificate exists in storage.
1305        assert!(storage.contains_certificate(certificate_id));
1306        // Check that the underlying storage representation is correct.
1307        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1308
1309        // Insert the certificate again - without any missing transmissions.
1310        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1311        // Ensure the certificate exists in storage.
1312        assert!(storage.contains_certificate(certificate_id));
1313        // Check that the underlying storage representation remains unchanged.
1314        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1315
1316        // Insert the certificate again - with all of the original missing transmissions.
1317        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1318        // Ensure the certificate exists in storage.
1319        assert!(storage.contains_certificate(certificate_id));
1320        // Check that the underlying storage representation remains unchanged.
1321        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1322    }
1323}