amareleo_node_bft/helpers/
storage.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21    ledger::{
22        block::{Block, Transaction},
23        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24    },
25    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26    utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::RwLock;
34use rayon::iter::{IntoParallelIterator, ParallelIterator};
35use std::{
36    collections::{HashMap, HashSet},
37    sync::{
38        Arc,
39        atomic::{AtomicU32, AtomicU64, Ordering},
40    },
41};
42use tracing::subscriber::DefaultGuard;
43
44#[derive(Clone, Debug)]
45pub struct Storage<N: Network>(Arc<StorageInner<N>>);
46
47impl<N: Network> std::ops::Deref for Storage<N> {
48    type Target = Arc<StorageInner<N>>;
49
50    fn deref(&self) -> &Self::Target {
51        &self.0
52    }
53}
54
55impl<N: Network> TracingHandlerGuard for Storage<N> {
56    /// Retruns tracing guard
57    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
58        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
59    }
60}
61
62/// The storage for the memory pool.
63///
64/// The storage is used to store the following:
65/// - `current_height` tracker.
66/// - `current_round` tracker.
67/// - `round` to `(certificate ID, batch ID, author)` entries.
68/// - `certificate ID` to `certificate` entries.
69/// - `batch ID` to `round` entries.
70/// - `transmission ID` to `(transmission, certificate IDs)` entries.
71///
72/// The chain of events is as follows:
73/// 1. A `transmission` is received.
74/// 2. After a `batch` is ready to be stored:
75///   - The `certificate` is inserted, triggering updates to the
76///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
77///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
78///   - The certificate ID is inserted into the `transmissions` map.
79/// 3. After a `round` reaches quorum threshold:
80///  - The next round is inserted into the `current_round`.
81#[derive(Debug)]
82pub struct StorageInner<N: Network> {
83    /// The ledger service.
84    ledger: Arc<dyn LedgerService<N>>,
85    /* Once per block */
86    /// The current height.
87    current_height: AtomicU32,
88    /* Once per round */
89    /// The current round.
90    current_round: AtomicU64,
91    /// The `round` for which garbage collection has occurred **up to** (inclusive).
92    gc_round: AtomicU64,
93    /// The maximum number of rounds to keep in storage.
94    max_gc_rounds: u64,
95    /* Once per batch */
96    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
97    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
98    /// The map of `certificate ID` to `certificate`.
99    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
100    /// The map of `batch ID` to `round`.
101    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
102    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
103    transmissions: Arc<dyn StorageService<N>>,
104    /// Tracing handle
105    tracing: Option<TracingHandler>,
106}
107
108impl<N: Network> Storage<N> {
109    /// Initializes a new instance of storage.
110    pub fn new(
111        ledger: Arc<dyn LedgerService<N>>,
112        transmissions: Arc<dyn StorageService<N>>,
113        max_gc_rounds: u64,
114        tracing: Option<TracingHandler>,
115    ) -> Self {
116        // Retrieve the latest committee bonded in the ledger
117        // (genesis committee if the ledger contains only the genesis block).
118        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
119        // Retrieve the round at which that committee was created, or 1 if it is the genesis committee.
120        let current_round = committee.starting_round().max(1);
121
122        // Create the storage.
123        let storage = Self(Arc::new(StorageInner {
124            ledger,
125            current_height: Default::default(),
126            current_round: Default::default(),
127            gc_round: Default::default(),
128            max_gc_rounds,
129            rounds: Default::default(),
130            certificates: Default::default(),
131            batch_ids: Default::default(),
132            transmissions,
133            tracing,
134        }));
135        // Update the storage to the current round.
136        storage.update_current_round(current_round);
137        // Perform GC on the current round.
138        // Since there are no certificates yet, this only sets `gc_round`.
139        storage.garbage_collect_certificates(current_round);
140        // Return the storage.
141        storage
142    }
143}
144
145impl<N: Network> Storage<N> {
146    /// Returns the current height.
147    pub fn current_height(&self) -> u32 {
148        // Get the current height.
149        self.current_height.load(Ordering::SeqCst)
150    }
151}
152
153impl<N: Network> Storage<N> {
154    /// Returns the current round.
155    pub fn current_round(&self) -> u64 {
156        // Get the current round.
157        self.current_round.load(Ordering::SeqCst)
158    }
159
160    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
161    pub fn gc_round(&self) -> u64 {
162        // Get the GC round.
163        self.gc_round.load(Ordering::SeqCst)
164    }
165
166    /// Returns the maximum number of rounds to keep in storage.
167    pub fn max_gc_rounds(&self) -> u64 {
168        self.max_gc_rounds
169    }
170
171    /// Increments storage to the next round, updating the current round.
172    /// Note: This method is only called once per round, upon certification of the primary's batch.
173    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
174        // Determine the next round.
175        let next_round = current_round + 1;
176
177        // Check if the next round is less than the current round in storage.
178        {
179            // Retrieve the storage round.
180            let storage_round = self.current_round();
181            // If the next round is less than the current round in storage, return early with the storage round.
182            if next_round < storage_round {
183                return Ok(storage_round);
184            }
185        }
186
187        // Retrieve the current committee.
188        let current_committee = self.ledger.current_committee()?;
189        // Retrieve the current committee's starting round.
190        let starting_round = current_committee.starting_round();
191        // If the primary is behind the current committee's starting round, sync with the latest block.
192        if next_round < starting_round {
193            // Retrieve the latest block round.
194            let latest_block_round = self.ledger.latest_round();
195            // Log the round sync.
196            guard_info!(
197                self,
198                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
199            );
200            // Sync the round with the latest block.
201            self.sync_round_with_block(latest_block_round);
202            // Return the latest block round.
203            return Ok(latest_block_round);
204        }
205
206        // Update the storage to the next round.
207        self.update_current_round(next_round);
208
209        #[cfg(feature = "metrics")]
210        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
211
212        // Retrieve the storage round.
213        let storage_round = self.current_round();
214        // Retrieve the GC round.
215        let gc_round = self.gc_round();
216        // Ensure the next round matches in storage.
217        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
218        // Ensure the next round is greater than or equal to the GC round.
219        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
220
221        // Log the updated round.
222        guard_info!(self, "Starting round {next_round}...");
223        Ok(next_round)
224    }
225
226    /// Updates the storage to the next round.
227    fn update_current_round(&self, next_round: u64) {
228        // Update the current round.
229        self.current_round.store(next_round, Ordering::SeqCst);
230    }
231
232    /// Update the storage by performing garbage collection based on the next round.
233    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
234        // Fetch the current GC round.
235        let current_gc_round = self.gc_round();
236        // Compute the next GC round.
237        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
238        // Check if storage needs to be garbage collected.
239        if next_gc_round > current_gc_round {
240            // Remove the GC round(s) from storage.
241            for gc_round in current_gc_round..=next_gc_round {
242                // Iterate over the certificates for the GC round.
243                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
244                    // Remove the certificate from storage.
245                    self.remove_certificate(id);
246                }
247            }
248            // Update the GC round.
249            self.gc_round.store(next_gc_round, Ordering::SeqCst);
250        }
251    }
252}
253
254impl<N: Network> Storage<N> {
255    /// Returns `true` if the storage contains the specified `round`.
256    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
257        // Check if the round exists in storage.
258        self.rounds.read().contains_key(&round)
259    }
260
261    /// Returns `true` if the storage contains the specified `certificate ID`.
262    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
263        // Check if the certificate ID exists in storage.
264        self.certificates.read().contains_key(&certificate_id)
265    }
266
267    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
268    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
269        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
270    }
271
272    /// Returns `true` if the storage contains the specified `batch ID`.
273    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
274        // Check if the batch ID exists in storage.
275        self.batch_ids.read().contains_key(&batch_id)
276    }
277
278    /// Returns `true` if the storage contains the specified `transmission ID`.
279    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
280        self.transmissions.contains_transmission(transmission_id.into())
281    }
282
283    /// Returns the transmission for the given `transmission ID`.
284    /// If the transmission ID does not exist in storage, `None` is returned.
285    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
286        self.transmissions.get_transmission(transmission_id.into())
287    }
288
289    /// Returns the round for the given `certificate ID`.
290    /// If the certificate ID does not exist in storage, `None` is returned.
291    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
292        // Get the round.
293        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
294    }
295
296    /// Returns the round for the given `batch ID`.
297    /// If the batch ID does not exist in storage, `None` is returned.
298    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
299        // Get the round.
300        self.batch_ids.read().get(&batch_id).copied()
301    }
302
303    /// Returns the certificate round for the given `certificate ID`.
304    /// If the certificate ID does not exist in storage, `None` is returned.
305    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
306        // Get the batch certificate and return the round.
307        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
308    }
309
310    /// Returns the certificate for the given `certificate ID`.
311    /// If the certificate ID does not exist in storage, `None` is returned.
312    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
313        // Get the batch certificate.
314        self.certificates.read().get(&certificate_id).cloned()
315    }
316
317    /// Returns the certificate for the given `round` and `author`.
318    /// If the round does not exist in storage, `None` is returned.
319    /// If the author for the round does not exist in storage, `None` is returned.
320    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
321        // Retrieve the certificates.
322        if let Some(entries) = self.rounds.read().get(&round) {
323            let certificates = self.certificates.read();
324            entries.iter().find_map(
325                |(certificate_id, _, a)| {
326                    if a == &author { certificates.get(certificate_id).cloned() } else { None }
327                },
328            )
329        } else {
330            Default::default()
331        }
332    }
333
334    /// Returns the certificates for the given `round`.
335    /// If the round does not exist in storage, an empty set is returned.
336    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
337        // The genesis round does not have batch certificates.
338        if round == 0 {
339            return Default::default();
340        }
341        // Retrieve the certificates.
342        if let Some(entries) = self.rounds.read().get(&round) {
343            let certificates = self.certificates.read();
344            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
345        } else {
346            Default::default()
347        }
348    }
349
350    /// Returns the certificate IDs for the given `round`.
351    /// If the round does not exist in storage, an empty set is returned.
352    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
353        // The genesis round does not have batch certificates.
354        if round == 0 {
355            return Default::default();
356        }
357        // Retrieve the certificates.
358        if let Some(entries) = self.rounds.read().get(&round) {
359            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
360        } else {
361            Default::default()
362        }
363    }
364
365    /// Returns the certificate authors for the given `round`.
366    /// If the round does not exist in storage, an empty set is returned.
367    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
368        // The genesis round does not have batch certificates.
369        if round == 0 {
370            return Default::default();
371        }
372        // Retrieve the certificates.
373        if let Some(entries) = self.rounds.read().get(&round) {
374            entries.iter().map(|(_, _, author)| *author).collect()
375        } else {
376            Default::default()
377        }
378    }
379
380    /// Returns the certificates that have not yet been included in the ledger.
381    /// Note that the order of this set is by round and then insertion.
382    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
383        // Obtain the read locks.
384        let rounds = self.rounds.read();
385        let certificates = self.certificates.read();
386
387        // Iterate over the rounds.
388        cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
389            .flat_map(|(_, certificates_for_round)| {
390                // Iterate over the certificates for the round.
391                cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
392                    // Skip the certificate if it already exists in the ledger.
393                    if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
394                        None
395                    } else {
396                        // Add the certificate to the pending certificates.
397                        certificates.get(&certificate_id).cloned()
398                    }
399                })
400            })
401            .collect()
402    }
403
404    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
405    ///
406    /// This method ensures the following invariants:
407    /// - The batch ID does not already exist in storage.
408    /// - The author is a member of the committee for the batch round.
409    /// - The timestamp is within the allowed time range.
410    /// - None of the transmissions are from any past rounds (up to GC).
411    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
412    /// - All previous certificates declared in the certificate exist in storage (up to GC).
413    /// - All previous certificates are for the previous round (i.e. round - 1).
414    /// - All previous certificates contain a unique author.
415    /// - The previous certificates reached the quorum threshold (N - f).
416    pub fn check_batch_header(
417        &self,
418        batch_header: &BatchHeader<N>,
419        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
420        aborted_transmissions: HashSet<TransmissionID<N>>,
421    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
422        // Retrieve the round.
423        let round = batch_header.round();
424        // Retrieve the GC round.
425        let gc_round = self.gc_round();
426        // Construct a GC log message.
427        let gc_log = format!("(gc = {gc_round})");
428
429        // Ensure the batch ID does not already exist in storage.
430        if self.contains_batch(batch_header.batch_id()) {
431            bail!("Batch for round {round} already exists in storage {gc_log}")
432        }
433
434        // Retrieve the committee lookback for the batch round.
435        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
436            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
437        };
438        // Ensure the author is in the committee.
439        if !committee_lookback.is_committee_member(batch_header.author()) {
440            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
441        }
442
443        // Check the timestamp for liveness.
444        check_timestamp_for_liveness(batch_header.timestamp())?;
445
446        // Retrieve the missing transmissions in storage from the given transmissions.
447        let missing_transmissions = self
448            .transmissions
449            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
450            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
451
452        // Compute the previous round.
453        let previous_round = round.saturating_sub(1);
454        // Check if the previous round is within range of the GC round.
455        if previous_round > gc_round {
456            // Retrieve the committee lookback for the previous round.
457            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
458                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
459            };
460            // Ensure the previous round certificates exists in storage.
461            if !self.contains_certificates_for_round(previous_round) {
462                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
463            }
464            // Ensure the number of previous certificate IDs is at or below the number of committee members.
465            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
466                bail!("Too many previous certificates for round {round} {gc_log}")
467            }
468            // Initialize a set of the previous authors.
469            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
470            // Ensure storage contains all declared previous certificates (up to GC).
471            for previous_certificate_id in batch_header.previous_certificate_ids() {
472                // Retrieve the previous certificate.
473                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
474                    bail!(
475                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
476                        fmt_id(previous_certificate_id)
477                    )
478                };
479                // Ensure the previous certificate is for the previous round.
480                if previous_certificate.round() != previous_round {
481                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
482                }
483                // Ensure the previous author is new.
484                if previous_authors.contains(&previous_certificate.author()) {
485                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
486                }
487                // Insert the author of the previous certificate.
488                previous_authors.insert(previous_certificate.author());
489            }
490            // Ensure the previous certificates have reached the quorum threshold.
491            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
492                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
493            }
494        }
495        Ok(missing_transmissions)
496    }
497
498    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
499    ///
500    /// This method ensures the following invariants:
501    /// - The certificate ID does not already exist in storage.
502    /// - The batch ID does not already exist in storage.
503    /// - The author is a member of the committee for the batch round.
504    /// - The author has not already created a certificate for the batch round.
505    /// - The timestamp is within the allowed time range.
506    /// - None of the transmissions are from any past rounds (up to GC).
507    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
508    /// - All previous certificates declared in the certificate exist in storage (up to GC).
509    /// - All previous certificates are for the previous round (i.e. round - 1).
510    /// - The previous certificates reached the quorum threshold (N - f).
511    /// - The timestamps from the signers are all within the allowed time range.
512    /// - The signers have reached the quorum threshold (N - f).
513    pub fn check_certificate(
514        &self,
515        certificate: &BatchCertificate<N>,
516        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
517        aborted_transmissions: HashSet<TransmissionID<N>>,
518    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
519        // Retrieve the round.
520        let round = certificate.round();
521        // Retrieve the GC round.
522        let gc_round = self.gc_round();
523        // Construct a GC log message.
524        let gc_log = format!("(gc = {gc_round})");
525
526        // Ensure the certificate ID does not already exist in storage.
527        if self.contains_certificate(certificate.id()) {
528            bail!("Certificate for round {round} already exists in storage {gc_log}")
529        }
530
531        // Ensure the storage does not already contain a certificate for this author in this round.
532        if self.contains_certificate_in_round_from(round, certificate.author()) {
533            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
534        }
535
536        // Ensure the batch header is well-formed.
537        let missing_transmissions =
538            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
539
540        // Check the timestamp for liveness.
541        check_timestamp_for_liveness(certificate.timestamp())?;
542
543        // Retrieve the committee lookback for the batch round.
544        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
545            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
546        };
547
548        // Initialize a set of the signers.
549        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
550        // Append the batch author.
551        signers.insert(certificate.author());
552
553        // Iterate over the signatures.
554        for signature in certificate.signatures() {
555            // Retrieve the signer.
556            let signer = signature.to_address();
557            // Ensure the signer is in the committee.
558            if !committee_lookback.is_committee_member(signer) {
559                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
560            }
561            // Append the signer.
562            signers.insert(signer);
563        }
564
565        // Ensure the signatures have reached the quorum threshold.
566        if !committee_lookback.is_quorum_threshold_reached(&signers) {
567            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
568        }
569        Ok(missing_transmissions)
570    }
571
572    /// Inserts the given `certificate` into storage.
573    ///
574    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
575    ///
576    /// This method ensures the following invariants:
577    /// - The certificate ID does not already exist in storage.
578    /// - The batch ID does not already exist in storage.
579    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
580    /// - All previous certificates declared in the certificate exist in storage (up to GC).
581    /// - All previous certificates are for the previous round (i.e. round - 1).
582    /// - The previous certificates reached the quorum threshold (N - f).
583    pub fn insert_certificate(
584        &self,
585        certificate: BatchCertificate<N>,
586        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
587        aborted_transmissions: HashSet<TransmissionID<N>>,
588    ) -> Result<()> {
589        // Ensure the certificate round is above the GC round.
590        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
591        // Ensure the certificate and its transmissions are valid.
592        let missing_transmissions =
593            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
594        // Insert the certificate into storage.
595        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
596        Ok(())
597    }
598
599    /// Inserts the given `certificate` into storage.
600    ///
601    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
602    ///
603    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
604    fn insert_certificate_atomic(
605        &self,
606        certificate: BatchCertificate<N>,
607        aborted_transmission_ids: HashSet<TransmissionID<N>>,
608        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
609    ) {
610        // Retrieve the round.
611        let round = certificate.round();
612        // Retrieve the certificate ID.
613        let certificate_id = certificate.id();
614        // Retrieve the batch ID.
615        let batch_id = certificate.batch_id();
616        // Retrieve the author of the batch.
617        let author = certificate.author();
618
619        // Insert the round to certificate ID entry.
620        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
621        // Obtain the certificate's transmission ids.
622        let transmission_ids = certificate.transmission_ids().clone();
623        // Insert the certificate.
624        self.certificates.write().insert(certificate_id, certificate);
625        // Insert the batch ID.
626        self.batch_ids.write().insert(batch_id, round);
627        // Insert the certificate ID for each of the transmissions into storage.
628        self.transmissions.insert_transmissions(
629            certificate_id,
630            transmission_ids,
631            aborted_transmission_ids,
632            missing_transmissions,
633        );
634    }
635
636    /// Removes the given `certificate ID` from storage.
637    ///
638    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
639    ///
640    /// If the certificate was successfully removed, `true` is returned.
641    /// If the certificate did not exist in storage, `false` is returned.
642    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
643        // Retrieve the certificate.
644        let Some(certificate) = self.get_certificate(certificate_id) else {
645            guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
646            return false;
647        };
648        // Retrieve the round.
649        let round = certificate.round();
650        // Retrieve the batch ID.
651        let batch_id = certificate.batch_id();
652        // Compute the author of the batch.
653        let author = certificate.author();
654
655        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
656        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
657        //  It will be better to write tests that compare the union of the sets.
658
659        // Update the round.
660        match self.rounds.write().entry(round) {
661            Entry::Occupied(mut entry) => {
662                // Remove the round to certificate ID entry.
663                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
664                // If the round is empty, remove it.
665                if entry.get().is_empty() {
666                    entry.swap_remove();
667                }
668            }
669            Entry::Vacant(_) => {}
670        }
671        // Remove the certificate.
672        self.certificates.write().swap_remove(&certificate_id);
673        // Remove the batch ID.
674        self.batch_ids.write().swap_remove(&batch_id);
675        // Remove the transmission entries in the certificate from storage.
676        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
677        // Return successfully.
678        true
679    }
680}
681
682impl<N: Network> Storage<N> {
683    /// Syncs the current height with the block.
684    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
685        // If the block height is greater than the current height in storage, sync the height.
686        if next_height > self.current_height() {
687            // Update the current height in storage.
688            self.current_height.store(next_height, Ordering::SeqCst);
689        }
690    }
691
692    /// Syncs the current round with the block.
693    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
694        // Retrieve the current round in the block.
695        let next_round = next_round.max(1);
696        // If the round in the block is greater than the current round in storage, sync the round.
697        if next_round > self.current_round() {
698            // Update the current round in storage.
699            self.update_current_round(next_round);
700            // Log the updated round.
701            guard_info!(self, "Synced to round {next_round}...");
702        }
703    }
704
705    /// Syncs the batch certificate with the block.
706    pub(crate) fn sync_certificate_with_block(
707        &self,
708        block: &Block<N>,
709        certificate: BatchCertificate<N>,
710        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
711    ) {
712        // Skip if the certificate round is below the GC round.
713        if certificate.round() <= self.gc_round() {
714            return;
715        }
716        // If the certificate ID already exists in storage, skip it.
717        if self.contains_certificate(certificate.id()) {
718            return;
719        }
720        // Retrieve the transmissions for the certificate.
721        let mut missing_transmissions = HashMap::new();
722
723        // Retrieve the aborted transmissions for the certificate.
724        let mut aborted_transmissions = HashSet::new();
725
726        // Track the block's aborted solutions and transactions.
727        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
728        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
729
730        // Iterate over the transmission IDs.
731        for transmission_id in certificate.transmission_ids() {
732            // If the transmission ID already exists in the map, skip it.
733            if missing_transmissions.contains_key(transmission_id) {
734                continue;
735            }
736            // If the transmission ID exists in storage, skip it.
737            if self.contains_transmission(*transmission_id) {
738                continue;
739            }
740            // Retrieve the transmission.
741            match transmission_id {
742                TransmissionID::Ratification => (),
743                TransmissionID::Solution(solution_id, _) => {
744                    // Retrieve the solution.
745                    match block.get_solution(solution_id) {
746                        // Insert the solution.
747                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
748                        // Otherwise, try to load the solution from the ledger.
749                        None => match self.ledger.get_solution(solution_id) {
750                            // Insert the solution.
751                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
752                            // Check if the solution is in the aborted solutions.
753                            Err(_) => {
754                                // Insert the aborted solution if it exists in the block or ledger.
755                                match aborted_solutions.contains(solution_id)
756                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
757                                {
758                                    true => {
759                                        aborted_transmissions.insert(*transmission_id);
760                                    }
761                                    false => {
762                                        guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
763                                    }
764                                }
765                                continue;
766                            }
767                        },
768                    };
769                }
770                TransmissionID::Transaction(transaction_id, _) => {
771                    // Retrieve the transaction.
772                    match unconfirmed_transactions.get(transaction_id) {
773                        // Insert the transaction.
774                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
775                        // Otherwise, try to load the unconfirmed transaction from the ledger.
776                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
777                            // Insert the transaction.
778                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
779                            // Check if the transaction is in the aborted transactions.
780                            Err(_) => {
781                                // Insert the aborted transaction if it exists in the block or ledger.
782                                match aborted_transactions.contains(transaction_id)
783                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
784                                {
785                                    true => {
786                                        aborted_transmissions.insert(*transmission_id);
787                                    }
788                                    false => guard_warn!(
789                                        self,
790                                        "Missing transaction {transaction_id} in block {}",
791                                        block.height()
792                                    ),
793                                }
794                                continue;
795                            }
796                        },
797                    };
798                }
799            }
800        }
801        // Insert the batch certificate into storage.
802        let certificate_id = fmt_id(certificate.id());
803        guard_debug!(
804            self,
805            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
806            certificate.round(),
807            certificate.transmission_ids().len()
808        );
809        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
810            guard_error!(
811                self,
812                "Failed to insert certificate '{certificate_id}' from block {} - {error}",
813                block.height()
814            );
815        }
816    }
817}
818
819#[cfg(test)]
820impl<N: Network> Storage<N> {
821    /// Returns the ledger service.
822    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
823        &self.ledger
824    }
825
826    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
827    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
828        self.rounds.read().clone().into_iter()
829    }
830
831    /// Returns an iterator over the `(certificate ID, certificate)` entries.
832    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
833        self.certificates.read().clone().into_iter()
834    }
835
836    /// Returns an iterator over the `(batch ID, round)` entries.
837    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
838        self.batch_ids.read().clone().into_iter()
839    }
840
841    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
842    pub fn transmissions_iter(
843        &self,
844    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
845        self.transmissions.as_hashmap().into_iter()
846    }
847
848    /// Inserts the given `certificate` into storage.
849    ///
850    /// Note: Do NOT use this in production. This is for **testing only**.
851    #[cfg(test)]
852    #[doc(hidden)]
853    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
854        // Retrieve the round.
855        let round = certificate.round();
856        // Retrieve the certificate ID.
857        let certificate_id = certificate.id();
858        // Retrieve the batch ID.
859        let batch_id = certificate.batch_id();
860        // Retrieve the author of the batch.
861        let author = certificate.author();
862
863        // Insert the round to certificate ID entry.
864        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
865        // Obtain the certificate's transmission ids.
866        let transmission_ids = certificate.transmission_ids().clone();
867        // Insert the certificate.
868        self.certificates.write().insert(certificate_id, certificate);
869        // Insert the batch ID.
870        self.batch_ids.write().insert(batch_id, round);
871
872        // Construct the dummy missing transmissions (for testing purposes).
873        let missing_transmissions = transmission_ids
874            .iter()
875            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
876            .collect::<HashMap<_, _>>();
877        // Insert the certificate ID for each of the transmissions into storage.
878        self.transmissions.insert_transmissions(
879            certificate_id,
880            transmission_ids,
881            Default::default(),
882            missing_transmissions,
883        );
884    }
885}
886
887#[cfg(test)]
888pub(crate) mod tests {
889    use super::*;
890    use amareleo_node_bft_ledger_service::MockLedgerService;
891    use amareleo_node_bft_storage_service::BFTMemoryService;
892    use snarkvm::{
893        ledger::narwhal::Data,
894        prelude::{Rng, TestRng},
895    };
896
897    use ::bytes::Bytes;
898    use indexmap::indexset;
899
900    type CurrentNetwork = snarkvm::prelude::MainnetV0;
901
902    /// Asserts that the storage matches the expected layout.
903    pub fn assert_storage<N: Network>(
904        storage: &Storage<N>,
905        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
906        certificates: &[(Field<N>, BatchCertificate<N>)],
907        batch_ids: &[(Field<N>, u64)],
908        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
909    ) {
910        // Ensure the rounds are well-formed.
911        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
912        // Ensure the certificates are well-formed.
913        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
914        // Ensure the batch IDs are well-formed.
915        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
916        // Ensure the transmissions are well-formed.
917        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
918    }
919
920    /// Samples a random transmission.
921    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
922        // Sample random fake solution bytes.
923        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
924        // Sample random fake transaction bytes.
925        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
926        // Sample a random transmission.
927        match rng.gen::<bool>() {
928            true => Transmission::Solution(s(rng)),
929            false => Transmission::Transaction(t(rng)),
930        }
931    }
932
933    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
934    pub(crate) fn sample_transmissions(
935        certificate: &BatchCertificate<CurrentNetwork>,
936        rng: &mut TestRng,
937    ) -> (
938        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
939        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
940    ) {
941        // Retrieve the certificate ID.
942        let certificate_id = certificate.id();
943
944        let mut missing_transmissions = HashMap::new();
945        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
946        for transmission_id in certificate.transmission_ids() {
947            // Initialize the transmission.
948            let transmission = sample_transmission(rng);
949            // Update the missing transmissions.
950            missing_transmissions.insert(*transmission_id, transmission.clone());
951            // Update the transmissions map.
952            transmissions
953                .entry(*transmission_id)
954                .or_insert((transmission, Default::default()))
955                .1
956                .insert(certificate_id);
957        }
958        (missing_transmissions, transmissions)
959    }
960
961    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
962
963    #[test]
964    fn test_certificate_insert_remove() {
965        let rng = &mut TestRng::default();
966
967        // Sample a committee.
968        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
969        // Initialize the ledger.
970        let ledger = Arc::new(MockLedgerService::new(committee));
971        // Initialize the storage.
972        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
973
974        // Ensure the storage is empty.
975        assert_storage(&storage, &[], &[], &[], &Default::default());
976
977        // Create a new certificate.
978        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
979        // Retrieve the certificate ID.
980        let certificate_id = certificate.id();
981        // Retrieve the round.
982        let round = certificate.round();
983        // Retrieve the batch ID.
984        let batch_id = certificate.batch_id();
985        // Retrieve the author of the batch.
986        let author = certificate.author();
987
988        // Construct the sample 'transmissions'.
989        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
990
991        // Insert the certificate.
992        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
993        // Ensure the certificate exists in storage.
994        assert!(storage.contains_certificate(certificate_id));
995        // Ensure the certificate is stored in the correct round.
996        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
997        // Ensure the certificate is stored for the correct round and author.
998        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
999
1000        // Check that the underlying storage representation is correct.
1001        {
1002            // Construct the expected layout for 'rounds'.
1003            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1004            // Construct the expected layout for 'certificates'.
1005            let certificates = [(certificate_id, certificate.clone())];
1006            // Construct the expected layout for 'batch_ids'.
1007            let batch_ids = [(batch_id, round)];
1008            // Assert the storage is well-formed.
1009            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1010        }
1011
1012        // Retrieve the certificate.
1013        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1014        // Ensure the retrieved certificate is the same as the inserted certificate.
1015        assert_eq!(certificate, candidate_certificate);
1016
1017        // Remove the certificate.
1018        assert!(storage.remove_certificate(certificate_id));
1019        // Ensure the certificate does not exist in storage.
1020        assert!(!storage.contains_certificate(certificate_id));
1021        // Ensure the certificate is no longer stored in the round.
1022        assert!(storage.get_certificates_for_round(round).is_empty());
1023        // Ensure the certificate is no longer stored for the round and author.
1024        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1025        // Ensure the storage is empty.
1026        assert_storage(&storage, &[], &[], &[], &Default::default());
1027    }
1028
1029    #[test]
1030    fn test_certificate_duplicate() {
1031        let rng = &mut TestRng::default();
1032
1033        // Sample a committee.
1034        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1035        // Initialize the ledger.
1036        let ledger = Arc::new(MockLedgerService::new(committee));
1037        // Initialize the storage.
1038        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1039
1040        // Ensure the storage is empty.
1041        assert_storage(&storage, &[], &[], &[], &Default::default());
1042
1043        // Create a new certificate.
1044        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1045        // Retrieve the certificate ID.
1046        let certificate_id = certificate.id();
1047        // Retrieve the round.
1048        let round = certificate.round();
1049        // Retrieve the batch ID.
1050        let batch_id = certificate.batch_id();
1051        // Retrieve the author of the batch.
1052        let author = certificate.author();
1053
1054        // Construct the expected layout for 'rounds'.
1055        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1056        // Construct the expected layout for 'certificates'.
1057        let certificates = [(certificate_id, certificate.clone())];
1058        // Construct the expected layout for 'batch_ids'.
1059        let batch_ids = [(batch_id, round)];
1060        // Construct the sample 'transmissions'.
1061        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1062
1063        // Insert the certificate.
1064        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1065        // Ensure the certificate exists in storage.
1066        assert!(storage.contains_certificate(certificate_id));
1067        // Check that the underlying storage representation is correct.
1068        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1069
1070        // Insert the certificate again - without any missing transmissions.
1071        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1072        // Ensure the certificate exists in storage.
1073        assert!(storage.contains_certificate(certificate_id));
1074        // Check that the underlying storage representation remains unchanged.
1075        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1076
1077        // Insert the certificate again - with all of the original missing transmissions.
1078        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1079        // Ensure the certificate exists in storage.
1080        assert!(storage.contains_certificate(certificate_id));
1081        // Check that the underlying storage representation remains unchanged.
1082        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1083    }
1084}
1085
1086#[cfg(test)]
1087pub mod prop_tests {
1088    use super::*;
1089    use crate::helpers::{now, storage::tests::assert_storage};
1090    use amareleo_node_bft_ledger_service::MockLedgerService;
1091    use amareleo_node_bft_storage_service::BFTMemoryService;
1092    use snarkvm::{
1093        ledger::{
1094            committee::prop_tests::{CommitteeContext, ValidatorSet},
1095            narwhal::{BatchHeader, Data},
1096            puzzle::SolutionID,
1097        },
1098        prelude::{Signature, Uniform},
1099    };
1100
1101    use ::bytes::Bytes;
1102    use indexmap::indexset;
1103    use proptest::{
1104        collection,
1105        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1106        prop_oneof,
1107        sample::{Selector, size_range},
1108        test_runner::TestRng,
1109    };
1110    use rand::{CryptoRng, Error, Rng, RngCore};
1111    use std::fmt::Debug;
1112    use test_strategy::proptest;
1113
1114    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1115
1116    impl Arbitrary for Storage<CurrentNetwork> {
1117        type Parameters = CommitteeContext;
1118        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1119
1120        fn arbitrary() -> Self::Strategy {
1121            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1122                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1123                    let ledger = Arc::new(MockLedgerService::new(committee));
1124                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1125                })
1126                .boxed()
1127        }
1128
1129        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1130            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1131                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1132                    let ledger = Arc::new(MockLedgerService::new(committee));
1133                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1134                })
1135                .boxed()
1136        }
1137    }
1138
1139    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1140    #[derive(Debug)]
1141    pub struct CryptoTestRng(TestRng);
1142
1143    impl Arbitrary for CryptoTestRng {
1144        type Parameters = ();
1145        type Strategy = BoxedStrategy<CryptoTestRng>;
1146
1147        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1148            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1149        }
1150    }
1151    impl RngCore for CryptoTestRng {
1152        fn next_u32(&mut self) -> u32 {
1153            self.0.next_u32()
1154        }
1155
1156        fn next_u64(&mut self) -> u64 {
1157            self.0.next_u64()
1158        }
1159
1160        fn fill_bytes(&mut self, dest: &mut [u8]) {
1161            self.0.fill_bytes(dest);
1162        }
1163
1164        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1165            self.0.try_fill_bytes(dest)
1166        }
1167    }
1168
1169    impl CryptoRng for CryptoTestRng {}
1170
1171    #[derive(Debug, Clone)]
1172    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1173
1174    impl Arbitrary for AnyTransmission {
1175        type Parameters = ();
1176        type Strategy = BoxedStrategy<AnyTransmission>;
1177
1178        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1179            any_transmission().prop_map(AnyTransmission).boxed()
1180        }
1181    }
1182
1183    #[derive(Debug, Clone)]
1184    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1185
1186    impl Arbitrary for AnyTransmissionID {
1187        type Parameters = ();
1188        type Strategy = BoxedStrategy<AnyTransmissionID>;
1189
1190        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1191            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1192        }
1193    }
1194
1195    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1196        prop_oneof![
1197            (collection::vec(any::<u8>(), 512..=512))
1198                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1199            (collection::vec(any::<u8>(), 2048..=2048))
1200                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1201        ]
1202        .boxed()
1203    }
1204
1205    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1206        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1207    }
1208
1209    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1210        Just(0)
1211            .prop_perturb(|_, rng| {
1212                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1213            })
1214            .boxed()
1215    }
1216
1217    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1218        prop_oneof![
1219            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1220                id,
1221                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1222            )),
1223            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1224                id,
1225                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1226            )),
1227        ]
1228        .boxed()
1229    }
1230
1231    pub fn sign_batch_header<R: Rng + CryptoRng>(
1232        validator_set: &ValidatorSet,
1233        batch_header: &BatchHeader<CurrentNetwork>,
1234        rng: &mut R,
1235    ) -> IndexSet<Signature<CurrentNetwork>> {
1236        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1237        for validator in validator_set.0.iter() {
1238            let private_key = validator.private_key;
1239            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1240        }
1241        signatures
1242    }
1243
1244    #[proptest]
1245    fn test_certificate_duplicate(
1246        context: CommitteeContext,
1247        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1248        mut rng: CryptoTestRng,
1249        selector: Selector,
1250    ) {
1251        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1252        let committee_id = committee.id();
1253
1254        // Initialize the storage.
1255        let ledger = Arc::new(MockLedgerService::new(committee));
1256        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1257
1258        // Ensure the storage is empty.
1259        assert_storage(&storage, &[], &[], &[], &Default::default());
1260
1261        // Create a new certificate.
1262        let signer = selector.select(&validators);
1263
1264        let mut transmission_map = IndexMap::new();
1265
1266        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1267            transmission_map.insert(*id, t.clone());
1268        }
1269
1270        let batch_header = BatchHeader::new(
1271            &signer.private_key,
1272            0,
1273            now(),
1274            committee_id,
1275            transmission_map.keys().cloned().collect(),
1276            Default::default(),
1277            &mut rng,
1278        )
1279        .unwrap();
1280
1281        // Remove the author from the validator set passed to create the batch
1282        // certificate, the author should not sign their own batch.
1283        let mut validators = validators.clone();
1284        validators.remove(signer);
1285
1286        let certificate = BatchCertificate::from(
1287            batch_header.clone(),
1288            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1289        )
1290        .unwrap();
1291
1292        // Retrieve the certificate ID.
1293        let certificate_id = certificate.id();
1294        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1295        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1296            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1297        }
1298
1299        // Retrieve the round.
1300        let round = certificate.round();
1301        // Retrieve the batch ID.
1302        let batch_id = certificate.batch_id();
1303        // Retrieve the author of the batch.
1304        let author = certificate.author();
1305
1306        // Construct the expected layout for 'rounds'.
1307        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1308        // Construct the expected layout for 'certificates'.
1309        let certificates = [(certificate_id, certificate.clone())];
1310        // Construct the expected layout for 'batch_ids'.
1311        let batch_ids = [(batch_id, round)];
1312
1313        // Insert the certificate.
1314        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1315            transmission_map.into_iter().collect();
1316        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1317        // Ensure the certificate exists in storage.
1318        assert!(storage.contains_certificate(certificate_id));
1319        // Check that the underlying storage representation is correct.
1320        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1321
1322        // Insert the certificate again - without any missing transmissions.
1323        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1324        // Ensure the certificate exists in storage.
1325        assert!(storage.contains_certificate(certificate_id));
1326        // Check that the underlying storage representation remains unchanged.
1327        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1328
1329        // Insert the certificate again - with all of the original missing transmissions.
1330        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1331        // Ensure the certificate exists in storage.
1332        assert!(storage.contains_certificate(certificate_id));
1333        // Check that the underlying storage representation remains unchanged.
1334        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1335    }
1336}