1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21 ledger::{
22 block::{Block, Transaction},
23 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24 },
25 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26 utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::RwLock;
34use rayon::iter::{IntoParallelIterator, ParallelIterator};
35use std::{
36 collections::{HashMap, HashSet},
37 sync::{
38 Arc,
39 atomic::{AtomicU32, AtomicU64, Ordering},
40 },
41};
42use tracing::subscriber::DefaultGuard;
43
44#[derive(Clone, Debug)]
45pub struct Storage<N: Network>(Arc<StorageInner<N>>);
46
47impl<N: Network> std::ops::Deref for Storage<N> {
48 type Target = Arc<StorageInner<N>>;
49
50 fn deref(&self) -> &Self::Target {
51 &self.0
52 }
53}
54
55impl<N: Network> TracingHandlerGuard for Storage<N> {
56 fn get_tracing_guard(&self) -> Option<DefaultGuard> {
58 self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
59 }
60}
61
62#[derive(Debug)]
82pub struct StorageInner<N: Network> {
83 ledger: Arc<dyn LedgerService<N>>,
85 current_height: AtomicU32,
88 current_round: AtomicU64,
97 gc_round: AtomicU64,
99 max_gc_rounds: u64,
101 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Address<N>)>>>,
104 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
106 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
108 transmissions: Arc<dyn StorageService<N>>,
110 tracing: Option<TracingHandler>,
112}
113
114impl<N: Network> Storage<N> {
115 pub fn new(
117 ledger: Arc<dyn LedgerService<N>>,
118 transmissions: Arc<dyn StorageService<N>>,
119 max_gc_rounds: u64,
120 tracing: Option<TracingHandler>,
121 ) -> Self {
122 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
125 let current_round = committee.starting_round().max(1);
127
128 let storage = Self(Arc::new(StorageInner {
130 ledger,
131 current_height: Default::default(),
132 current_round: Default::default(),
133 gc_round: Default::default(),
134 max_gc_rounds,
135 rounds: Default::default(),
136 certificates: Default::default(),
137 batch_ids: Default::default(),
138 transmissions,
139 tracing,
140 }));
141 storage.update_current_round(current_round);
143 storage.garbage_collect_certificates(current_round);
146 storage
148 }
149}
150
151impl<N: Network> Storage<N> {
152 pub fn current_height(&self) -> u32 {
154 self.current_height.load(Ordering::SeqCst)
156 }
157}
158
159impl<N: Network> Storage<N> {
160 pub fn current_round(&self) -> u64 {
162 self.current_round.load(Ordering::SeqCst)
164 }
165
166 pub fn gc_round(&self) -> u64 {
168 self.gc_round.load(Ordering::SeqCst)
170 }
171
172 pub fn max_gc_rounds(&self) -> u64 {
174 self.max_gc_rounds
175 }
176
177 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
180 let next_round = current_round + 1;
182
183 {
185 let storage_round = self.current_round();
187 if next_round < storage_round {
189 return Ok(storage_round);
190 }
191 }
192
193 let current_committee = self.ledger.current_committee()?;
195 let starting_round = current_committee.starting_round();
197 if next_round < starting_round {
199 let latest_block_round = self.ledger.latest_round();
201 guard_info!(
203 self,
204 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
205 );
206 self.sync_round_with_block(latest_block_round);
208 return Ok(latest_block_round);
210 }
211
212 self.update_current_round(next_round);
214
215 #[cfg(feature = "metrics")]
216 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
217
218 let storage_round = self.current_round();
220 let gc_round = self.gc_round();
222 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
224 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
226
227 guard_info!(self, "Starting round {next_round}...");
229 Ok(next_round)
230 }
231
232 fn update_current_round(&self, next_round: u64) {
234 self.current_round.store(next_round, Ordering::SeqCst);
236 }
237
238 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
240 let current_gc_round = self.gc_round();
242 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
244 if next_gc_round > current_gc_round {
246 for gc_round in current_gc_round..=next_gc_round {
248 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
250 self.remove_certificate(id);
252 }
253 }
254 self.gc_round.store(next_gc_round, Ordering::SeqCst);
256 }
257 }
258}
259
260impl<N: Network> Storage<N> {
261 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
263 self.rounds.read().contains_key(&round)
265 }
266
267 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
269 self.certificates.read().contains_key(&certificate_id)
271 }
272
273 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
275 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, a)| a == &author))
276 }
277
278 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
280 self.batch_ids.read().contains_key(&batch_id)
282 }
283
284 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
286 self.transmissions.contains_transmission(transmission_id.into())
287 }
288
289 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
292 self.transmissions.get_transmission(transmission_id.into())
293 }
294
295 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
298 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
300 }
301
302 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
305 self.batch_ids.read().get(&batch_id).copied()
307 }
308
309 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
312 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
314 }
315
316 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
319 self.certificates.read().get(&certificate_id).cloned()
321 }
322
323 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
327 if let Some(entries) = self.rounds.read().get(&round) {
329 let certificates = self.certificates.read();
330 entries.iter().find_map(
331 |(certificate_id, a)| if a == &author { certificates.get(certificate_id).cloned() } else { None },
332 )
333 } else {
334 Default::default()
335 }
336 }
337
338 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
341 if round == 0 {
343 return Default::default();
344 }
345 if let Some(entries) = self.rounds.read().get(&round) {
347 let certificates = self.certificates.read();
348 entries.iter().flat_map(|(certificate_id, _)| certificates.get(certificate_id).cloned()).collect()
349 } else {
350 Default::default()
351 }
352 }
353
354 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
357 if round == 0 {
359 return Default::default();
360 }
361 if let Some(entries) = self.rounds.read().get(&round) {
363 entries.iter().map(|(certificate_id, _)| *certificate_id).collect()
364 } else {
365 Default::default()
366 }
367 }
368
369 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
372 if round == 0 {
374 return Default::default();
375 }
376 if let Some(entries) = self.rounds.read().get(&round) {
378 entries.iter().map(|(_, author)| *author).collect()
379 } else {
380 Default::default()
381 }
382 }
383
384 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
387 let rounds = self.rounds.read();
389 let certificates = self.certificates.read();
390
391 cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
393 .flat_map(|(_, certificates_for_round)| {
394 cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _)| {
396 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
398 None
399 } else {
400 certificates.get(&certificate_id).cloned()
402 }
403 })
404 })
405 .collect()
406 }
407
408 pub fn check_batch_header(
421 &self,
422 batch_header: &BatchHeader<N>,
423 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
424 aborted_transmissions: HashSet<TransmissionID<N>>,
425 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
426 let round = batch_header.round();
428 let gc_round = self.gc_round();
430 let gc_log = format!("(gc = {gc_round})");
432
433 if self.contains_batch(batch_header.batch_id()) {
435 bail!("Batch for round {round} already exists in storage {gc_log}")
436 }
437
438 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
440 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
441 };
442 if !committee_lookback.is_committee_member(batch_header.author()) {
444 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
445 }
446
447 check_timestamp_for_liveness(batch_header.timestamp())?;
449
450 let missing_transmissions = self
452 .transmissions
453 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
454 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
455
456 let previous_round = round.saturating_sub(1);
458 if previous_round > gc_round {
460 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
462 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
463 };
464 if !self.contains_certificates_for_round(previous_round) {
466 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
467 }
468 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
470 bail!("Too many previous certificates for round {round} {gc_log}")
471 }
472 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
474 for previous_certificate_id in batch_header.previous_certificate_ids() {
476 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
478 bail!(
479 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
480 fmt_id(previous_certificate_id)
481 )
482 };
483 if previous_certificate.round() != previous_round {
485 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
486 }
487 if previous_authors.contains(&previous_certificate.author()) {
489 bail!("Round {round} certificate contains a duplicate author {gc_log}")
490 }
491 previous_authors.insert(previous_certificate.author());
493 }
494 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
496 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
497 }
498 }
499 Ok(missing_transmissions)
500 }
501
502 pub fn check_certificate(
518 &self,
519 certificate: &BatchCertificate<N>,
520 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
521 aborted_transmissions: HashSet<TransmissionID<N>>,
522 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
523 let round = certificate.round();
525 let gc_round = self.gc_round();
527 let gc_log = format!("(gc = {gc_round})");
529
530 if self.contains_certificate(certificate.id()) {
532 bail!("Certificate for round {round} already exists in storage {gc_log}")
533 }
534
535 if self.contains_certificate_in_round_from(round, certificate.author()) {
537 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
538 }
539
540 let missing_transmissions =
542 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
543
544 check_timestamp_for_liveness(certificate.timestamp())?;
546
547 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
549 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
550 };
551
552 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
554 signers.insert(certificate.author());
556
557 for signature in certificate.signatures() {
559 let signer = signature.to_address();
561 if !committee_lookback.is_committee_member(signer) {
563 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
564 }
565 signers.insert(signer);
567 }
568
569 if !committee_lookback.is_quorum_threshold_reached(&signers) {
571 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
572 }
573 Ok(missing_transmissions)
574 }
575
576 pub fn insert_certificate(
588 &self,
589 certificate: BatchCertificate<N>,
590 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
591 aborted_transmissions: HashSet<TransmissionID<N>>,
592 ) -> Result<()> {
593 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
595 let missing_transmissions =
597 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
598 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
600 Ok(())
601 }
602
603 fn insert_certificate_atomic(
609 &self,
610 certificate: BatchCertificate<N>,
611 aborted_transmission_ids: HashSet<TransmissionID<N>>,
612 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
613 ) {
614 let round = certificate.round();
616 let certificate_id = certificate.id();
618 let author = certificate.author();
620
621 self.rounds.write().entry(round).or_default().insert((certificate_id, author));
623 let transmission_ids = certificate.transmission_ids().clone();
625 self.certificates.write().insert(certificate_id, certificate);
627 self.batch_ids.write().insert(certificate_id, round);
629 self.transmissions.insert_transmissions(
631 certificate_id,
632 transmission_ids,
633 aborted_transmission_ids,
634 missing_transmissions,
635 );
636 }
637
638 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
645 let Some(certificate) = self.get_certificate(certificate_id) else {
647 guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
648 return false;
649 };
650 let round = certificate.round();
652 let author = certificate.author();
654
655 match self.rounds.write().entry(round) {
661 Entry::Occupied(mut entry) => {
662 entry.get_mut().swap_remove(&(certificate_id, author));
664 if entry.get().is_empty() {
666 entry.swap_remove();
667 }
668 }
669 Entry::Vacant(_) => {}
670 }
671 self.certificates.write().swap_remove(&certificate_id);
673 self.batch_ids.write().swap_remove(&certificate_id);
675 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
677 true
679 }
680}
681
682impl<N: Network> Storage<N> {
683 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
685 if next_height > self.current_height() {
687 self.current_height.store(next_height, Ordering::SeqCst);
689 }
690 }
691
692 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
694 let next_round = next_round.max(1);
696 if next_round > self.current_round() {
698 self.update_current_round(next_round);
700 guard_info!(self, "Synced to round {next_round}...");
702 }
703 }
704
705 pub(crate) fn sync_certificate_with_block(
707 &self,
708 block: &Block<N>,
709 certificate: BatchCertificate<N>,
710 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
711 ) {
712 if certificate.round() <= self.gc_round() {
714 return;
715 }
716 if self.contains_certificate(certificate.id()) {
718 return;
719 }
720 let mut missing_transmissions = HashMap::new();
722
723 let mut aborted_transmissions = HashSet::new();
725
726 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
728 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
729
730 for transmission_id in certificate.transmission_ids() {
732 if missing_transmissions.contains_key(transmission_id) {
734 continue;
735 }
736 if self.contains_transmission(*transmission_id) {
738 continue;
739 }
740 match transmission_id {
742 TransmissionID::Ratification => (),
743 TransmissionID::Solution(solution_id, _) => {
744 match block.get_solution(solution_id) {
746 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
748 None => match self.ledger.get_solution(solution_id) {
750 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
752 Err(_) => {
754 match aborted_solutions.contains(solution_id)
756 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
757 {
758 true => {
759 aborted_transmissions.insert(*transmission_id);
760 }
761 false => {
762 guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
763 }
764 }
765 continue;
766 }
767 },
768 };
769 }
770 TransmissionID::Transaction(transaction_id, _) => {
771 match unconfirmed_transactions.get(transaction_id) {
773 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
775 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
777 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
779 Err(_) => {
781 match aborted_transactions.contains(transaction_id)
783 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
784 {
785 true => {
786 aborted_transmissions.insert(*transmission_id);
787 }
788 false => guard_warn!(
789 self,
790 "Missing transaction {transaction_id} in block {}",
791 block.height()
792 ),
793 }
794 continue;
795 }
796 },
797 };
798 }
799 }
800 }
801 let certificate_id = fmt_id(certificate.id());
803 guard_debug!(
804 self,
805 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
806 certificate.round(),
807 certificate.transmission_ids().len()
808 );
809 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
810 guard_error!(
811 self,
812 "Failed to insert certificate '{certificate_id}' from block {} - {error}",
813 block.height()
814 );
815 }
816 }
817}
818
819#[cfg(test)]
820impl<N: Network> Storage<N> {
821 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
823 &self.ledger
824 }
825
826 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Address<N>)>)> {
828 self.rounds.read().clone().into_iter()
829 }
830
831 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
833 self.certificates.read().clone().into_iter()
834 }
835
836 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
838 self.batch_ids.read().clone().into_iter()
839 }
840
841 pub fn transmissions_iter(
843 &self,
844 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
845 self.transmissions.as_hashmap().into_iter()
846 }
847
848 #[cfg(test)]
852 #[doc(hidden)]
853 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
854 let round = certificate.round();
856 let certificate_id = certificate.id();
858 let author = certificate.author();
860
861 self.rounds.write().entry(round).or_default().insert((certificate_id, author));
863 let transmission_ids = certificate.transmission_ids().clone();
865 self.certificates.write().insert(certificate_id, certificate);
867 self.batch_ids.write().insert(certificate_id, round);
869
870 let missing_transmissions = transmission_ids
872 .iter()
873 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
874 .collect::<HashMap<_, _>>();
875 self.transmissions.insert_transmissions(
877 certificate_id,
878 transmission_ids,
879 Default::default(),
880 missing_transmissions,
881 );
882 }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887 use super::*;
888 use amareleo_node_bft_ledger_service::MockLedgerService;
889 use amareleo_node_bft_storage_service::BFTMemoryService;
890 use snarkvm::{
891 ledger::narwhal::{Data, batch_certificate::test_helpers::sample_batch_certificate_for_round_with_committee},
892 prelude::{Rng, TestRng},
893 };
894
895 use ::bytes::Bytes;
896 use indexmap::indexset;
897
898 type CurrentNetwork = snarkvm::prelude::MainnetV0;
899
900 pub fn assert_storage<N: Network>(
902 storage: &Storage<N>,
903 rounds: &[(u64, IndexSet<(Field<N>, Address<N>)>)],
904 certificates: &[(Field<N>, BatchCertificate<N>)],
905 batch_ids: &[(Field<N>, u64)],
906 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
907 ) {
908 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
910 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
912 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
914 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
916 }
917
918 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
920 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
922 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
924 match rng.gen::<bool>() {
926 true => Transmission::Solution(s(rng)),
927 false => Transmission::Transaction(t(rng)),
928 }
929 }
930
931 pub(crate) fn sample_transmissions(
933 certificate: &BatchCertificate<CurrentNetwork>,
934 rng: &mut TestRng,
935 ) -> (
936 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
937 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
938 ) {
939 let certificate_id = certificate.id();
941
942 let mut missing_transmissions = HashMap::new();
943 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
944 for transmission_id in certificate.transmission_ids() {
945 let transmission = sample_transmission(rng);
947 missing_transmissions.insert(*transmission_id, transmission.clone());
949 transmissions
951 .entry(*transmission_id)
952 .or_insert((transmission, Default::default()))
953 .1
954 .insert(certificate_id);
955 }
956 (missing_transmissions, transmissions)
957 }
958
959 #[test]
962 fn test_certificate_insert_remove() {
963 let rng = &mut TestRng::default();
964
965 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
967 let ledger = Arc::new(MockLedgerService::new(committee));
969 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
971
972 assert_storage(&storage, &[], &[], &[], &Default::default());
974
975 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
977 let certificate_id = certificate.id();
979 let round = certificate.round();
981 let author = certificate.author();
983
984 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
986
987 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
989 assert!(storage.contains_certificate(certificate_id));
991 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
993 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
995
996 {
998 let rounds = [(round, indexset! { (certificate_id, author) })];
1000 let certificates = [(certificate_id, certificate.clone())];
1002 let batch_ids = [(certificate_id, round)];
1004 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1006 }
1007
1008 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1010 assert_eq!(certificate, candidate_certificate);
1012
1013 assert!(storage.remove_certificate(certificate_id));
1015 assert!(!storage.contains_certificate(certificate_id));
1017 assert!(storage.get_certificates_for_round(round).is_empty());
1019 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1021 assert_storage(&storage, &[], &[], &[], &Default::default());
1023 }
1024
1025 #[test]
1026 fn test_certificate_duplicate() {
1027 let rng = &mut TestRng::default();
1028
1029 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1031 let ledger = Arc::new(MockLedgerService::new(committee));
1033 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1035
1036 assert_storage(&storage, &[], &[], &[], &Default::default());
1038
1039 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1041 let certificate_id = certificate.id();
1043 let round = certificate.round();
1045 let author = certificate.author();
1047
1048 let rounds = [(round, indexset! { (certificate_id, author) })];
1050 let certificates = [(certificate_id, certificate.clone())];
1052 let batch_ids = [(certificate_id, round)];
1054 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1056
1057 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1059 assert!(storage.contains_certificate(certificate_id));
1061 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1063
1064 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1066 assert!(storage.contains_certificate(certificate_id));
1068 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1070
1071 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1073 assert!(storage.contains_certificate(certificate_id));
1075 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1077 }
1078
1079 #[test]
1081 fn test_invalid_certificate_insufficient_previous_certs() {
1082 let rng = &mut TestRng::default();
1083
1084 let (committee, private_keys) =
1086 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1087 let ledger = Arc::new(MockLedgerService::new(committee));
1089 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1091
1092 let mut previous_certs = IndexSet::default();
1094
1095 for round in 1..=6 {
1096 let mut new_certs = IndexSet::default();
1097
1098 for private_key in private_keys.iter() {
1100 let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1101
1102 let certificate = sample_batch_certificate_for_round_with_committee(
1103 round,
1104 previous_certs.clone(),
1105 private_key,
1106 &other_keys,
1107 rng,
1108 );
1109
1110 let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1112 let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1113
1114 if round <= 5 {
1115 new_certs.insert(certificate.id());
1116 storage
1117 .insert_certificate(certificate, transmissions, Default::default())
1118 .expect("Valid certificate rejected");
1119 } else {
1120 assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1121 }
1122 }
1123
1124 if round < 5 {
1125 previous_certs = new_certs;
1126 } else {
1127 previous_certs = new_certs.into_iter().skip(6).collect();
1129 }
1130 }
1131 }
1132
1133 #[test]
1135 fn test_invalid_certificate_wrong_round_number() {
1136 let rng = &mut TestRng::default();
1137
1138 let (committee, private_keys) =
1140 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1141 let ledger = Arc::new(MockLedgerService::new(committee));
1143 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1145
1146 let mut previous_certs = IndexSet::default();
1148
1149 for round in 1..=6 {
1150 let mut new_certs = IndexSet::default();
1151
1152 for private_key in private_keys.iter() {
1154 let cert_round = round.min(5); let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1156
1157 let certificate = sample_batch_certificate_for_round_with_committee(
1158 cert_round,
1159 previous_certs.clone(),
1160 private_key,
1161 &other_keys,
1162 rng,
1163 );
1164
1165 let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1167 let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1168
1169 if round <= 5 {
1170 new_certs.insert(certificate.id());
1171 storage
1172 .insert_certificate(certificate, transmissions, Default::default())
1173 .expect("Valid certificate rejected");
1174 } else {
1175 assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1176 }
1177 }
1178
1179 if round < 5 {
1180 previous_certs = new_certs;
1181 } else {
1182 previous_certs = new_certs.into_iter().skip(6).collect();
1184 }
1185 }
1186 }
1187}
1188
1189#[cfg(test)]
1190pub mod prop_tests {
1191 use super::*;
1192 use crate::helpers::{now, storage::tests::assert_storage};
1193 use amareleo_node_bft_ledger_service::MockLedgerService;
1194 use amareleo_node_bft_storage_service::BFTMemoryService;
1195 use snarkvm::{
1196 ledger::{
1197 committee::prop_tests::{CommitteeContext, ValidatorSet},
1198 narwhal::{BatchHeader, Data},
1199 puzzle::SolutionID,
1200 },
1201 prelude::{Signature, Uniform},
1202 };
1203
1204 use ::bytes::Bytes;
1205 use indexmap::indexset;
1206 use proptest::{
1207 collection,
1208 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1209 prop_oneof,
1210 sample::{Selector, size_range},
1211 test_runner::TestRng,
1212 };
1213 use rand::{CryptoRng, Error, Rng, RngCore};
1214 use std::fmt::Debug;
1215 use test_strategy::proptest;
1216
1217 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1218
1219 impl Arbitrary for Storage<CurrentNetwork> {
1220 type Parameters = CommitteeContext;
1221 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1222
1223 fn arbitrary() -> Self::Strategy {
1224 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1225 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1226 let ledger = Arc::new(MockLedgerService::new(committee));
1227 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1228 })
1229 .boxed()
1230 }
1231
1232 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1233 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1234 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1235 let ledger = Arc::new(MockLedgerService::new(committee));
1236 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1237 })
1238 .boxed()
1239 }
1240 }
1241
1242 #[derive(Debug)]
1244 pub struct CryptoTestRng(TestRng);
1245
1246 impl Arbitrary for CryptoTestRng {
1247 type Parameters = ();
1248 type Strategy = BoxedStrategy<CryptoTestRng>;
1249
1250 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1251 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1252 }
1253 }
1254 impl RngCore for CryptoTestRng {
1255 fn next_u32(&mut self) -> u32 {
1256 self.0.next_u32()
1257 }
1258
1259 fn next_u64(&mut self) -> u64 {
1260 self.0.next_u64()
1261 }
1262
1263 fn fill_bytes(&mut self, dest: &mut [u8]) {
1264 self.0.fill_bytes(dest);
1265 }
1266
1267 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1268 self.0.try_fill_bytes(dest)
1269 }
1270 }
1271
1272 impl CryptoRng for CryptoTestRng {}
1273
1274 #[derive(Debug, Clone)]
1275 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1276
1277 impl Arbitrary for AnyTransmission {
1278 type Parameters = ();
1279 type Strategy = BoxedStrategy<AnyTransmission>;
1280
1281 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1282 any_transmission().prop_map(AnyTransmission).boxed()
1283 }
1284 }
1285
1286 #[derive(Debug, Clone)]
1287 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1288
1289 impl Arbitrary for AnyTransmissionID {
1290 type Parameters = ();
1291 type Strategy = BoxedStrategy<AnyTransmissionID>;
1292
1293 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1294 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1295 }
1296 }
1297
1298 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1299 prop_oneof![
1300 (collection::vec(any::<u8>(), 512..=512))
1301 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1302 (collection::vec(any::<u8>(), 2048..=2048))
1303 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1304 ]
1305 .boxed()
1306 }
1307
1308 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1309 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1310 }
1311
1312 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1313 Just(0)
1314 .prop_perturb(|_, rng| {
1315 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1316 })
1317 .boxed()
1318 }
1319
1320 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1321 prop_oneof![
1322 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1323 id,
1324 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1325 )),
1326 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1327 id,
1328 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1329 )),
1330 ]
1331 .boxed()
1332 }
1333
1334 pub fn sign_batch_header<R: Rng + CryptoRng>(
1335 validator_set: &ValidatorSet,
1336 batch_header: &BatchHeader<CurrentNetwork>,
1337 rng: &mut R,
1338 ) -> IndexSet<Signature<CurrentNetwork>> {
1339 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1340 for validator in validator_set.0.iter() {
1341 let private_key = validator.private_key;
1342 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1343 }
1344 signatures
1345 }
1346
1347 #[proptest]
1348 fn test_certificate_duplicate(
1349 context: CommitteeContext,
1350 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1351 mut rng: CryptoTestRng,
1352 selector: Selector,
1353 ) {
1354 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1355 let committee_id = committee.id();
1356
1357 let ledger = Arc::new(MockLedgerService::new(committee));
1359 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1360
1361 assert_storage(&storage, &[], &[], &[], &Default::default());
1363
1364 let signer = selector.select(&validators);
1366
1367 let mut transmission_map = IndexMap::new();
1368
1369 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1370 transmission_map.insert(*id, t.clone());
1371 }
1372
1373 let batch_header = BatchHeader::new(
1374 &signer.private_key,
1375 0,
1376 now(),
1377 committee_id,
1378 transmission_map.keys().cloned().collect(),
1379 Default::default(),
1380 &mut rng,
1381 )
1382 .unwrap();
1383
1384 let mut validators = validators.clone();
1387 validators.remove(signer);
1388
1389 let certificate = BatchCertificate::from(
1390 batch_header.clone(),
1391 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1392 )
1393 .unwrap();
1394
1395 let certificate_id = certificate.id();
1397 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1398 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1399 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1400 }
1401
1402 let round = certificate.round();
1404 let author = certificate.author();
1406
1407 let rounds = [(round, indexset! { (certificate_id, author) })];
1409 let certificates = [(certificate_id, certificate.clone())];
1411 let batch_ids = [(certificate_id, round)];
1413
1414 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1416 transmission_map.into_iter().collect();
1417 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1418 assert!(storage.contains_certificate(certificate_id));
1420 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1422
1423 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1425 assert!(storage.contains_certificate(certificate_id));
1427 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1429
1430 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1432 assert!(storage.contains_certificate(certificate_id));
1434 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1436 }
1437}