amareleo_node_bft/helpers/
storage.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21    ledger::{
22        block::{Block, Transaction},
23        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24    },
25    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26    utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::RwLock;
34use rayon::iter::{IntoParallelIterator, ParallelIterator};
35use std::{
36    collections::{HashMap, HashSet},
37    sync::{
38        Arc,
39        atomic::{AtomicU32, AtomicU64, Ordering},
40    },
41};
42use tracing::subscriber::DefaultGuard;
43
44#[derive(Clone, Debug)]
45pub struct Storage<N: Network>(Arc<StorageInner<N>>);
46
47impl<N: Network> std::ops::Deref for Storage<N> {
48    type Target = Arc<StorageInner<N>>;
49
50    fn deref(&self) -> &Self::Target {
51        &self.0
52    }
53}
54
55impl<N: Network> TracingHandlerGuard for Storage<N> {
56    /// Retruns tracing guard
57    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
58        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
59    }
60}
61
62/// The storage for the memory pool.
63///
64/// The storage is used to store the following:
65/// - `current_height` tracker.
66/// - `current_round` tracker.
67/// - `round` to `(certificate ID, batch ID, author)` entries.
68/// - `certificate ID` to `certificate` entries.
69/// - `batch ID` to `round` entries.
70/// - `transmission ID` to `(transmission, certificate IDs)` entries.
71///
72/// The chain of events is as follows:
73/// 1. A `transmission` is received.
74/// 2. After a `batch` is ready to be stored:
75///   - The `certificate` is inserted, triggering updates to the
76///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
77///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
78///   - The certificate ID is inserted into the `transmissions` map.
79/// 3. After a `round` reaches quorum threshold:
80///  - The next round is inserted into the `current_round`.
81#[derive(Debug)]
82pub struct StorageInner<N: Network> {
83    /// The ledger service.
84    ledger: Arc<dyn LedgerService<N>>,
85    /* Once per block */
86    /// The current height.
87    current_height: AtomicU32,
88    /* Once per round */
89    /// The current round.
90    ///
91    /// Invariant: current_round > 0.
92    /// This is established in [`Storage::new`], which sets it to at least 1 via [`Storage::update_current_round`].
93    /// The only callers of [`Storage::update_current_round`] are
94    /// [`Storage::increment_to_next_round`] and [`Storage::sync_round_with_block`],
95    /// both of which set it to at least 1.
96    current_round: AtomicU64,
97    /// The `round` for which garbage collection has occurred **up to** (inclusive).
98    gc_round: AtomicU64,
99    /// The maximum number of rounds to keep in storage.
100    max_gc_rounds: u64,
101    /* Once per batch */
102    /// The map of `round` to a list of `(certificate ID, author)` entries.
103    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Address<N>)>>>,
104    /// The map of `certificate ID` to `certificate`.
105    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
106    /// The map of `certificate ID` to `round`.
107    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
108    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
109    transmissions: Arc<dyn StorageService<N>>,
110    /// Tracing handle
111    tracing: Option<TracingHandler>,
112}
113
114impl<N: Network> Storage<N> {
115    /// Initializes a new instance of storage.
116    pub fn new(
117        ledger: Arc<dyn LedgerService<N>>,
118        transmissions: Arc<dyn StorageService<N>>,
119        max_gc_rounds: u64,
120        tracing: Option<TracingHandler>,
121    ) -> Self {
122        // Retrieve the latest committee bonded in the ledger
123        // (genesis committee if the ledger contains only the genesis block).
124        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
125        // Retrieve the round at which that committee was created, or 1 if it is the genesis committee.
126        let current_round = committee.starting_round().max(1);
127
128        // Create the storage.
129        let storage = Self(Arc::new(StorageInner {
130            ledger,
131            current_height: Default::default(),
132            current_round: Default::default(),
133            gc_round: Default::default(),
134            max_gc_rounds,
135            rounds: Default::default(),
136            certificates: Default::default(),
137            batch_ids: Default::default(),
138            transmissions,
139            tracing,
140        }));
141        // Update the storage to the current round.
142        storage.update_current_round(current_round);
143        // Perform GC on the current round.
144        // Since there are no certificates yet, this only sets `gc_round`.
145        storage.garbage_collect_certificates(current_round);
146        // Return the storage.
147        storage
148    }
149}
150
151impl<N: Network> Storage<N> {
152    /// Returns the current height.
153    pub fn current_height(&self) -> u32 {
154        // Get the current height.
155        self.current_height.load(Ordering::SeqCst)
156    }
157}
158
159impl<N: Network> Storage<N> {
160    /// Returns the current round.
161    pub fn current_round(&self) -> u64 {
162        // Get the current round.
163        self.current_round.load(Ordering::SeqCst)
164    }
165
166    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
167    pub fn gc_round(&self) -> u64 {
168        // Get the GC round.
169        self.gc_round.load(Ordering::SeqCst)
170    }
171
172    /// Returns the maximum number of rounds to keep in storage.
173    pub fn max_gc_rounds(&self) -> u64 {
174        self.max_gc_rounds
175    }
176
177    /// Increments storage to the next round, updating the current round.
178    /// Note: This method is only called once per round, upon certification of the primary's batch.
179    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
180        // Determine the next round.
181        let next_round = current_round + 1;
182
183        // Check if the next round is less than the current round in storage.
184        {
185            // Retrieve the storage round.
186            let storage_round = self.current_round();
187            // If the next round is less than the current round in storage, return early with the storage round.
188            if next_round < storage_round {
189                return Ok(storage_round);
190            }
191        }
192
193        // Retrieve the current committee.
194        let current_committee = self.ledger.current_committee()?;
195        // Retrieve the current committee's starting round.
196        let starting_round = current_committee.starting_round();
197        // If the primary is behind the current committee's starting round, sync with the latest block.
198        if next_round < starting_round {
199            // Retrieve the latest block round.
200            let latest_block_round = self.ledger.latest_round();
201            // Log the round sync.
202            guard_info!(
203                self,
204                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
205            );
206            // Sync the round with the latest block.
207            self.sync_round_with_block(latest_block_round);
208            // Return the latest block round.
209            return Ok(latest_block_round);
210        }
211
212        // Update the storage to the next round.
213        self.update_current_round(next_round);
214
215        #[cfg(feature = "metrics")]
216        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
217
218        // Retrieve the storage round.
219        let storage_round = self.current_round();
220        // Retrieve the GC round.
221        let gc_round = self.gc_round();
222        // Ensure the next round matches in storage.
223        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
224        // Ensure the next round is greater than or equal to the GC round.
225        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
226
227        // Log the updated round.
228        guard_info!(self, "Starting round {next_round}...");
229        Ok(next_round)
230    }
231
232    /// Updates the storage to the next round.
233    fn update_current_round(&self, next_round: u64) {
234        // Update the current round.
235        self.current_round.store(next_round, Ordering::SeqCst);
236    }
237
238    /// Update the storage by performing garbage collection based on the next round.
239    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
240        // Fetch the current GC round.
241        let current_gc_round = self.gc_round();
242        // Compute the next GC round.
243        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
244        // Check if storage needs to be garbage collected.
245        if next_gc_round > current_gc_round {
246            // Remove the GC round(s) from storage.
247            for gc_round in current_gc_round..=next_gc_round {
248                // Iterate over the certificates for the GC round.
249                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
250                    // Remove the certificate from storage.
251                    self.remove_certificate(id);
252                }
253            }
254            // Update the GC round.
255            self.gc_round.store(next_gc_round, Ordering::SeqCst);
256        }
257    }
258}
259
260impl<N: Network> Storage<N> {
261    /// Returns `true` if the storage contains the specified `round`.
262    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
263        // Check if the round exists in storage.
264        self.rounds.read().contains_key(&round)
265    }
266
267    /// Returns `true` if the storage contains the specified `certificate ID`.
268    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
269        // Check if the certificate ID exists in storage.
270        self.certificates.read().contains_key(&certificate_id)
271    }
272
273    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
274    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
275        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, a)| a == &author))
276    }
277
278    /// Returns `true` if the storage contains the specified `batch ID`.
279    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
280        // Check if the batch ID exists in storage.
281        self.batch_ids.read().contains_key(&batch_id)
282    }
283
284    /// Returns `true` if the storage contains the specified `transmission ID`.
285    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
286        self.transmissions.contains_transmission(transmission_id.into())
287    }
288
289    /// Returns the transmission for the given `transmission ID`.
290    /// If the transmission ID does not exist in storage, `None` is returned.
291    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
292        self.transmissions.get_transmission(transmission_id.into())
293    }
294
295    /// Returns the round for the given `certificate ID`.
296    /// If the certificate ID does not exist in storage, `None` is returned.
297    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
298        // Get the round.
299        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
300    }
301
302    /// Returns the round for the given `batch ID`.
303    /// If the batch ID does not exist in storage, `None` is returned.
304    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
305        // Get the round.
306        self.batch_ids.read().get(&batch_id).copied()
307    }
308
309    /// Returns the certificate round for the given `certificate ID`.
310    /// If the certificate ID does not exist in storage, `None` is returned.
311    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
312        // Get the batch certificate and return the round.
313        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
314    }
315
316    /// Returns the certificate for the given `certificate ID`.
317    /// If the certificate ID does not exist in storage, `None` is returned.
318    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
319        // Get the batch certificate.
320        self.certificates.read().get(&certificate_id).cloned()
321    }
322
323    /// Returns the certificate for the given `round` and `author`.
324    /// If the round does not exist in storage, `None` is returned.
325    /// If the author for the round does not exist in storage, `None` is returned.
326    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
327        // Retrieve the certificates.
328        if let Some(entries) = self.rounds.read().get(&round) {
329            let certificates = self.certificates.read();
330            entries.iter().find_map(
331                |(certificate_id, a)| if a == &author { certificates.get(certificate_id).cloned() } else { None },
332            )
333        } else {
334            Default::default()
335        }
336    }
337
338    /// Returns the certificates for the given `round`.
339    /// If the round does not exist in storage, an empty set is returned.
340    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
341        // The genesis round does not have batch certificates.
342        if round == 0 {
343            return Default::default();
344        }
345        // Retrieve the certificates.
346        if let Some(entries) = self.rounds.read().get(&round) {
347            let certificates = self.certificates.read();
348            entries.iter().flat_map(|(certificate_id, _)| certificates.get(certificate_id).cloned()).collect()
349        } else {
350            Default::default()
351        }
352    }
353
354    /// Returns the certificate IDs for the given `round`.
355    /// If the round does not exist in storage, an empty set is returned.
356    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
357        // The genesis round does not have batch certificates.
358        if round == 0 {
359            return Default::default();
360        }
361        // Retrieve the certificates.
362        if let Some(entries) = self.rounds.read().get(&round) {
363            entries.iter().map(|(certificate_id, _)| *certificate_id).collect()
364        } else {
365            Default::default()
366        }
367    }
368
369    /// Returns the certificate authors for the given `round`.
370    /// If the round does not exist in storage, an empty set is returned.
371    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
372        // The genesis round does not have batch certificates.
373        if round == 0 {
374            return Default::default();
375        }
376        // Retrieve the certificates.
377        if let Some(entries) = self.rounds.read().get(&round) {
378            entries.iter().map(|(_, author)| *author).collect()
379        } else {
380            Default::default()
381        }
382    }
383
384    /// Returns the certificates that have not yet been included in the ledger.
385    /// Note that the order of this set is by round and then insertion.
386    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
387        // Obtain the read locks.
388        let rounds = self.rounds.read();
389        let certificates = self.certificates.read();
390
391        // Iterate over the rounds.
392        cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
393            .flat_map(|(_, certificates_for_round)| {
394                // Iterate over the certificates for the round.
395                cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _)| {
396                    // Skip the certificate if it already exists in the ledger.
397                    if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
398                        None
399                    } else {
400                        // Add the certificate to the pending certificates.
401                        certificates.get(&certificate_id).cloned()
402                    }
403                })
404            })
405            .collect()
406    }
407
408    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
409    ///
410    /// This method ensures the following invariants:
411    /// - The batch ID does not already exist in storage.
412    /// - The author is a member of the committee for the batch round.
413    /// - The timestamp is within the allowed time range.
414    /// - None of the transmissions are from any past rounds (up to GC).
415    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
416    /// - All previous certificates declared in the certificate exist in storage (up to GC).
417    /// - All previous certificates are for the previous round (i.e. round - 1).
418    /// - All previous certificates contain a unique author.
419    /// - The previous certificates reached the quorum threshold (N - f).
420    pub fn check_batch_header(
421        &self,
422        batch_header: &BatchHeader<N>,
423        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
424        aborted_transmissions: HashSet<TransmissionID<N>>,
425    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
426        // Retrieve the round.
427        let round = batch_header.round();
428        // Retrieve the GC round.
429        let gc_round = self.gc_round();
430        // Construct a GC log message.
431        let gc_log = format!("(gc = {gc_round})");
432
433        // Ensure the batch ID does not already exist in storage.
434        if self.contains_batch(batch_header.batch_id()) {
435            bail!("Batch for round {round} already exists in storage {gc_log}")
436        }
437
438        // Retrieve the committee lookback for the batch round.
439        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
440            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
441        };
442        // Ensure the author is in the committee.
443        if !committee_lookback.is_committee_member(batch_header.author()) {
444            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
445        }
446
447        // Check the timestamp for liveness.
448        check_timestamp_for_liveness(batch_header.timestamp())?;
449
450        // Retrieve the missing transmissions in storage from the given transmissions.
451        let missing_transmissions = self
452            .transmissions
453            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
454            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
455
456        // Compute the previous round.
457        let previous_round = round.saturating_sub(1);
458        // Check if the previous round is within range of the GC round.
459        if previous_round > gc_round {
460            // Retrieve the committee lookback for the previous round.
461            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
462                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
463            };
464            // Ensure the previous round certificates exists in storage.
465            if !self.contains_certificates_for_round(previous_round) {
466                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
467            }
468            // Ensure the number of previous certificate IDs is at or below the number of committee members.
469            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
470                bail!("Too many previous certificates for round {round} {gc_log}")
471            }
472            // Initialize a set of the previous authors.
473            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
474            // Ensure storage contains all declared previous certificates (up to GC).
475            for previous_certificate_id in batch_header.previous_certificate_ids() {
476                // Retrieve the previous certificate.
477                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
478                    bail!(
479                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
480                        fmt_id(previous_certificate_id)
481                    )
482                };
483                // Ensure the previous certificate is for the previous round.
484                if previous_certificate.round() != previous_round {
485                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
486                }
487                // Ensure the previous author is new.
488                if previous_authors.contains(&previous_certificate.author()) {
489                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
490                }
491                // Insert the author of the previous certificate.
492                previous_authors.insert(previous_certificate.author());
493            }
494            // Ensure the previous certificates have reached the quorum threshold.
495            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
496                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
497            }
498        }
499        Ok(missing_transmissions)
500    }
501
502    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
503    ///
504    /// This method ensures the following invariants:
505    /// - The certificate ID does not already exist in storage.
506    /// - The batch ID does not already exist in storage.
507    /// - The author is a member of the committee for the batch round.
508    /// - The author has not already created a certificate for the batch round.
509    /// - The timestamp is within the allowed time range.
510    /// - None of the transmissions are from any past rounds (up to GC).
511    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
512    /// - All previous certificates declared in the certificate exist in storage (up to GC).
513    /// - All previous certificates are for the previous round (i.e. round - 1).
514    /// - The previous certificates reached the quorum threshold (N - f).
515    /// - The timestamps from the signers are all within the allowed time range.
516    /// - The signers have reached the quorum threshold (N - f).
517    pub fn check_certificate(
518        &self,
519        certificate: &BatchCertificate<N>,
520        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
521        aborted_transmissions: HashSet<TransmissionID<N>>,
522    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
523        // Retrieve the round.
524        let round = certificate.round();
525        // Retrieve the GC round.
526        let gc_round = self.gc_round();
527        // Construct a GC log message.
528        let gc_log = format!("(gc = {gc_round})");
529
530        // Ensure the certificate ID does not already exist in storage.
531        if self.contains_certificate(certificate.id()) {
532            bail!("Certificate for round {round} already exists in storage {gc_log}")
533        }
534
535        // Ensure the storage does not already contain a certificate for this author in this round.
536        if self.contains_certificate_in_round_from(round, certificate.author()) {
537            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
538        }
539
540        // Ensure the batch header is well-formed.
541        let missing_transmissions =
542            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
543
544        // Check the timestamp for liveness.
545        check_timestamp_for_liveness(certificate.timestamp())?;
546
547        // Retrieve the committee lookback for the batch round.
548        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
549            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
550        };
551
552        // Initialize a set of the signers.
553        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
554        // Append the batch author.
555        signers.insert(certificate.author());
556
557        // Iterate over the signatures.
558        for signature in certificate.signatures() {
559            // Retrieve the signer.
560            let signer = signature.to_address();
561            // Ensure the signer is in the committee.
562            if !committee_lookback.is_committee_member(signer) {
563                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
564            }
565            // Append the signer.
566            signers.insert(signer);
567        }
568
569        // Ensure the signatures have reached the quorum threshold.
570        if !committee_lookback.is_quorum_threshold_reached(&signers) {
571            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
572        }
573        Ok(missing_transmissions)
574    }
575
576    /// Inserts the given `certificate` into storage.
577    ///
578    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
579    ///
580    /// This method ensures the following invariants:
581    /// - The certificate ID does not already exist in storage.
582    /// - The batch ID does not already exist in storage.
583    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
584    /// - All previous certificates declared in the certificate exist in storage (up to GC).
585    /// - All previous certificates are for the previous round (i.e. round - 1).
586    /// - The previous certificates reached the quorum threshold (N - f).
587    pub fn insert_certificate(
588        &self,
589        certificate: BatchCertificate<N>,
590        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
591        aborted_transmissions: HashSet<TransmissionID<N>>,
592    ) -> Result<()> {
593        // Ensure the certificate round is above the GC round.
594        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
595        // Ensure the certificate and its transmissions are valid.
596        let missing_transmissions =
597            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
598        // Insert the certificate into storage.
599        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
600        Ok(())
601    }
602
603    /// Inserts the given `certificate` into storage.
604    ///
605    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
606    ///
607    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
608    fn insert_certificate_atomic(
609        &self,
610        certificate: BatchCertificate<N>,
611        aborted_transmission_ids: HashSet<TransmissionID<N>>,
612        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
613    ) {
614        // Retrieve the round.
615        let round = certificate.round();
616        // Retrieve the certificate ID.
617        let certificate_id = certificate.id();
618        // Retrieve the author of the batch.
619        let author = certificate.author();
620
621        // Insert the round to certificate ID entry.
622        self.rounds.write().entry(round).or_default().insert((certificate_id, author));
623        // Obtain the certificate's transmission ids.
624        let transmission_ids = certificate.transmission_ids().clone();
625        // Insert the certificate.
626        self.certificates.write().insert(certificate_id, certificate);
627        // Insert the batch ID.
628        self.batch_ids.write().insert(certificate_id, round);
629        // Insert the certificate ID for each of the transmissions into storage.
630        self.transmissions.insert_transmissions(
631            certificate_id,
632            transmission_ids,
633            aborted_transmission_ids,
634            missing_transmissions,
635        );
636    }
637
638    /// Removes the given `certificate ID` from storage.
639    ///
640    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
641    ///
642    /// If the certificate was successfully removed, `true` is returned.
643    /// If the certificate did not exist in storage, `false` is returned.
644    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
645        // Retrieve the certificate.
646        let Some(certificate) = self.get_certificate(certificate_id) else {
647            guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
648            return false;
649        };
650        // Retrieve the round.
651        let round = certificate.round();
652        // Compute the author of the batch.
653        let author = certificate.author();
654
655        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
656        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
657        //  It will be better to write tests that compare the union of the sets.
658
659        // Update the round.
660        match self.rounds.write().entry(round) {
661            Entry::Occupied(mut entry) => {
662                // Remove the round to certificate ID entry.
663                entry.get_mut().swap_remove(&(certificate_id, author));
664                // If the round is empty, remove it.
665                if entry.get().is_empty() {
666                    entry.swap_remove();
667                }
668            }
669            Entry::Vacant(_) => {}
670        }
671        // Remove the certificate.
672        self.certificates.write().swap_remove(&certificate_id);
673        // Remove the batch ID.
674        self.batch_ids.write().swap_remove(&certificate_id);
675        // Remove the transmission entries in the certificate from storage.
676        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
677        // Return successfully.
678        true
679    }
680}
681
682impl<N: Network> Storage<N> {
683    /// Syncs the current height with the block.
684    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
685        // If the block height is greater than the current height in storage, sync the height.
686        if next_height > self.current_height() {
687            // Update the current height in storage.
688            self.current_height.store(next_height, Ordering::SeqCst);
689        }
690    }
691
692    /// Syncs the current round with the block.
693    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
694        // Retrieve the current round in the block.
695        let next_round = next_round.max(1);
696        // If the round in the block is greater than the current round in storage, sync the round.
697        if next_round > self.current_round() {
698            // Update the current round in storage.
699            self.update_current_round(next_round);
700            // Log the updated round.
701            guard_info!(self, "Synced to round {next_round}...");
702        }
703    }
704
705    /// Syncs the batch certificate with the block.
706    pub(crate) fn sync_certificate_with_block(
707        &self,
708        block: &Block<N>,
709        certificate: BatchCertificate<N>,
710        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
711    ) {
712        // Skip if the certificate round is below the GC round.
713        if certificate.round() <= self.gc_round() {
714            return;
715        }
716        // If the certificate ID already exists in storage, skip it.
717        if self.contains_certificate(certificate.id()) {
718            return;
719        }
720        // Retrieve the transmissions for the certificate.
721        let mut missing_transmissions = HashMap::new();
722
723        // Retrieve the aborted transmissions for the certificate.
724        let mut aborted_transmissions = HashSet::new();
725
726        // Track the block's aborted solutions and transactions.
727        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
728        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
729
730        // Iterate over the transmission IDs.
731        for transmission_id in certificate.transmission_ids() {
732            // If the transmission ID already exists in the map, skip it.
733            if missing_transmissions.contains_key(transmission_id) {
734                continue;
735            }
736            // If the transmission ID exists in storage, skip it.
737            if self.contains_transmission(*transmission_id) {
738                continue;
739            }
740            // Retrieve the transmission.
741            match transmission_id {
742                TransmissionID::Ratification => (),
743                TransmissionID::Solution(solution_id, _) => {
744                    // Retrieve the solution.
745                    match block.get_solution(solution_id) {
746                        // Insert the solution.
747                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
748                        // Otherwise, try to load the solution from the ledger.
749                        None => match self.ledger.get_solution(solution_id) {
750                            // Insert the solution.
751                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
752                            // Check if the solution is in the aborted solutions.
753                            Err(_) => {
754                                // Insert the aborted solution if it exists in the block or ledger.
755                                match aborted_solutions.contains(solution_id)
756                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
757                                {
758                                    true => {
759                                        aborted_transmissions.insert(*transmission_id);
760                                    }
761                                    false => {
762                                        guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
763                                    }
764                                }
765                                continue;
766                            }
767                        },
768                    };
769                }
770                TransmissionID::Transaction(transaction_id, _) => {
771                    // Retrieve the transaction.
772                    match unconfirmed_transactions.get(transaction_id) {
773                        // Insert the transaction.
774                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
775                        // Otherwise, try to load the unconfirmed transaction from the ledger.
776                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
777                            // Insert the transaction.
778                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
779                            // Check if the transaction is in the aborted transactions.
780                            Err(_) => {
781                                // Insert the aborted transaction if it exists in the block or ledger.
782                                match aborted_transactions.contains(transaction_id)
783                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
784                                {
785                                    true => {
786                                        aborted_transmissions.insert(*transmission_id);
787                                    }
788                                    false => guard_warn!(
789                                        self,
790                                        "Missing transaction {transaction_id} in block {}",
791                                        block.height()
792                                    ),
793                                }
794                                continue;
795                            }
796                        },
797                    };
798                }
799            }
800        }
801        // Insert the batch certificate into storage.
802        let certificate_id = fmt_id(certificate.id());
803        guard_debug!(
804            self,
805            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
806            certificate.round(),
807            certificate.transmission_ids().len()
808        );
809        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
810            guard_error!(
811                self,
812                "Failed to insert certificate '{certificate_id}' from block {} - {error}",
813                block.height()
814            );
815        }
816    }
817}
818
819#[cfg(test)]
820impl<N: Network> Storage<N> {
821    /// Returns the ledger service.
822    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
823        &self.ledger
824    }
825
826    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
827    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Address<N>)>)> {
828        self.rounds.read().clone().into_iter()
829    }
830
831    /// Returns an iterator over the `(certificate ID, certificate)` entries.
832    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
833        self.certificates.read().clone().into_iter()
834    }
835
836    /// Returns an iterator over the `(batch ID, round)` entries.
837    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
838        self.batch_ids.read().clone().into_iter()
839    }
840
841    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
842    pub fn transmissions_iter(
843        &self,
844    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
845        self.transmissions.as_hashmap().into_iter()
846    }
847
848    /// Inserts the given `certificate` into storage.
849    ///
850    /// Note: Do NOT use this in production. This is for **testing only**.
851    #[cfg(test)]
852    #[doc(hidden)]
853    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
854        // Retrieve the round.
855        let round = certificate.round();
856        // Retrieve the certificate ID.
857        let certificate_id = certificate.id();
858        // Retrieve the author of the batch.
859        let author = certificate.author();
860
861        // Insert the round to certificate ID entry.
862        self.rounds.write().entry(round).or_default().insert((certificate_id, author));
863        // Obtain the certificate's transmission ids.
864        let transmission_ids = certificate.transmission_ids().clone();
865        // Insert the certificate.
866        self.certificates.write().insert(certificate_id, certificate);
867        // Insert the batch ID.
868        self.batch_ids.write().insert(certificate_id, round);
869
870        // Construct the dummy missing transmissions (for testing purposes).
871        let missing_transmissions = transmission_ids
872            .iter()
873            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
874            .collect::<HashMap<_, _>>();
875        // Insert the certificate ID for each of the transmissions into storage.
876        self.transmissions.insert_transmissions(
877            certificate_id,
878            transmission_ids,
879            Default::default(),
880            missing_transmissions,
881        );
882    }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887    use super::*;
888    use amareleo_node_bft_ledger_service::MockLedgerService;
889    use amareleo_node_bft_storage_service::BFTMemoryService;
890    use snarkvm::{
891        ledger::narwhal::{Data, batch_certificate::test_helpers::sample_batch_certificate_for_round_with_committee},
892        prelude::{Rng, TestRng},
893    };
894
895    use ::bytes::Bytes;
896    use indexmap::indexset;
897
898    type CurrentNetwork = snarkvm::prelude::MainnetV0;
899
900    /// Asserts that the storage matches the expected layout.
901    pub fn assert_storage<N: Network>(
902        storage: &Storage<N>,
903        rounds: &[(u64, IndexSet<(Field<N>, Address<N>)>)],
904        certificates: &[(Field<N>, BatchCertificate<N>)],
905        batch_ids: &[(Field<N>, u64)],
906        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
907    ) {
908        // Ensure the rounds are well-formed.
909        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
910        // Ensure the certificates are well-formed.
911        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
912        // Ensure the batch IDs are well-formed.
913        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
914        // Ensure the transmissions are well-formed.
915        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
916    }
917
918    /// Samples a random transmission.
919    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
920        // Sample random fake solution bytes.
921        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
922        // Sample random fake transaction bytes.
923        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
924        // Sample a random transmission.
925        match rng.gen::<bool>() {
926            true => Transmission::Solution(s(rng)),
927            false => Transmission::Transaction(t(rng)),
928        }
929    }
930
931    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
932    pub(crate) fn sample_transmissions(
933        certificate: &BatchCertificate<CurrentNetwork>,
934        rng: &mut TestRng,
935    ) -> (
936        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
937        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
938    ) {
939        // Retrieve the certificate ID.
940        let certificate_id = certificate.id();
941
942        let mut missing_transmissions = HashMap::new();
943        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
944        for transmission_id in certificate.transmission_ids() {
945            // Initialize the transmission.
946            let transmission = sample_transmission(rng);
947            // Update the missing transmissions.
948            missing_transmissions.insert(*transmission_id, transmission.clone());
949            // Update the transmissions map.
950            transmissions
951                .entry(*transmission_id)
952                .or_insert((transmission, Default::default()))
953                .1
954                .insert(certificate_id);
955        }
956        (missing_transmissions, transmissions)
957    }
958
959    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
960
961    #[test]
962    fn test_certificate_insert_remove() {
963        let rng = &mut TestRng::default();
964
965        // Sample a committee.
966        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
967        // Initialize the ledger.
968        let ledger = Arc::new(MockLedgerService::new(committee));
969        // Initialize the storage.
970        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
971
972        // Ensure the storage is empty.
973        assert_storage(&storage, &[], &[], &[], &Default::default());
974
975        // Create a new certificate.
976        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
977        // Retrieve the certificate ID.
978        let certificate_id = certificate.id();
979        // Retrieve the round.
980        let round = certificate.round();
981        // Retrieve the author of the batch.
982        let author = certificate.author();
983
984        // Construct the sample 'transmissions'.
985        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
986
987        // Insert the certificate.
988        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
989        // Ensure the certificate exists in storage.
990        assert!(storage.contains_certificate(certificate_id));
991        // Ensure the certificate is stored in the correct round.
992        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
993        // Ensure the certificate is stored for the correct round and author.
994        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
995
996        // Check that the underlying storage representation is correct.
997        {
998            // Construct the expected layout for 'rounds'.
999            let rounds = [(round, indexset! { (certificate_id, author) })];
1000            // Construct the expected layout for 'certificates'.
1001            let certificates = [(certificate_id, certificate.clone())];
1002            // Construct the expected layout for 'batch_ids'.
1003            let batch_ids = [(certificate_id, round)];
1004            // Assert the storage is well-formed.
1005            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1006        }
1007
1008        // Retrieve the certificate.
1009        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1010        // Ensure the retrieved certificate is the same as the inserted certificate.
1011        assert_eq!(certificate, candidate_certificate);
1012
1013        // Remove the certificate.
1014        assert!(storage.remove_certificate(certificate_id));
1015        // Ensure the certificate does not exist in storage.
1016        assert!(!storage.contains_certificate(certificate_id));
1017        // Ensure the certificate is no longer stored in the round.
1018        assert!(storage.get_certificates_for_round(round).is_empty());
1019        // Ensure the certificate is no longer stored for the round and author.
1020        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1021        // Ensure the storage is empty.
1022        assert_storage(&storage, &[], &[], &[], &Default::default());
1023    }
1024
1025    #[test]
1026    fn test_certificate_duplicate() {
1027        let rng = &mut TestRng::default();
1028
1029        // Sample a committee.
1030        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1031        // Initialize the ledger.
1032        let ledger = Arc::new(MockLedgerService::new(committee));
1033        // Initialize the storage.
1034        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1035
1036        // Ensure the storage is empty.
1037        assert_storage(&storage, &[], &[], &[], &Default::default());
1038
1039        // Create a new certificate.
1040        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1041        // Retrieve the certificate ID.
1042        let certificate_id = certificate.id();
1043        // Retrieve the round.
1044        let round = certificate.round();
1045        // Retrieve the author of the batch.
1046        let author = certificate.author();
1047
1048        // Construct the expected layout for 'rounds'.
1049        let rounds = [(round, indexset! { (certificate_id, author) })];
1050        // Construct the expected layout for 'certificates'.
1051        let certificates = [(certificate_id, certificate.clone())];
1052        // Construct the expected layout for 'batch_ids'.
1053        let batch_ids = [(certificate_id, round)];
1054        // Construct the sample 'transmissions'.
1055        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1056
1057        // Insert the certificate.
1058        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1059        // Ensure the certificate exists in storage.
1060        assert!(storage.contains_certificate(certificate_id));
1061        // Check that the underlying storage representation is correct.
1062        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1063
1064        // Insert the certificate again - without any missing transmissions.
1065        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1066        // Ensure the certificate exists in storage.
1067        assert!(storage.contains_certificate(certificate_id));
1068        // Check that the underlying storage representation remains unchanged.
1069        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1070
1071        // Insert the certificate again - with all of the original missing transmissions.
1072        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1073        // Ensure the certificate exists in storage.
1074        assert!(storage.contains_certificate(certificate_id));
1075        // Check that the underlying storage representation remains unchanged.
1076        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1077    }
1078
1079    /// Verify that `insert_certificate` rejects certs with less edges than required.
1080    #[test]
1081    fn test_invalid_certificate_insufficient_previous_certs() {
1082        let rng = &mut TestRng::default();
1083
1084        // Sample a committee.
1085        let (committee, private_keys) =
1086            snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1087        // Initialize the ledger.
1088        let ledger = Arc::new(MockLedgerService::new(committee));
1089        // Initialize the storage.
1090        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1091
1092        // Go through many rounds of valid certificates and ensure they're accepted.
1093        let mut previous_certs = IndexSet::default();
1094
1095        for round in 1..=6 {
1096            let mut new_certs = IndexSet::default();
1097
1098            // Generate one cert per validator
1099            for private_key in private_keys.iter() {
1100                let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1101
1102                let certificate = sample_batch_certificate_for_round_with_committee(
1103                    round,
1104                    previous_certs.clone(),
1105                    private_key,
1106                    &other_keys,
1107                    rng,
1108                );
1109
1110                // Construct the sample 'transmissions'.
1111                let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1112                let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1113
1114                if round <= 5 {
1115                    new_certs.insert(certificate.id());
1116                    storage
1117                        .insert_certificate(certificate, transmissions, Default::default())
1118                        .expect("Valid certificate rejected");
1119                } else {
1120                    assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1121                }
1122            }
1123
1124            if round < 5 {
1125                previous_certs = new_certs;
1126            } else {
1127                // Remove more than half of the previous certs.
1128                previous_certs = new_certs.into_iter().skip(6).collect();
1129            }
1130        }
1131    }
1132
1133    /// Verify that `insert_certificate` rejects certs that do not increment the round number.
1134    #[test]
1135    fn test_invalid_certificate_wrong_round_number() {
1136        let rng = &mut TestRng::default();
1137
1138        // Sample a committee.
1139        let (committee, private_keys) =
1140            snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1141        // Initialize the ledger.
1142        let ledger = Arc::new(MockLedgerService::new(committee));
1143        // Initialize the storage.
1144        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1145
1146        // Go through many rounds of valid certificates and ensure they're accepted.
1147        let mut previous_certs = IndexSet::default();
1148
1149        for round in 1..=6 {
1150            let mut new_certs = IndexSet::default();
1151
1152            // Generate one cert per validator
1153            for private_key in private_keys.iter() {
1154                let cert_round = round.min(5); // In the sixth round, do not increment
1155                let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1156
1157                let certificate = sample_batch_certificate_for_round_with_committee(
1158                    cert_round,
1159                    previous_certs.clone(),
1160                    private_key,
1161                    &other_keys,
1162                    rng,
1163                );
1164
1165                // Construct the sample 'transmissions'.
1166                let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1167                let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1168
1169                if round <= 5 {
1170                    new_certs.insert(certificate.id());
1171                    storage
1172                        .insert_certificate(certificate, transmissions, Default::default())
1173                        .expect("Valid certificate rejected");
1174                } else {
1175                    assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1176                }
1177            }
1178
1179            if round < 5 {
1180                previous_certs = new_certs;
1181            } else {
1182                // Remove more than half of the previous certs.
1183                previous_certs = new_certs.into_iter().skip(6).collect();
1184            }
1185        }
1186    }
1187}
1188
1189#[cfg(test)]
1190pub mod prop_tests {
1191    use super::*;
1192    use crate::helpers::{now, storage::tests::assert_storage};
1193    use amareleo_node_bft_ledger_service::MockLedgerService;
1194    use amareleo_node_bft_storage_service::BFTMemoryService;
1195    use snarkvm::{
1196        ledger::{
1197            committee::prop_tests::{CommitteeContext, ValidatorSet},
1198            narwhal::{BatchHeader, Data},
1199            puzzle::SolutionID,
1200        },
1201        prelude::{Signature, Uniform},
1202    };
1203
1204    use ::bytes::Bytes;
1205    use indexmap::indexset;
1206    use proptest::{
1207        collection,
1208        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1209        prop_oneof,
1210        sample::{Selector, size_range},
1211        test_runner::TestRng,
1212    };
1213    use rand::{CryptoRng, Error, Rng, RngCore};
1214    use std::fmt::Debug;
1215    use test_strategy::proptest;
1216
1217    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1218
1219    impl Arbitrary for Storage<CurrentNetwork> {
1220        type Parameters = CommitteeContext;
1221        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1222
1223        fn arbitrary() -> Self::Strategy {
1224            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1225                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1226                    let ledger = Arc::new(MockLedgerService::new(committee));
1227                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1228                })
1229                .boxed()
1230        }
1231
1232        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1233            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1234                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1235                    let ledger = Arc::new(MockLedgerService::new(committee));
1236                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1237                })
1238                .boxed()
1239        }
1240    }
1241
1242    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1243    #[derive(Debug)]
1244    pub struct CryptoTestRng(TestRng);
1245
1246    impl Arbitrary for CryptoTestRng {
1247        type Parameters = ();
1248        type Strategy = BoxedStrategy<CryptoTestRng>;
1249
1250        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1251            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1252        }
1253    }
1254    impl RngCore for CryptoTestRng {
1255        fn next_u32(&mut self) -> u32 {
1256            self.0.next_u32()
1257        }
1258
1259        fn next_u64(&mut self) -> u64 {
1260            self.0.next_u64()
1261        }
1262
1263        fn fill_bytes(&mut self, dest: &mut [u8]) {
1264            self.0.fill_bytes(dest);
1265        }
1266
1267        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1268            self.0.try_fill_bytes(dest)
1269        }
1270    }
1271
1272    impl CryptoRng for CryptoTestRng {}
1273
1274    #[derive(Debug, Clone)]
1275    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1276
1277    impl Arbitrary for AnyTransmission {
1278        type Parameters = ();
1279        type Strategy = BoxedStrategy<AnyTransmission>;
1280
1281        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1282            any_transmission().prop_map(AnyTransmission).boxed()
1283        }
1284    }
1285
1286    #[derive(Debug, Clone)]
1287    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1288
1289    impl Arbitrary for AnyTransmissionID {
1290        type Parameters = ();
1291        type Strategy = BoxedStrategy<AnyTransmissionID>;
1292
1293        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1294            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1295        }
1296    }
1297
1298    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1299        prop_oneof![
1300            (collection::vec(any::<u8>(), 512..=512))
1301                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1302            (collection::vec(any::<u8>(), 2048..=2048))
1303                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1304        ]
1305        .boxed()
1306    }
1307
1308    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1309        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1310    }
1311
1312    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1313        Just(0)
1314            .prop_perturb(|_, rng| {
1315                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1316            })
1317            .boxed()
1318    }
1319
1320    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1321        prop_oneof![
1322            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1323                id,
1324                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1325            )),
1326            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1327                id,
1328                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1329            )),
1330        ]
1331        .boxed()
1332    }
1333
1334    pub fn sign_batch_header<R: Rng + CryptoRng>(
1335        validator_set: &ValidatorSet,
1336        batch_header: &BatchHeader<CurrentNetwork>,
1337        rng: &mut R,
1338    ) -> IndexSet<Signature<CurrentNetwork>> {
1339        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1340        for validator in validator_set.0.iter() {
1341            let private_key = validator.private_key;
1342            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1343        }
1344        signatures
1345    }
1346
1347    #[proptest]
1348    fn test_certificate_duplicate(
1349        context: CommitteeContext,
1350        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1351        mut rng: CryptoTestRng,
1352        selector: Selector,
1353    ) {
1354        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1355        let committee_id = committee.id();
1356
1357        // Initialize the storage.
1358        let ledger = Arc::new(MockLedgerService::new(committee));
1359        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1360
1361        // Ensure the storage is empty.
1362        assert_storage(&storage, &[], &[], &[], &Default::default());
1363
1364        // Create a new certificate.
1365        let signer = selector.select(&validators);
1366
1367        let mut transmission_map = IndexMap::new();
1368
1369        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1370            transmission_map.insert(*id, t.clone());
1371        }
1372
1373        let batch_header = BatchHeader::new(
1374            &signer.private_key,
1375            0,
1376            now(),
1377            committee_id,
1378            transmission_map.keys().cloned().collect(),
1379            Default::default(),
1380            &mut rng,
1381        )
1382        .unwrap();
1383
1384        // Remove the author from the validator set passed to create the batch
1385        // certificate, the author should not sign their own batch.
1386        let mut validators = validators.clone();
1387        validators.remove(signer);
1388
1389        let certificate = BatchCertificate::from(
1390            batch_header.clone(),
1391            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1392        )
1393        .unwrap();
1394
1395        // Retrieve the certificate ID.
1396        let certificate_id = certificate.id();
1397        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1398        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1399            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1400        }
1401
1402        // Retrieve the round.
1403        let round = certificate.round();
1404        // Retrieve the author of the batch.
1405        let author = certificate.author();
1406
1407        // Construct the expected layout for 'rounds'.
1408        let rounds = [(round, indexset! { (certificate_id, author) })];
1409        // Construct the expected layout for 'certificates'.
1410        let certificates = [(certificate_id, certificate.clone())];
1411        // Construct the expected layout for 'batch_ids'.
1412        let batch_ids = [(certificate_id, round)];
1413
1414        // Insert the certificate.
1415        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1416            transmission_map.into_iter().collect();
1417        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1418        // Ensure the certificate exists in storage.
1419        assert!(storage.contains_certificate(certificate_id));
1420        // Check that the underlying storage representation is correct.
1421        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1422
1423        // Insert the certificate again - without any missing transmissions.
1424        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1425        // Ensure the certificate exists in storage.
1426        assert!(storage.contains_certificate(certificate_id));
1427        // Check that the underlying storage representation remains unchanged.
1428        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1429
1430        // Insert the certificate again - with all of the original missing transmissions.
1431        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1432        // Ensure the certificate exists in storage.
1433        assert!(storage.contains_certificate(certificate_id));
1434        // Check that the underlying storage representation remains unchanged.
1435        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1436    }
1437}