amareleo_node_bft/helpers/
storage.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21    ledger::{
22        block::{Block, Transaction},
23        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24    },
25    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26    utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30use parking_lot::RwLock;
31use rayon::iter::{IntoParallelIterator, ParallelIterator};
32use std::{
33    collections::{HashMap, HashSet},
34    sync::{
35        Arc,
36        atomic::{AtomicU32, AtomicU64, Ordering},
37    },
38};
39use tracing::subscriber::DefaultGuard;
40
41#[derive(Clone, Debug)]
42pub struct Storage<N: Network>(Arc<StorageInner<N>>);
43
44impl<N: Network> std::ops::Deref for Storage<N> {
45    type Target = Arc<StorageInner<N>>;
46
47    fn deref(&self) -> &Self::Target {
48        &self.0
49    }
50}
51
52impl<N: Network> TracingHandlerGuard for Storage<N> {
53    /// Retruns tracing guard
54    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
55        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
56    }
57}
58
59/// The storage for the memory pool.
60///
61/// The storage is used to store the following:
62/// - `current_height` tracker.
63/// - `current_round` tracker.
64/// - `round` to `(certificate ID, batch ID, author)` entries.
65/// - `certificate ID` to `certificate` entries.
66/// - `batch ID` to `round` entries.
67/// - `transmission ID` to `(transmission, certificate IDs)` entries.
68///
69/// The chain of events is as follows:
70/// 1. A `transmission` is received.
71/// 2. After a `batch` is ready to be stored:
72///   - The `certificate` is inserted, triggering updates to the
73///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
74///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
75///   - The certificate ID is inserted into the `transmissions` map.
76/// 3. After a `round` reaches quorum threshold:
77///  - The next round is inserted into the `current_round`.
78#[derive(Debug)]
79pub struct StorageInner<N: Network> {
80    /// The ledger service.
81    ledger: Arc<dyn LedgerService<N>>,
82    /* Once per block */
83    /// The current height.
84    current_height: AtomicU32,
85    /* Once per round */
86    /// The current round.
87    current_round: AtomicU64,
88    /// The `round` for which garbage collection has occurred **up to** (inclusive).
89    gc_round: AtomicU64,
90    /// The maximum number of rounds to keep in storage.
91    max_gc_rounds: u64,
92    /* Once per batch */
93    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
94    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
95    /// The map of `certificate ID` to `certificate`.
96    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
97    /// The map of `batch ID` to `round`.
98    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
99    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
100    transmissions: Arc<dyn StorageService<N>>,
101    /// Tracing handle
102    tracing: Option<TracingHandler>,
103}
104
105impl<N: Network> Storage<N> {
106    /// Initializes a new instance of storage.
107    pub fn new(
108        ledger: Arc<dyn LedgerService<N>>,
109        transmissions: Arc<dyn StorageService<N>>,
110        max_gc_rounds: u64,
111        tracing: Option<TracingHandler>,
112    ) -> Self {
113        // Retrieve the current committee.
114        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
115        // Retrieve the current round.
116        let current_round = committee.starting_round().max(1);
117
118        // Return the storage.
119        let storage = Self(Arc::new(StorageInner {
120            ledger,
121            current_height: Default::default(),
122            current_round: Default::default(),
123            gc_round: Default::default(),
124            max_gc_rounds,
125            rounds: Default::default(),
126            certificates: Default::default(),
127            batch_ids: Default::default(),
128            transmissions,
129            tracing,
130        }));
131        // Update the storage to the current round.
132        storage.update_current_round(current_round);
133        // Perform GC on the current round.
134        storage.garbage_collect_certificates(current_round);
135        // Return the storage.
136        storage
137    }
138}
139
140impl<N: Network> Storage<N> {
141    /// Returns the current height.
142    pub fn current_height(&self) -> u32 {
143        // Get the current height.
144        self.current_height.load(Ordering::SeqCst)
145    }
146}
147
148impl<N: Network> Storage<N> {
149    /// Returns the current round.
150    pub fn current_round(&self) -> u64 {
151        // Get the current round.
152        self.current_round.load(Ordering::SeqCst)
153    }
154
155    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
156    pub fn gc_round(&self) -> u64 {
157        // Get the GC round.
158        self.gc_round.load(Ordering::SeqCst)
159    }
160
161    /// Returns the maximum number of rounds to keep in storage.
162    pub fn max_gc_rounds(&self) -> u64 {
163        self.max_gc_rounds
164    }
165
166    /// Increments storage to the next round, updating the current round.
167    /// Note: This method is only called once per round, upon certification of the primary's batch.
168    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
169        // Determine the next round.
170        let next_round = current_round + 1;
171
172        // Check if the next round is less than the current round in storage.
173        {
174            // Retrieve the storage round.
175            let storage_round = self.current_round();
176            // If the next round is less than the current round in storage, return early with the storage round.
177            if next_round < storage_round {
178                return Ok(storage_round);
179            }
180        }
181
182        // Retrieve the current committee.
183        let current_committee = self.ledger.current_committee()?;
184        // Retrieve the current committee's starting round.
185        let starting_round = current_committee.starting_round();
186        // If the primary is behind the current committee's starting round, sync with the latest block.
187        if next_round < starting_round {
188            // Retrieve the latest block round.
189            let latest_block_round = self.ledger.latest_round();
190            // Log the round sync.
191            guard_info!(
192                self,
193                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
194            );
195            // Sync the round with the latest block.
196            self.sync_round_with_block(latest_block_round);
197            // Return the latest block round.
198            return Ok(latest_block_round);
199        }
200
201        // Update the storage to the next round.
202        self.update_current_round(next_round);
203
204        #[cfg(feature = "metrics")]
205        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
206
207        // Retrieve the storage round.
208        let storage_round = self.current_round();
209        // Retrieve the GC round.
210        let gc_round = self.gc_round();
211        // Ensure the next round matches in storage.
212        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
213        // Ensure the next round is greater than or equal to the GC round.
214        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
215
216        // Log the updated round.
217        guard_info!(self, "Starting round {next_round}...");
218        Ok(next_round)
219    }
220
221    /// Updates the storage to the next round.
222    fn update_current_round(&self, next_round: u64) {
223        // Update the current round.
224        self.current_round.store(next_round, Ordering::SeqCst);
225    }
226
227    /// Update the storage by performing garbage collection based on the next round.
228    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
229        // Fetch the current GC round.
230        let current_gc_round = self.gc_round();
231        // Compute the next GC round.
232        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
233        // Check if storage needs to be garbage collected.
234        if next_gc_round > current_gc_round {
235            // Remove the GC round(s) from storage.
236            for gc_round in current_gc_round..=next_gc_round {
237                // Iterate over the certificates for the GC round.
238                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
239                    // Remove the certificate from storage.
240                    self.remove_certificate(id);
241                }
242            }
243            // Update the GC round.
244            self.gc_round.store(next_gc_round, Ordering::SeqCst);
245        }
246    }
247}
248
249impl<N: Network> Storage<N> {
250    /// Returns `true` if the storage contains the specified `round`.
251    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
252        // Check if the round exists in storage.
253        self.rounds.read().contains_key(&round)
254    }
255
256    /// Returns `true` if the storage contains the specified `certificate ID`.
257    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
258        // Check if the certificate ID exists in storage.
259        self.certificates.read().contains_key(&certificate_id)
260    }
261
262    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
263    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
264        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
265    }
266
267    /// Returns `true` if the storage contains the specified `batch ID`.
268    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
269        // Check if the batch ID exists in storage.
270        self.batch_ids.read().contains_key(&batch_id)
271    }
272
273    /// Returns `true` if the storage contains the specified `transmission ID`.
274    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
275        self.transmissions.contains_transmission(transmission_id.into())
276    }
277
278    /// Returns the transmission for the given `transmission ID`.
279    /// If the transmission ID does not exist in storage, `None` is returned.
280    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
281        self.transmissions.get_transmission(transmission_id.into())
282    }
283
284    /// Returns the round for the given `certificate ID`.
285    /// If the certificate ID does not exist in storage, `None` is returned.
286    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
287        // Get the round.
288        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
289    }
290
291    /// Returns the round for the given `batch ID`.
292    /// If the batch ID does not exist in storage, `None` is returned.
293    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
294        // Get the round.
295        self.batch_ids.read().get(&batch_id).copied()
296    }
297
298    /// Returns the certificate round for the given `certificate ID`.
299    /// If the certificate ID does not exist in storage, `None` is returned.
300    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
301        // Get the batch certificate and return the round.
302        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
303    }
304
305    /// Returns the certificate for the given `certificate ID`.
306    /// If the certificate ID does not exist in storage, `None` is returned.
307    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
308        // Get the batch certificate.
309        self.certificates.read().get(&certificate_id).cloned()
310    }
311
312    /// Returns the certificate for the given `round` and `author`.
313    /// If the round does not exist in storage, `None` is returned.
314    /// If the author for the round does not exist in storage, `None` is returned.
315    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
316        // Retrieve the certificates.
317        if let Some(entries) = self.rounds.read().get(&round) {
318            let certificates = self.certificates.read();
319            entries.iter().find_map(
320                |(certificate_id, _, a)| {
321                    if a == &author { certificates.get(certificate_id).cloned() } else { None }
322                },
323            )
324        } else {
325            Default::default()
326        }
327    }
328
329    /// Returns the certificates for the given `round`.
330    /// If the round does not exist in storage, an empty set is returned.
331    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
332        // The genesis round does not have batch certificates.
333        if round == 0 {
334            return Default::default();
335        }
336        // Retrieve the certificates.
337        if let Some(entries) = self.rounds.read().get(&round) {
338            let certificates = self.certificates.read();
339            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
340        } else {
341            Default::default()
342        }
343    }
344
345    /// Returns the certificate IDs for the given `round`.
346    /// If the round does not exist in storage, an empty set is returned.
347    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
348        // The genesis round does not have batch certificates.
349        if round == 0 {
350            return Default::default();
351        }
352        // Retrieve the certificates.
353        if let Some(entries) = self.rounds.read().get(&round) {
354            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
355        } else {
356            Default::default()
357        }
358    }
359
360    /// Returns the certificate authors for the given `round`.
361    /// If the round does not exist in storage, an empty set is returned.
362    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
363        // The genesis round does not have batch certificates.
364        if round == 0 {
365            return Default::default();
366        }
367        // Retrieve the certificates.
368        if let Some(entries) = self.rounds.read().get(&round) {
369            entries.iter().map(|(_, _, author)| *author).collect()
370        } else {
371            Default::default()
372        }
373    }
374
375    /// Returns the certificates that have not yet been included in the ledger.
376    /// Note that the order of this set is by round and then insertion.
377    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
378        // Obtain the read locks.
379        let rounds = self.rounds.read();
380        let certificates = self.certificates.read();
381
382        // Iterate over the rounds.
383        cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
384            .flat_map(|(_, certificates_for_round)| {
385                // Iterate over the certificates for the round.
386                cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
387                    // Skip the certificate if it already exists in the ledger.
388                    if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
389                        None
390                    } else {
391                        // Add the certificate to the pending certificates.
392                        certificates.get(&certificate_id).cloned()
393                    }
394                })
395            })
396            .collect()
397    }
398
399    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
400    ///
401    /// This method ensures the following invariants:
402    /// - The batch ID does not already exist in storage.
403    /// - The author is a member of the committee for the batch round.
404    /// - The timestamp is within the allowed time range.
405    /// - None of the transmissions are from any past rounds (up to GC).
406    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
407    /// - All previous certificates declared in the certificate exist in storage (up to GC).
408    /// - All previous certificates are for the previous round (i.e. round - 1).
409    /// - All previous certificates contain a unique author.
410    /// - The previous certificates reached the quorum threshold (N - f).
411    pub fn check_batch_header(
412        &self,
413        batch_header: &BatchHeader<N>,
414        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
415        aborted_transmissions: HashSet<TransmissionID<N>>,
416    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
417        // Retrieve the round.
418        let round = batch_header.round();
419        // Retrieve the GC round.
420        let gc_round = self.gc_round();
421        // Construct a GC log message.
422        let gc_log = format!("(gc = {gc_round})");
423
424        // Ensure the batch ID does not already exist in storage.
425        if self.contains_batch(batch_header.batch_id()) {
426            bail!("Batch for round {round} already exists in storage {gc_log}")
427        }
428
429        // Retrieve the committee lookback for the batch round.
430        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
431            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
432        };
433        // Ensure the author is in the committee.
434        if !committee_lookback.is_committee_member(batch_header.author()) {
435            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
436        }
437
438        // Check the timestamp for liveness.
439        check_timestamp_for_liveness(batch_header.timestamp())?;
440
441        // Retrieve the missing transmissions in storage from the given transmissions.
442        let missing_transmissions = self
443            .transmissions
444            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
445            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
446
447        // Compute the previous round.
448        let previous_round = round.saturating_sub(1);
449        // Check if the previous round is within range of the GC round.
450        if previous_round > gc_round {
451            // Retrieve the committee lookback for the previous round.
452            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
453                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
454            };
455            // Ensure the previous round certificates exists in storage.
456            if !self.contains_certificates_for_round(previous_round) {
457                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
458            }
459            // Ensure the number of previous certificate IDs is at or below the number of committee members.
460            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
461                bail!("Too many previous certificates for round {round} {gc_log}")
462            }
463            // Initialize a set of the previous authors.
464            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
465            // Ensure storage contains all declared previous certificates (up to GC).
466            for previous_certificate_id in batch_header.previous_certificate_ids() {
467                // Retrieve the previous certificate.
468                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
469                    bail!(
470                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
471                        fmt_id(previous_certificate_id)
472                    )
473                };
474                // Ensure the previous certificate is for the previous round.
475                if previous_certificate.round() != previous_round {
476                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
477                }
478                // Ensure the previous author is new.
479                if previous_authors.contains(&previous_certificate.author()) {
480                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
481                }
482                // Insert the author of the previous certificate.
483                previous_authors.insert(previous_certificate.author());
484            }
485            // Ensure the previous certificates have reached the quorum threshold.
486            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
487                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
488            }
489        }
490        Ok(missing_transmissions)
491    }
492
493    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
494    ///
495    /// This method ensures the following invariants:
496    /// - The certificate ID does not already exist in storage.
497    /// - The batch ID does not already exist in storage.
498    /// - The author is a member of the committee for the batch round.
499    /// - The author has not already created a certificate for the batch round.
500    /// - The timestamp is within the allowed time range.
501    /// - None of the transmissions are from any past rounds (up to GC).
502    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
503    /// - All previous certificates declared in the certificate exist in storage (up to GC).
504    /// - All previous certificates are for the previous round (i.e. round - 1).
505    /// - The previous certificates reached the quorum threshold (N - f).
506    /// - The timestamps from the signers are all within the allowed time range.
507    /// - The signers have reached the quorum threshold (N - f).
508    pub fn check_certificate(
509        &self,
510        certificate: &BatchCertificate<N>,
511        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
512        aborted_transmissions: HashSet<TransmissionID<N>>,
513    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
514        // Retrieve the round.
515        let round = certificate.round();
516        // Retrieve the GC round.
517        let gc_round = self.gc_round();
518        // Construct a GC log message.
519        let gc_log = format!("(gc = {gc_round})");
520
521        // Ensure the certificate ID does not already exist in storage.
522        if self.contains_certificate(certificate.id()) {
523            bail!("Certificate for round {round} already exists in storage {gc_log}")
524        }
525
526        // Ensure the storage does not already contain a certificate for this author in this round.
527        if self.contains_certificate_in_round_from(round, certificate.author()) {
528            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
529        }
530
531        // Ensure the batch header is well-formed.
532        let missing_transmissions =
533            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
534
535        // Check the timestamp for liveness.
536        check_timestamp_for_liveness(certificate.timestamp())?;
537
538        // Retrieve the committee lookback for the batch round.
539        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
540            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
541        };
542
543        // Initialize a set of the signers.
544        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
545        // Append the batch author.
546        signers.insert(certificate.author());
547
548        // Iterate over the signatures.
549        for signature in certificate.signatures() {
550            // Retrieve the signer.
551            let signer = signature.to_address();
552            // Ensure the signer is in the committee.
553            if !committee_lookback.is_committee_member(signer) {
554                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
555            }
556            // Append the signer.
557            signers.insert(signer);
558        }
559
560        // Ensure the signatures have reached the quorum threshold.
561        if !committee_lookback.is_quorum_threshold_reached(&signers) {
562            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
563        }
564        Ok(missing_transmissions)
565    }
566
567    /// Inserts the given `certificate` into storage.
568    ///
569    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
570    ///
571    /// This method ensures the following invariants:
572    /// - The certificate ID does not already exist in storage.
573    /// - The batch ID does not already exist in storage.
574    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
575    /// - All previous certificates declared in the certificate exist in storage (up to GC).
576    /// - All previous certificates are for the previous round (i.e. round - 1).
577    /// - The previous certificates reached the quorum threshold (N - f).
578    pub fn insert_certificate(
579        &self,
580        certificate: BatchCertificate<N>,
581        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
582        aborted_transmissions: HashSet<TransmissionID<N>>,
583    ) -> Result<()> {
584        // Ensure the certificate round is above the GC round.
585        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
586        // Ensure the certificate and its transmissions are valid.
587        let missing_transmissions =
588            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
589        // Insert the certificate into storage.
590        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
591        Ok(())
592    }
593
594    /// Inserts the given `certificate` into storage.
595    ///
596    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
597    ///
598    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
599    fn insert_certificate_atomic(
600        &self,
601        certificate: BatchCertificate<N>,
602        aborted_transmission_ids: HashSet<TransmissionID<N>>,
603        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
604    ) {
605        // Retrieve the round.
606        let round = certificate.round();
607        // Retrieve the certificate ID.
608        let certificate_id = certificate.id();
609        // Retrieve the batch ID.
610        let batch_id = certificate.batch_id();
611        // Retrieve the author of the batch.
612        let author = certificate.author();
613
614        // Insert the round to certificate ID entry.
615        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
616        // Obtain the certificate's transmission ids.
617        let transmission_ids = certificate.transmission_ids().clone();
618        // Insert the certificate.
619        self.certificates.write().insert(certificate_id, certificate);
620        // Insert the batch ID.
621        self.batch_ids.write().insert(batch_id, round);
622        // Insert the certificate ID for each of the transmissions into storage.
623        self.transmissions.insert_transmissions(
624            certificate_id,
625            transmission_ids,
626            aborted_transmission_ids,
627            missing_transmissions,
628        );
629    }
630
631    /// Removes the given `certificate ID` from storage.
632    ///
633    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
634    ///
635    /// If the certificate was successfully removed, `true` is returned.
636    /// If the certificate did not exist in storage, `false` is returned.
637    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
638        // Retrieve the certificate.
639        let Some(certificate) = self.get_certificate(certificate_id) else {
640            guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
641            return false;
642        };
643        // Retrieve the round.
644        let round = certificate.round();
645        // Retrieve the batch ID.
646        let batch_id = certificate.batch_id();
647        // Compute the author of the batch.
648        let author = certificate.author();
649
650        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
651        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
652        //  It will be better to write tests that compare the union of the sets.
653
654        // Update the round.
655        match self.rounds.write().entry(round) {
656            Entry::Occupied(mut entry) => {
657                // Remove the round to certificate ID entry.
658                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
659                // If the round is empty, remove it.
660                if entry.get().is_empty() {
661                    entry.swap_remove();
662                }
663            }
664            Entry::Vacant(_) => {}
665        }
666        // Remove the certificate.
667        self.certificates.write().swap_remove(&certificate_id);
668        // Remove the batch ID.
669        self.batch_ids.write().swap_remove(&batch_id);
670        // Remove the transmission entries in the certificate from storage.
671        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
672        // Return successfully.
673        true
674    }
675}
676
677impl<N: Network> Storage<N> {
678    /// Syncs the current height with the block.
679    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
680        // If the block height is greater than the current height in storage, sync the height.
681        if next_height > self.current_height() {
682            // Update the current height in storage.
683            self.current_height.store(next_height, Ordering::SeqCst);
684        }
685    }
686
687    /// Syncs the current round with the block.
688    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
689        // Retrieve the current round in the block.
690        let next_round = next_round.max(1);
691        // If the round in the block is greater than the current round in storage, sync the round.
692        if next_round > self.current_round() {
693            // Update the current round in storage.
694            self.update_current_round(next_round);
695            // Log the updated round.
696            guard_info!(self, "Synced to round {next_round}...");
697        }
698    }
699
700    /// Syncs the batch certificate with the block.
701    pub(crate) fn sync_certificate_with_block(
702        &self,
703        block: &Block<N>,
704        certificate: BatchCertificate<N>,
705        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
706    ) {
707        // Skip if the certificate round is below the GC round.
708        if certificate.round() <= self.gc_round() {
709            return;
710        }
711        // If the certificate ID already exists in storage, skip it.
712        if self.contains_certificate(certificate.id()) {
713            return;
714        }
715        // Retrieve the transmissions for the certificate.
716        let mut missing_transmissions = HashMap::new();
717
718        // Retrieve the aborted transmissions for the certificate.
719        let mut aborted_transmissions = HashSet::new();
720
721        // Track the block's aborted solutions and transactions.
722        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
723        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
724
725        // Iterate over the transmission IDs.
726        for transmission_id in certificate.transmission_ids() {
727            // If the transmission ID already exists in the map, skip it.
728            if missing_transmissions.contains_key(transmission_id) {
729                continue;
730            }
731            // If the transmission ID exists in storage, skip it.
732            if self.contains_transmission(*transmission_id) {
733                continue;
734            }
735            // Retrieve the transmission.
736            match transmission_id {
737                TransmissionID::Ratification => (),
738                TransmissionID::Solution(solution_id, _) => {
739                    // Retrieve the solution.
740                    match block.get_solution(solution_id) {
741                        // Insert the solution.
742                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
743                        // Otherwise, try to load the solution from the ledger.
744                        None => match self.ledger.get_solution(solution_id) {
745                            // Insert the solution.
746                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
747                            // Check if the solution is in the aborted solutions.
748                            Err(_) => {
749                                // Insert the aborted solution if it exists in the block or ledger.
750                                match aborted_solutions.contains(solution_id)
751                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
752                                {
753                                    true => {
754                                        aborted_transmissions.insert(*transmission_id);
755                                    }
756                                    false => {
757                                        guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
758                                    }
759                                }
760                                continue;
761                            }
762                        },
763                    };
764                }
765                TransmissionID::Transaction(transaction_id, _) => {
766                    // Retrieve the transaction.
767                    match unconfirmed_transactions.get(transaction_id) {
768                        // Insert the transaction.
769                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
770                        // Otherwise, try to load the unconfirmed transaction from the ledger.
771                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
772                            // Insert the transaction.
773                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
774                            // Check if the transaction is in the aborted transactions.
775                            Err(_) => {
776                                // Insert the aborted transaction if it exists in the block or ledger.
777                                match aborted_transactions.contains(transaction_id)
778                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
779                                {
780                                    true => {
781                                        aborted_transmissions.insert(*transmission_id);
782                                    }
783                                    false => guard_warn!(
784                                        self,
785                                        "Missing transaction {transaction_id} in block {}",
786                                        block.height()
787                                    ),
788                                }
789                                continue;
790                            }
791                        },
792                    };
793                }
794            }
795        }
796        // Insert the batch certificate into storage.
797        let certificate_id = fmt_id(certificate.id());
798        guard_debug!(
799            self,
800            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
801            certificate.round(),
802            certificate.transmission_ids().len()
803        );
804        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
805            guard_error!(
806                self,
807                "Failed to insert certificate '{certificate_id}' from block {} - {error}",
808                block.height()
809            );
810        }
811    }
812}
813
814#[cfg(test)]
815impl<N: Network> Storage<N> {
816    /// Returns the ledger service.
817    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
818        &self.ledger
819    }
820
821    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
822    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
823        self.rounds.read().clone().into_iter()
824    }
825
826    /// Returns an iterator over the `(certificate ID, certificate)` entries.
827    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
828        self.certificates.read().clone().into_iter()
829    }
830
831    /// Returns an iterator over the `(batch ID, round)` entries.
832    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
833        self.batch_ids.read().clone().into_iter()
834    }
835
836    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
837    pub fn transmissions_iter(
838        &self,
839    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
840        self.transmissions.as_hashmap().into_iter()
841    }
842
843    /// Inserts the given `certificate` into storage.
844    ///
845    /// Note: Do NOT use this in production. This is for **testing only**.
846    #[cfg(test)]
847    #[doc(hidden)]
848    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
849        // Retrieve the round.
850        let round = certificate.round();
851        // Retrieve the certificate ID.
852        let certificate_id = certificate.id();
853        // Retrieve the batch ID.
854        let batch_id = certificate.batch_id();
855        // Retrieve the author of the batch.
856        let author = certificate.author();
857
858        // Insert the round to certificate ID entry.
859        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
860        // Obtain the certificate's transmission ids.
861        let transmission_ids = certificate.transmission_ids().clone();
862        // Insert the certificate.
863        self.certificates.write().insert(certificate_id, certificate);
864        // Insert the batch ID.
865        self.batch_ids.write().insert(batch_id, round);
866
867        // Construct the dummy missing transmissions (for testing purposes).
868        let missing_transmissions = transmission_ids
869            .iter()
870            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
871            .collect::<HashMap<_, _>>();
872        // Insert the certificate ID for each of the transmissions into storage.
873        self.transmissions.insert_transmissions(
874            certificate_id,
875            transmission_ids,
876            Default::default(),
877            missing_transmissions,
878        );
879    }
880}
881
882#[cfg(test)]
883pub(crate) mod tests {
884    use super::*;
885    use amareleo_node_bft_ledger_service::MockLedgerService;
886    use amareleo_node_bft_storage_service::BFTMemoryService;
887    use snarkvm::{
888        ledger::narwhal::Data,
889        prelude::{Rng, TestRng},
890    };
891
892    use ::bytes::Bytes;
893    use indexmap::indexset;
894
895    type CurrentNetwork = snarkvm::prelude::MainnetV0;
896
897    /// Asserts that the storage matches the expected layout.
898    pub fn assert_storage<N: Network>(
899        storage: &Storage<N>,
900        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
901        certificates: &[(Field<N>, BatchCertificate<N>)],
902        batch_ids: &[(Field<N>, u64)],
903        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
904    ) {
905        // Ensure the rounds are well-formed.
906        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
907        // Ensure the certificates are well-formed.
908        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
909        // Ensure the batch IDs are well-formed.
910        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
911        // Ensure the transmissions are well-formed.
912        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
913    }
914
915    /// Samples a random transmission.
916    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
917        // Sample random fake solution bytes.
918        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
919        // Sample random fake transaction bytes.
920        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
921        // Sample a random transmission.
922        match rng.gen::<bool>() {
923            true => Transmission::Solution(s(rng)),
924            false => Transmission::Transaction(t(rng)),
925        }
926    }
927
928    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
929    pub(crate) fn sample_transmissions(
930        certificate: &BatchCertificate<CurrentNetwork>,
931        rng: &mut TestRng,
932    ) -> (
933        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
934        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
935    ) {
936        // Retrieve the certificate ID.
937        let certificate_id = certificate.id();
938
939        let mut missing_transmissions = HashMap::new();
940        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
941        for transmission_id in certificate.transmission_ids() {
942            // Initialize the transmission.
943            let transmission = sample_transmission(rng);
944            // Update the missing transmissions.
945            missing_transmissions.insert(*transmission_id, transmission.clone());
946            // Update the transmissions map.
947            transmissions
948                .entry(*transmission_id)
949                .or_insert((transmission, Default::default()))
950                .1
951                .insert(certificate_id);
952        }
953        (missing_transmissions, transmissions)
954    }
955
956    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
957
958    #[test]
959    fn test_certificate_insert_remove() {
960        let rng = &mut TestRng::default();
961
962        // Sample a committee.
963        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
964        // Initialize the ledger.
965        let ledger = Arc::new(MockLedgerService::new(committee));
966        // Initialize the storage.
967        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
968
969        // Ensure the storage is empty.
970        assert_storage(&storage, &[], &[], &[], &Default::default());
971
972        // Create a new certificate.
973        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
974        // Retrieve the certificate ID.
975        let certificate_id = certificate.id();
976        // Retrieve the round.
977        let round = certificate.round();
978        // Retrieve the batch ID.
979        let batch_id = certificate.batch_id();
980        // Retrieve the author of the batch.
981        let author = certificate.author();
982
983        // Construct the sample 'transmissions'.
984        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
985
986        // Insert the certificate.
987        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
988        // Ensure the certificate exists in storage.
989        assert!(storage.contains_certificate(certificate_id));
990        // Ensure the certificate is stored in the correct round.
991        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
992        // Ensure the certificate is stored for the correct round and author.
993        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
994
995        // Check that the underlying storage representation is correct.
996        {
997            // Construct the expected layout for 'rounds'.
998            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
999            // Construct the expected layout for 'certificates'.
1000            let certificates = [(certificate_id, certificate.clone())];
1001            // Construct the expected layout for 'batch_ids'.
1002            let batch_ids = [(batch_id, round)];
1003            // Assert the storage is well-formed.
1004            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1005        }
1006
1007        // Retrieve the certificate.
1008        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1009        // Ensure the retrieved certificate is the same as the inserted certificate.
1010        assert_eq!(certificate, candidate_certificate);
1011
1012        // Remove the certificate.
1013        assert!(storage.remove_certificate(certificate_id));
1014        // Ensure the certificate does not exist in storage.
1015        assert!(!storage.contains_certificate(certificate_id));
1016        // Ensure the certificate is no longer stored in the round.
1017        assert!(storage.get_certificates_for_round(round).is_empty());
1018        // Ensure the certificate is no longer stored for the round and author.
1019        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1020        // Ensure the storage is empty.
1021        assert_storage(&storage, &[], &[], &[], &Default::default());
1022    }
1023
1024    #[test]
1025    fn test_certificate_duplicate() {
1026        let rng = &mut TestRng::default();
1027
1028        // Sample a committee.
1029        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1030        // Initialize the ledger.
1031        let ledger = Arc::new(MockLedgerService::new(committee));
1032        // Initialize the storage.
1033        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1034
1035        // Ensure the storage is empty.
1036        assert_storage(&storage, &[], &[], &[], &Default::default());
1037
1038        // Create a new certificate.
1039        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1040        // Retrieve the certificate ID.
1041        let certificate_id = certificate.id();
1042        // Retrieve the round.
1043        let round = certificate.round();
1044        // Retrieve the batch ID.
1045        let batch_id = certificate.batch_id();
1046        // Retrieve the author of the batch.
1047        let author = certificate.author();
1048
1049        // Construct the expected layout for 'rounds'.
1050        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1051        // Construct the expected layout for 'certificates'.
1052        let certificates = [(certificate_id, certificate.clone())];
1053        // Construct the expected layout for 'batch_ids'.
1054        let batch_ids = [(batch_id, round)];
1055        // Construct the sample 'transmissions'.
1056        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1057
1058        // Insert the certificate.
1059        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1060        // Ensure the certificate exists in storage.
1061        assert!(storage.contains_certificate(certificate_id));
1062        // Check that the underlying storage representation is correct.
1063        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1064
1065        // Insert the certificate again - without any missing transmissions.
1066        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1067        // Ensure the certificate exists in storage.
1068        assert!(storage.contains_certificate(certificate_id));
1069        // Check that the underlying storage representation remains unchanged.
1070        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1071
1072        // Insert the certificate again - with all of the original missing transmissions.
1073        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1074        // Ensure the certificate exists in storage.
1075        assert!(storage.contains_certificate(certificate_id));
1076        // Check that the underlying storage representation remains unchanged.
1077        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1078    }
1079}
1080
1081#[cfg(test)]
1082pub mod prop_tests {
1083    use super::*;
1084    use crate::helpers::{now, storage::tests::assert_storage};
1085    use amareleo_node_bft_ledger_service::MockLedgerService;
1086    use amareleo_node_bft_storage_service::BFTMemoryService;
1087    use snarkvm::{
1088        ledger::{
1089            committee::prop_tests::{CommitteeContext, ValidatorSet},
1090            narwhal::{BatchHeader, Data},
1091            puzzle::SolutionID,
1092        },
1093        prelude::{Signature, Uniform},
1094    };
1095
1096    use ::bytes::Bytes;
1097    use indexmap::indexset;
1098    use proptest::{
1099        collection,
1100        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1101        prop_oneof,
1102        sample::{Selector, size_range},
1103        test_runner::TestRng,
1104    };
1105    use rand::{CryptoRng, Error, Rng, RngCore};
1106    use std::fmt::Debug;
1107    use test_strategy::proptest;
1108
1109    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1110
1111    impl Arbitrary for Storage<CurrentNetwork> {
1112        type Parameters = CommitteeContext;
1113        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1114
1115        fn arbitrary() -> Self::Strategy {
1116            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1117                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1118                    let ledger = Arc::new(MockLedgerService::new(committee));
1119                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1120                })
1121                .boxed()
1122        }
1123
1124        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1125            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1126                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1127                    let ledger = Arc::new(MockLedgerService::new(committee));
1128                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1129                })
1130                .boxed()
1131        }
1132    }
1133
1134    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1135    #[derive(Debug)]
1136    pub struct CryptoTestRng(TestRng);
1137
1138    impl Arbitrary for CryptoTestRng {
1139        type Parameters = ();
1140        type Strategy = BoxedStrategy<CryptoTestRng>;
1141
1142        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1143            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1144        }
1145    }
1146    impl RngCore for CryptoTestRng {
1147        fn next_u32(&mut self) -> u32 {
1148            self.0.next_u32()
1149        }
1150
1151        fn next_u64(&mut self) -> u64 {
1152            self.0.next_u64()
1153        }
1154
1155        fn fill_bytes(&mut self, dest: &mut [u8]) {
1156            self.0.fill_bytes(dest);
1157        }
1158
1159        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1160            self.0.try_fill_bytes(dest)
1161        }
1162    }
1163
1164    impl CryptoRng for CryptoTestRng {}
1165
1166    #[derive(Debug, Clone)]
1167    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1168
1169    impl Arbitrary for AnyTransmission {
1170        type Parameters = ();
1171        type Strategy = BoxedStrategy<AnyTransmission>;
1172
1173        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1174            any_transmission().prop_map(AnyTransmission).boxed()
1175        }
1176    }
1177
1178    #[derive(Debug, Clone)]
1179    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1180
1181    impl Arbitrary for AnyTransmissionID {
1182        type Parameters = ();
1183        type Strategy = BoxedStrategy<AnyTransmissionID>;
1184
1185        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1186            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1187        }
1188    }
1189
1190    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1191        prop_oneof![
1192            (collection::vec(any::<u8>(), 512..=512))
1193                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1194            (collection::vec(any::<u8>(), 2048..=2048))
1195                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1196        ]
1197        .boxed()
1198    }
1199
1200    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1201        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1202    }
1203
1204    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1205        Just(0)
1206            .prop_perturb(|_, rng| {
1207                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1208            })
1209            .boxed()
1210    }
1211
1212    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1213        prop_oneof![
1214            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1215                id,
1216                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1217            )),
1218            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1219                id,
1220                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1221            )),
1222        ]
1223        .boxed()
1224    }
1225
1226    pub fn sign_batch_header<R: Rng + CryptoRng>(
1227        validator_set: &ValidatorSet,
1228        batch_header: &BatchHeader<CurrentNetwork>,
1229        rng: &mut R,
1230    ) -> IndexSet<Signature<CurrentNetwork>> {
1231        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1232        for validator in validator_set.0.iter() {
1233            let private_key = validator.private_key;
1234            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1235        }
1236        signatures
1237    }
1238
1239    #[proptest]
1240    fn test_certificate_duplicate(
1241        context: CommitteeContext,
1242        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1243        mut rng: CryptoTestRng,
1244        selector: Selector,
1245    ) {
1246        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1247        let committee_id = committee.id();
1248
1249        // Initialize the storage.
1250        let ledger = Arc::new(MockLedgerService::new(committee));
1251        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1252
1253        // Ensure the storage is empty.
1254        assert_storage(&storage, &[], &[], &[], &Default::default());
1255
1256        // Create a new certificate.
1257        let signer = selector.select(&validators);
1258
1259        let mut transmission_map = IndexMap::new();
1260
1261        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1262            transmission_map.insert(*id, t.clone());
1263        }
1264
1265        let batch_header = BatchHeader::new(
1266            &signer.private_key,
1267            0,
1268            now(),
1269            committee_id,
1270            transmission_map.keys().cloned().collect(),
1271            Default::default(),
1272            &mut rng,
1273        )
1274        .unwrap();
1275
1276        // Remove the author from the validator set passed to create the batch
1277        // certificate, the author should not sign their own batch.
1278        let mut validators = validators.clone();
1279        validators.remove(signer);
1280
1281        let certificate = BatchCertificate::from(
1282            batch_header.clone(),
1283            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1284        )
1285        .unwrap();
1286
1287        // Retrieve the certificate ID.
1288        let certificate_id = certificate.id();
1289        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1290        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1291            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1292        }
1293
1294        // Retrieve the round.
1295        let round = certificate.round();
1296        // Retrieve the batch ID.
1297        let batch_id = certificate.batch_id();
1298        // Retrieve the author of the batch.
1299        let author = certificate.author();
1300
1301        // Construct the expected layout for 'rounds'.
1302        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1303        // Construct the expected layout for 'certificates'.
1304        let certificates = [(certificate_id, certificate.clone())];
1305        // Construct the expected layout for 'batch_ids'.
1306        let batch_ids = [(batch_id, round)];
1307
1308        // Insert the certificate.
1309        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1310            transmission_map.into_iter().collect();
1311        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1312        // Ensure the certificate exists in storage.
1313        assert!(storage.contains_certificate(certificate_id));
1314        // Check that the underlying storage representation is correct.
1315        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1316
1317        // Insert the certificate again - without any missing transmissions.
1318        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1319        // Ensure the certificate exists in storage.
1320        assert!(storage.contains_certificate(certificate_id));
1321        // Check that the underlying storage representation remains unchanged.
1322        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1323
1324        // Insert the certificate again - with all of the original missing transmissions.
1325        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1326        // Ensure the certificate exists in storage.
1327        assert!(storage.contains_certificate(certificate_id));
1328        // Check that the underlying storage representation remains unchanged.
1329        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1330    }
1331}