1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::TracingHandler;
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21 ledger::{
22 block::{Block, Transaction},
23 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24 },
25 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26 utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30use parking_lot::RwLock;
31use rayon::iter::{IntoParallelIterator, ParallelIterator};
32use std::{
33 collections::{HashMap, HashSet},
34 sync::{
35 Arc,
36 atomic::{AtomicU32, AtomicU64, Ordering},
37 },
38};
39use tracing::subscriber::DefaultGuard;
40
41#[derive(Clone, Debug)]
42pub struct Storage<N: Network>(Arc<StorageInner<N>>);
43
44impl<N: Network> std::ops::Deref for Storage<N> {
45 type Target = Arc<StorageInner<N>>;
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52#[derive(Debug)]
72pub struct StorageInner<N: Network> {
73 ledger: Arc<dyn LedgerService<N>>,
75 current_height: AtomicU32,
78 current_round: AtomicU64,
81 gc_round: AtomicU64,
83 max_gc_rounds: u64,
85 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
88 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
90 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
92 transmissions: Arc<dyn StorageService<N>>,
94 tracing: Option<TracingHandler>,
96}
97
98impl<N: Network> Storage<N> {
99 pub fn new(
101 ledger: Arc<dyn LedgerService<N>>,
102 transmissions: Arc<dyn StorageService<N>>,
103 max_gc_rounds: u64,
104 tracing: Option<TracingHandler>,
105 ) -> Self {
106 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
108 let current_round = committee.starting_round().max(1);
110
111 let storage = Self(Arc::new(StorageInner {
113 ledger,
114 current_height: Default::default(),
115 current_round: Default::default(),
116 gc_round: Default::default(),
117 max_gc_rounds,
118 rounds: Default::default(),
119 certificates: Default::default(),
120 batch_ids: Default::default(),
121 transmissions,
122 tracing,
123 }));
124 storage.update_current_round(current_round);
126 storage.garbage_collect_certificates(current_round);
128 storage
130 }
131}
132
133impl<N: Network> Storage<N> {
134 pub fn current_height(&self) -> u32 {
136 self.current_height.load(Ordering::SeqCst)
138 }
139}
140
141impl<N: Network> Storage<N> {
142 pub fn current_round(&self) -> u64 {
144 self.current_round.load(Ordering::SeqCst)
146 }
147
148 pub fn gc_round(&self) -> u64 {
150 self.gc_round.load(Ordering::SeqCst)
152 }
153
154 pub fn max_gc_rounds(&self) -> u64 {
156 self.max_gc_rounds
157 }
158
159 pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
161 self.tracing.clone().map(|trace_handle| trace_handle.subscribe_thread())
162 }
163
164 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
167 let _guard = self.get_tracing_guard();
168
169 let next_round = current_round + 1;
171
172 {
174 let storage_round = self.current_round();
176 if next_round < storage_round {
178 return Ok(storage_round);
179 }
180 }
181
182 let current_committee = self.ledger.current_committee()?;
184 let starting_round = current_committee.starting_round();
186 if next_round < starting_round {
188 let latest_block_round = self.ledger.latest_round();
190 info!(
192 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
193 );
194 self.sync_round_with_block(latest_block_round);
196 return Ok(latest_block_round);
198 }
199
200 self.update_current_round(next_round);
202
203 #[cfg(feature = "metrics")]
204 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
205
206 let storage_round = self.current_round();
208 let gc_round = self.gc_round();
210 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
212 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
214
215 info!("Starting round {next_round}...");
217 Ok(next_round)
218 }
219
220 fn update_current_round(&self, next_round: u64) {
222 self.current_round.store(next_round, Ordering::SeqCst);
224 }
225
226 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
228 let current_gc_round = self.gc_round();
230 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
232 if next_gc_round > current_gc_round {
234 for gc_round in current_gc_round..=next_gc_round {
236 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
238 self.remove_certificate(id);
240 }
241 }
242 self.gc_round.store(next_gc_round, Ordering::SeqCst);
244 }
245 }
246}
247
248impl<N: Network> Storage<N> {
249 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
251 self.rounds.read().contains_key(&round)
253 }
254
255 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
257 self.certificates.read().contains_key(&certificate_id)
259 }
260
261 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
263 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
264 }
265
266 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
268 self.batch_ids.read().contains_key(&batch_id)
270 }
271
272 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
274 self.transmissions.contains_transmission(transmission_id.into())
275 }
276
277 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
280 self.transmissions.get_transmission(transmission_id.into())
281 }
282
283 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
286 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
288 }
289
290 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
293 self.batch_ids.read().get(&batch_id).copied()
295 }
296
297 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
300 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
302 }
303
304 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
307 self.certificates.read().get(&certificate_id).cloned()
309 }
310
311 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
315 if let Some(entries) = self.rounds.read().get(&round) {
317 let certificates = self.certificates.read();
318 entries.iter().find_map(
319 |(certificate_id, _, a)| {
320 if a == &author { certificates.get(certificate_id).cloned() } else { None }
321 },
322 )
323 } else {
324 Default::default()
325 }
326 }
327
328 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
331 if round == 0 {
333 return Default::default();
334 }
335 if let Some(entries) = self.rounds.read().get(&round) {
337 let certificates = self.certificates.read();
338 entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
339 } else {
340 Default::default()
341 }
342 }
343
344 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
347 if round == 0 {
349 return Default::default();
350 }
351 if let Some(entries) = self.rounds.read().get(&round) {
353 entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
354 } else {
355 Default::default()
356 }
357 }
358
359 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
362 if round == 0 {
364 return Default::default();
365 }
366 if let Some(entries) = self.rounds.read().get(&round) {
368 entries.iter().map(|(_, _, author)| *author).collect()
369 } else {
370 Default::default()
371 }
372 }
373
374 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
377 let rounds = self.rounds.read();
379 let certificates = self.certificates.read();
380
381 cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
383 .flat_map(|(_, certificates_for_round)| {
384 cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
386 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
388 None
389 } else {
390 certificates.get(&certificate_id).cloned()
392 }
393 })
394 })
395 .collect()
396 }
397
398 pub fn check_batch_header(
411 &self,
412 batch_header: &BatchHeader<N>,
413 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
414 aborted_transmissions: HashSet<TransmissionID<N>>,
415 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
416 let round = batch_header.round();
418 let gc_round = self.gc_round();
420 let gc_log = format!("(gc = {gc_round})");
422
423 if self.contains_batch(batch_header.batch_id()) {
425 bail!("Batch for round {round} already exists in storage {gc_log}")
426 }
427
428 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
430 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
431 };
432 if !committee_lookback.is_committee_member(batch_header.author()) {
434 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
435 }
436
437 check_timestamp_for_liveness(batch_header.timestamp())?;
439
440 let missing_transmissions = self
442 .transmissions
443 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
444 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
445
446 let previous_round = round.saturating_sub(1);
448 if previous_round > gc_round {
450 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
452 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
453 };
454 if !self.contains_certificates_for_round(previous_round) {
456 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
457 }
458 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
460 bail!("Too many previous certificates for round {round} {gc_log}")
461 }
462 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
464 for previous_certificate_id in batch_header.previous_certificate_ids() {
466 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
468 bail!(
469 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
470 fmt_id(previous_certificate_id)
471 )
472 };
473 if previous_certificate.round() != previous_round {
475 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
476 }
477 if previous_authors.contains(&previous_certificate.author()) {
479 bail!("Round {round} certificate contains a duplicate author {gc_log}")
480 }
481 previous_authors.insert(previous_certificate.author());
483 }
484 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
486 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
487 }
488 }
489 Ok(missing_transmissions)
490 }
491
492 pub fn check_certificate(
508 &self,
509 certificate: &BatchCertificate<N>,
510 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
511 aborted_transmissions: HashSet<TransmissionID<N>>,
512 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
513 let round = certificate.round();
515 let gc_round = self.gc_round();
517 let gc_log = format!("(gc = {gc_round})");
519
520 if self.contains_certificate(certificate.id()) {
522 bail!("Certificate for round {round} already exists in storage {gc_log}")
523 }
524
525 if self.contains_certificate_in_round_from(round, certificate.author()) {
527 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
528 }
529
530 let missing_transmissions =
532 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
533
534 check_timestamp_for_liveness(certificate.timestamp())?;
536
537 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
539 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
540 };
541
542 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
544 signers.insert(certificate.author());
546
547 for signature in certificate.signatures() {
549 let signer = signature.to_address();
551 if !committee_lookback.is_committee_member(signer) {
553 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
554 }
555 signers.insert(signer);
557 }
558
559 if !committee_lookback.is_quorum_threshold_reached(&signers) {
561 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
562 }
563 Ok(missing_transmissions)
564 }
565
566 pub fn insert_certificate(
578 &self,
579 certificate: BatchCertificate<N>,
580 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
581 aborted_transmissions: HashSet<TransmissionID<N>>,
582 ) -> Result<()> {
583 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
585 let missing_transmissions =
587 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
588 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
590 Ok(())
591 }
592
593 fn insert_certificate_atomic(
599 &self,
600 certificate: BatchCertificate<N>,
601 aborted_transmission_ids: HashSet<TransmissionID<N>>,
602 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
603 ) {
604 let round = certificate.round();
606 let certificate_id = certificate.id();
608 let batch_id = certificate.batch_id();
610 let author = certificate.author();
612
613 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
615 let transmission_ids = certificate.transmission_ids().clone();
617 self.certificates.write().insert(certificate_id, certificate);
619 self.batch_ids.write().insert(batch_id, round);
621 self.transmissions.insert_transmissions(
623 certificate_id,
624 transmission_ids,
625 aborted_transmission_ids,
626 missing_transmissions,
627 );
628 }
629
630 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
637 let _guard = self.get_tracing_guard();
638 let Some(certificate) = self.get_certificate(certificate_id) else {
640 warn!("Certificate {certificate_id} does not exist in storage");
641 return false;
642 };
643 let round = certificate.round();
645 let batch_id = certificate.batch_id();
647 let author = certificate.author();
649
650 match self.rounds.write().entry(round) {
656 Entry::Occupied(mut entry) => {
657 entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
659 if entry.get().is_empty() {
661 entry.swap_remove();
662 }
663 }
664 Entry::Vacant(_) => {}
665 }
666 self.certificates.write().swap_remove(&certificate_id);
668 self.batch_ids.write().swap_remove(&batch_id);
670 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
672 true
674 }
675}
676
677impl<N: Network> Storage<N> {
678 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
680 if next_height > self.current_height() {
682 self.current_height.store(next_height, Ordering::SeqCst);
684 }
685 }
686
687 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
689 let _guard = self.get_tracing_guard();
690 let next_round = next_round.max(1);
692 if next_round > self.current_round() {
694 self.update_current_round(next_round);
696 info!("Synced to round {next_round}...");
698 }
699 }
700
701 pub(crate) fn sync_certificate_with_block(
703 &self,
704 block: &Block<N>,
705 certificate: BatchCertificate<N>,
706 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
707 ) {
708 if certificate.round() <= self.gc_round() {
710 return;
711 }
712 if self.contains_certificate(certificate.id()) {
714 return;
715 }
716 let mut missing_transmissions = HashMap::new();
718
719 let mut aborted_transmissions = HashSet::new();
721
722 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
724 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
725
726 let _guard = self.get_tracing_guard();
727
728 for transmission_id in certificate.transmission_ids() {
730 if missing_transmissions.contains_key(transmission_id) {
732 continue;
733 }
734 if self.contains_transmission(*transmission_id) {
736 continue;
737 }
738 match transmission_id {
740 TransmissionID::Ratification => (),
741 TransmissionID::Solution(solution_id, _) => {
742 match block.get_solution(solution_id) {
744 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
746 None => match self.ledger.get_solution(solution_id) {
748 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
750 Err(_) => {
752 match aborted_solutions.contains(solution_id)
754 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
755 {
756 true => {
757 aborted_transmissions.insert(*transmission_id);
758 }
759 false => error!("Missing solution {solution_id} in block {}", block.height()),
760 }
761 continue;
762 }
763 },
764 };
765 }
766 TransmissionID::Transaction(transaction_id, _) => {
767 match unconfirmed_transactions.get(transaction_id) {
769 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
771 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
773 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
775 Err(_) => {
777 match aborted_transactions.contains(transaction_id)
779 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
780 {
781 true => {
782 aborted_transmissions.insert(*transmission_id);
783 }
784 false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
785 }
786 continue;
787 }
788 },
789 };
790 }
791 }
792 }
793 let certificate_id = fmt_id(certificate.id());
795 debug!(
796 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
797 certificate.round(),
798 certificate.transmission_ids().len()
799 );
800 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
801 error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
802 }
803 }
804}
805
806#[cfg(test)]
807impl<N: Network> Storage<N> {
808 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
810 &self.ledger
811 }
812
813 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
815 self.rounds.read().clone().into_iter()
816 }
817
818 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
820 self.certificates.read().clone().into_iter()
821 }
822
823 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
825 self.batch_ids.read().clone().into_iter()
826 }
827
828 pub fn transmissions_iter(
830 &self,
831 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
832 self.transmissions.as_hashmap().into_iter()
833 }
834
835 #[cfg(test)]
839 #[doc(hidden)]
840 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
841 let round = certificate.round();
843 let certificate_id = certificate.id();
845 let batch_id = certificate.batch_id();
847 let author = certificate.author();
849
850 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
852 let transmission_ids = certificate.transmission_ids().clone();
854 self.certificates.write().insert(certificate_id, certificate);
856 self.batch_ids.write().insert(batch_id, round);
858
859 let missing_transmissions = transmission_ids
861 .iter()
862 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
863 .collect::<HashMap<_, _>>();
864 self.transmissions.insert_transmissions(
866 certificate_id,
867 transmission_ids,
868 Default::default(),
869 missing_transmissions,
870 );
871 }
872}
873
874#[cfg(test)]
875pub(crate) mod tests {
876 use super::*;
877 use amareleo_node_bft_ledger_service::MockLedgerService;
878 use amareleo_node_bft_storage_service::BFTMemoryService;
879 use snarkvm::{
880 ledger::narwhal::Data,
881 prelude::{Rng, TestRng},
882 };
883
884 use ::bytes::Bytes;
885 use indexmap::indexset;
886
887 type CurrentNetwork = snarkvm::prelude::MainnetV0;
888
889 pub fn assert_storage<N: Network>(
891 storage: &Storage<N>,
892 rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
893 certificates: &[(Field<N>, BatchCertificate<N>)],
894 batch_ids: &[(Field<N>, u64)],
895 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
896 ) {
897 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
899 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
901 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
903 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
905 }
906
907 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
909 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
911 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
913 match rng.gen::<bool>() {
915 true => Transmission::Solution(s(rng)),
916 false => Transmission::Transaction(t(rng)),
917 }
918 }
919
920 pub(crate) fn sample_transmissions(
922 certificate: &BatchCertificate<CurrentNetwork>,
923 rng: &mut TestRng,
924 ) -> (
925 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
926 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
927 ) {
928 let certificate_id = certificate.id();
930
931 let mut missing_transmissions = HashMap::new();
932 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
933 for transmission_id in certificate.transmission_ids() {
934 let transmission = sample_transmission(rng);
936 missing_transmissions.insert(*transmission_id, transmission.clone());
938 transmissions
940 .entry(*transmission_id)
941 .or_insert((transmission, Default::default()))
942 .1
943 .insert(certificate_id);
944 }
945 (missing_transmissions, transmissions)
946 }
947
948 #[test]
951 fn test_certificate_insert_remove() {
952 let rng = &mut TestRng::default();
953
954 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
956 let ledger = Arc::new(MockLedgerService::new(committee));
958 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
960
961 assert_storage(&storage, &[], &[], &[], &Default::default());
963
964 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
966 let certificate_id = certificate.id();
968 let round = certificate.round();
970 let batch_id = certificate.batch_id();
972 let author = certificate.author();
974
975 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
977
978 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
980 assert!(storage.contains_certificate(certificate_id));
982 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
984 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
986
987 {
989 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
991 let certificates = [(certificate_id, certificate.clone())];
993 let batch_ids = [(batch_id, round)];
995 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
997 }
998
999 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1001 assert_eq!(certificate, candidate_certificate);
1003
1004 assert!(storage.remove_certificate(certificate_id));
1006 assert!(!storage.contains_certificate(certificate_id));
1008 assert!(storage.get_certificates_for_round(round).is_empty());
1010 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1012 assert_storage(&storage, &[], &[], &[], &Default::default());
1014 }
1015
1016 #[test]
1017 fn test_certificate_duplicate() {
1018 let rng = &mut TestRng::default();
1019
1020 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1022 let ledger = Arc::new(MockLedgerService::new(committee));
1024 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1026
1027 assert_storage(&storage, &[], &[], &[], &Default::default());
1029
1030 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1032 let certificate_id = certificate.id();
1034 let round = certificate.round();
1036 let batch_id = certificate.batch_id();
1038 let author = certificate.author();
1040
1041 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1043 let certificates = [(certificate_id, certificate.clone())];
1045 let batch_ids = [(batch_id, round)];
1047 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1049
1050 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1052 assert!(storage.contains_certificate(certificate_id));
1054 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1056
1057 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1059 assert!(storage.contains_certificate(certificate_id));
1061 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1063
1064 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1066 assert!(storage.contains_certificate(certificate_id));
1068 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1070 }
1071}
1072
1073#[cfg(test)]
1074pub mod prop_tests {
1075 use super::*;
1076 use crate::helpers::{now, storage::tests::assert_storage};
1077 use amareleo_node_bft_ledger_service::MockLedgerService;
1078 use amareleo_node_bft_storage_service::BFTMemoryService;
1079 use snarkvm::{
1080 ledger::{
1081 committee::prop_tests::{CommitteeContext, ValidatorSet},
1082 narwhal::{BatchHeader, Data},
1083 puzzle::SolutionID,
1084 },
1085 prelude::{Signature, Uniform},
1086 };
1087
1088 use ::bytes::Bytes;
1089 use indexmap::indexset;
1090 use proptest::{
1091 collection,
1092 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1093 prop_oneof,
1094 sample::{Selector, size_range},
1095 test_runner::TestRng,
1096 };
1097 use rand::{CryptoRng, Error, Rng, RngCore};
1098 use std::fmt::Debug;
1099 use test_strategy::proptest;
1100
1101 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1102
1103 impl Arbitrary for Storage<CurrentNetwork> {
1104 type Parameters = CommitteeContext;
1105 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1106
1107 fn arbitrary() -> Self::Strategy {
1108 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1109 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1110 let ledger = Arc::new(MockLedgerService::new(committee));
1111 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1112 })
1113 .boxed()
1114 }
1115
1116 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1117 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1118 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1119 let ledger = Arc::new(MockLedgerService::new(committee));
1120 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1121 })
1122 .boxed()
1123 }
1124 }
1125
1126 #[derive(Debug)]
1128 pub struct CryptoTestRng(TestRng);
1129
1130 impl Arbitrary for CryptoTestRng {
1131 type Parameters = ();
1132 type Strategy = BoxedStrategy<CryptoTestRng>;
1133
1134 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1135 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1136 }
1137 }
1138 impl RngCore for CryptoTestRng {
1139 fn next_u32(&mut self) -> u32 {
1140 self.0.next_u32()
1141 }
1142
1143 fn next_u64(&mut self) -> u64 {
1144 self.0.next_u64()
1145 }
1146
1147 fn fill_bytes(&mut self, dest: &mut [u8]) {
1148 self.0.fill_bytes(dest);
1149 }
1150
1151 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1152 self.0.try_fill_bytes(dest)
1153 }
1154 }
1155
1156 impl CryptoRng for CryptoTestRng {}
1157
1158 #[derive(Debug, Clone)]
1159 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1160
1161 impl Arbitrary for AnyTransmission {
1162 type Parameters = ();
1163 type Strategy = BoxedStrategy<AnyTransmission>;
1164
1165 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1166 any_transmission().prop_map(AnyTransmission).boxed()
1167 }
1168 }
1169
1170 #[derive(Debug, Clone)]
1171 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1172
1173 impl Arbitrary for AnyTransmissionID {
1174 type Parameters = ();
1175 type Strategy = BoxedStrategy<AnyTransmissionID>;
1176
1177 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1178 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1179 }
1180 }
1181
1182 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1183 prop_oneof![
1184 (collection::vec(any::<u8>(), 512..=512))
1185 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1186 (collection::vec(any::<u8>(), 2048..=2048))
1187 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1188 ]
1189 .boxed()
1190 }
1191
1192 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1193 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1194 }
1195
1196 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1197 Just(0)
1198 .prop_perturb(|_, rng| {
1199 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1200 })
1201 .boxed()
1202 }
1203
1204 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1205 prop_oneof![
1206 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1207 id,
1208 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1209 )),
1210 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1211 id,
1212 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1213 )),
1214 ]
1215 .boxed()
1216 }
1217
1218 pub fn sign_batch_header<R: Rng + CryptoRng>(
1219 validator_set: &ValidatorSet,
1220 batch_header: &BatchHeader<CurrentNetwork>,
1221 rng: &mut R,
1222 ) -> IndexSet<Signature<CurrentNetwork>> {
1223 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1224 for validator in validator_set.0.iter() {
1225 let private_key = validator.private_key;
1226 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1227 }
1228 signatures
1229 }
1230
1231 #[proptest]
1232 fn test_certificate_duplicate(
1233 context: CommitteeContext,
1234 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1235 mut rng: CryptoTestRng,
1236 selector: Selector,
1237 ) {
1238 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1239 let committee_id = committee.id();
1240
1241 let ledger = Arc::new(MockLedgerService::new(committee));
1243 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1244
1245 assert_storage(&storage, &[], &[], &[], &Default::default());
1247
1248 let signer = selector.select(&validators);
1250
1251 let mut transmission_map = IndexMap::new();
1252
1253 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1254 transmission_map.insert(*id, t.clone());
1255 }
1256
1257 let batch_header = BatchHeader::new(
1258 &signer.private_key,
1259 0,
1260 now(),
1261 committee_id,
1262 transmission_map.keys().cloned().collect(),
1263 Default::default(),
1264 &mut rng,
1265 )
1266 .unwrap();
1267
1268 let mut validators = validators.clone();
1271 validators.remove(signer);
1272
1273 let certificate = BatchCertificate::from(
1274 batch_header.clone(),
1275 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1276 )
1277 .unwrap();
1278
1279 let certificate_id = certificate.id();
1281 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1282 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1283 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1284 }
1285
1286 let round = certificate.round();
1288 let batch_id = certificate.batch_id();
1290 let author = certificate.author();
1292
1293 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1295 let certificates = [(certificate_id, certificate.clone())];
1297 let batch_ids = [(batch_id, round)];
1299
1300 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1302 transmission_map.into_iter().collect();
1303 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1304 assert!(storage.contains_certificate(certificate_id));
1306 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1308
1309 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1311 assert!(storage.contains_certificate(certificate_id));
1313 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1315
1316 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1318 assert!(storage.contains_certificate(certificate_id));
1320 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1322 }
1323}