1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_node_bft_ledger_service::LedgerService;
18use amareleo_node_bft_storage_service::StorageService;
19use snarkvm::{
20 ledger::{
21 block::{Block, Transaction},
22 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
23 },
24 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
25};
26
27use indexmap::{IndexMap, IndexSet, map::Entry};
28use parking_lot::RwLock;
29use std::{
30 collections::{HashMap, HashSet},
31 sync::{
32 Arc,
33 atomic::{AtomicU32, AtomicU64, Ordering},
34 },
35};
36
37#[derive(Clone, Debug)]
38pub struct Storage<N: Network>(Arc<StorageInner<N>>);
39
40impl<N: Network> std::ops::Deref for Storage<N> {
41 type Target = Arc<StorageInner<N>>;
42
43 fn deref(&self) -> &Self::Target {
44 &self.0
45 }
46}
47
48#[derive(Debug)]
68pub struct StorageInner<N: Network> {
69 ledger: Arc<dyn LedgerService<N>>,
71 current_height: AtomicU32,
74 current_round: AtomicU64,
77 gc_round: AtomicU64,
79 max_gc_rounds: u64,
81 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
84 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
86 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
88 transmissions: Arc<dyn StorageService<N>>,
90}
91
92impl<N: Network> Storage<N> {
93 pub fn new(
95 ledger: Arc<dyn LedgerService<N>>,
96 transmissions: Arc<dyn StorageService<N>>,
97 max_gc_rounds: u64,
98 ) -> Self {
99 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
101 let current_round = committee.starting_round().max(1);
103
104 let storage = Self(Arc::new(StorageInner {
106 ledger,
107 current_height: Default::default(),
108 current_round: Default::default(),
109 gc_round: Default::default(),
110 max_gc_rounds,
111 rounds: Default::default(),
112 certificates: Default::default(),
113 batch_ids: Default::default(),
114 transmissions,
115 }));
116 storage.update_current_round(current_round);
118 storage.garbage_collect_certificates(current_round);
120 storage
122 }
123}
124
125impl<N: Network> Storage<N> {
126 pub fn current_height(&self) -> u32 {
128 self.current_height.load(Ordering::SeqCst)
130 }
131}
132
133impl<N: Network> Storage<N> {
134 pub fn current_round(&self) -> u64 {
136 self.current_round.load(Ordering::SeqCst)
138 }
139
140 pub fn gc_round(&self) -> u64 {
142 self.gc_round.load(Ordering::SeqCst)
144 }
145
146 pub fn max_gc_rounds(&self) -> u64 {
148 self.max_gc_rounds
149 }
150
151 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
154 let next_round = current_round + 1;
156
157 {
159 let storage_round = self.current_round();
161 if next_round < storage_round {
163 return Ok(storage_round);
164 }
165 }
166
167 let current_committee = self.ledger.current_committee()?;
169 let starting_round = current_committee.starting_round();
171 if next_round < starting_round {
173 let latest_block_round = self.ledger.latest_round();
175 info!(
177 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
178 );
179 self.sync_round_with_block(latest_block_round);
181 return Ok(latest_block_round);
183 }
184
185 self.update_current_round(next_round);
187
188 #[cfg(feature = "metrics")]
189 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
190
191 let storage_round = self.current_round();
193 let gc_round = self.gc_round();
195 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
197 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
199
200 info!("Starting round {next_round}...");
202 Ok(next_round)
203 }
204
205 fn update_current_round(&self, next_round: u64) {
207 self.current_round.store(next_round, Ordering::SeqCst);
209 }
210
211 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
213 let current_gc_round = self.gc_round();
215 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
217 if next_gc_round > current_gc_round {
219 for gc_round in current_gc_round..=next_gc_round {
221 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
223 self.remove_certificate(id);
225 }
226 }
227 self.gc_round.store(next_gc_round, Ordering::SeqCst);
229 }
230 }
231}
232
233impl<N: Network> Storage<N> {
234 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
236 self.rounds.read().contains_key(&round)
238 }
239
240 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
242 self.certificates.read().contains_key(&certificate_id)
244 }
245
246 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
248 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
249 }
250
251 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
253 self.batch_ids.read().contains_key(&batch_id)
255 }
256
257 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
259 self.transmissions.contains_transmission(transmission_id.into())
260 }
261
262 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
265 self.transmissions.get_transmission(transmission_id.into())
266 }
267
268 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
271 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
273 }
274
275 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
278 self.batch_ids.read().get(&batch_id).copied()
280 }
281
282 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
285 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
287 }
288
289 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
292 self.certificates.read().get(&certificate_id).cloned()
294 }
295
296 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
300 if let Some(entries) = self.rounds.read().get(&round) {
302 let certificates = self.certificates.read();
303 entries.iter().find_map(
304 |(certificate_id, _, a)| {
305 if a == &author { certificates.get(certificate_id).cloned() } else { None }
306 },
307 )
308 } else {
309 Default::default()
310 }
311 }
312
313 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
316 if round == 0 {
318 return Default::default();
319 }
320 if let Some(entries) = self.rounds.read().get(&round) {
322 let certificates = self.certificates.read();
323 entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
324 } else {
325 Default::default()
326 }
327 }
328
329 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
332 if round == 0 {
334 return Default::default();
335 }
336 if let Some(entries) = self.rounds.read().get(&round) {
338 entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
339 } else {
340 Default::default()
341 }
342 }
343
344 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
347 if round == 0 {
349 return Default::default();
350 }
351 if let Some(entries) = self.rounds.read().get(&round) {
353 entries.iter().map(|(_, _, author)| *author).collect()
354 } else {
355 Default::default()
356 }
357 }
358
359 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
362 let mut pending_certificates = IndexSet::new();
363
364 let rounds = self.rounds.read();
366 let certificates = self.certificates.read();
367
368 for (_, certificates_for_round) in rounds.clone().sorted_by(|a, _, b, _| a.cmp(b)) {
370 for (certificate_id, _, _) in certificates_for_round {
372 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
374 continue;
375 }
376
377 match certificates.get(&certificate_id).cloned() {
379 Some(certificate) => pending_certificates.insert(certificate),
380 None => continue,
381 };
382 }
383 }
384
385 pending_certificates
386 }
387
388 pub fn check_batch_header(
401 &self,
402 batch_header: &BatchHeader<N>,
403 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
404 aborted_transmissions: HashSet<TransmissionID<N>>,
405 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
406 let round = batch_header.round();
408 let gc_round = self.gc_round();
410 let gc_log = format!("(gc = {gc_round})");
412
413 if self.contains_batch(batch_header.batch_id()) {
415 bail!("Batch for round {round} already exists in storage {gc_log}")
416 }
417
418 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
420 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
421 };
422 if !committee_lookback.is_committee_member(batch_header.author()) {
424 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
425 }
426
427 check_timestamp_for_liveness(batch_header.timestamp())?;
429
430 let missing_transmissions = self
432 .transmissions
433 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
434 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
435
436 let previous_round = round.saturating_sub(1);
438 if previous_round > gc_round {
440 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
442 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
443 };
444 if !self.contains_certificates_for_round(previous_round) {
446 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
447 }
448 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
450 bail!("Too many previous certificates for round {round} {gc_log}")
451 }
452 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
454 for previous_certificate_id in batch_header.previous_certificate_ids() {
456 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
458 bail!(
459 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
460 fmt_id(previous_certificate_id)
461 )
462 };
463 if previous_certificate.round() != previous_round {
465 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
466 }
467 if previous_authors.contains(&previous_certificate.author()) {
469 bail!("Round {round} certificate contains a duplicate author {gc_log}")
470 }
471 previous_authors.insert(previous_certificate.author());
473 }
474 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
476 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
477 }
478 }
479 Ok(missing_transmissions)
480 }
481
482 pub fn check_certificate(
498 &self,
499 certificate: &BatchCertificate<N>,
500 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
501 aborted_transmissions: HashSet<TransmissionID<N>>,
502 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
503 let round = certificate.round();
505 let gc_round = self.gc_round();
507 let gc_log = format!("(gc = {gc_round})");
509
510 if self.contains_certificate(certificate.id()) {
512 bail!("Certificate for round {round} already exists in storage {gc_log}")
513 }
514
515 if self.contains_certificate_in_round_from(round, certificate.author()) {
517 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
518 }
519
520 let missing_transmissions =
522 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
523
524 check_timestamp_for_liveness(certificate.timestamp())?;
526
527 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
529 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
530 };
531
532 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
534 signers.insert(certificate.author());
536
537 for signature in certificate.signatures() {
539 let signer = signature.to_address();
541 if !committee_lookback.is_committee_member(signer) {
543 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
544 }
545 signers.insert(signer);
547 }
548
549 if !committee_lookback.is_quorum_threshold_reached(&signers) {
551 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
552 }
553 Ok(missing_transmissions)
554 }
555
556 pub fn insert_certificate(
568 &self,
569 certificate: BatchCertificate<N>,
570 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
571 aborted_transmissions: HashSet<TransmissionID<N>>,
572 ) -> Result<()> {
573 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
575 let missing_transmissions =
577 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
578 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
580 Ok(())
581 }
582
583 fn insert_certificate_atomic(
589 &self,
590 certificate: BatchCertificate<N>,
591 aborted_transmission_ids: HashSet<TransmissionID<N>>,
592 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
593 ) {
594 let round = certificate.round();
596 let certificate_id = certificate.id();
598 let batch_id = certificate.batch_id();
600 let author = certificate.author();
602
603 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
605 let transmission_ids = certificate.transmission_ids().clone();
607 self.certificates.write().insert(certificate_id, certificate);
609 self.batch_ids.write().insert(batch_id, round);
611 self.transmissions.insert_transmissions(
613 certificate_id,
614 transmission_ids,
615 aborted_transmission_ids,
616 missing_transmissions,
617 );
618 }
619
620 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
627 let Some(certificate) = self.get_certificate(certificate_id) else {
629 warn!("Certificate {certificate_id} does not exist in storage");
630 return false;
631 };
632 let round = certificate.round();
634 let batch_id = certificate.batch_id();
636 let author = certificate.author();
638
639 match self.rounds.write().entry(round) {
645 Entry::Occupied(mut entry) => {
646 entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
648 if entry.get().is_empty() {
650 entry.swap_remove();
651 }
652 }
653 Entry::Vacant(_) => {}
654 }
655 self.certificates.write().swap_remove(&certificate_id);
657 self.batch_ids.write().swap_remove(&batch_id);
659 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
661 true
663 }
664}
665
666impl<N: Network> Storage<N> {
667 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
669 if next_height > self.current_height() {
671 self.current_height.store(next_height, Ordering::SeqCst);
673 }
674 }
675
676 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
678 let next_round = next_round.max(1);
680 if next_round > self.current_round() {
682 self.update_current_round(next_round);
684 info!("Synced to round {next_round}...");
686 }
687 }
688
689 pub(crate) fn sync_certificate_with_block(
691 &self,
692 block: &Block<N>,
693 certificate: BatchCertificate<N>,
694 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
695 ) {
696 if certificate.round() <= self.gc_round() {
698 return;
699 }
700 if self.contains_certificate(certificate.id()) {
702 return;
703 }
704 let mut missing_transmissions = HashMap::new();
706
707 let mut aborted_transmissions = HashSet::new();
709
710 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
712 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
713
714 for transmission_id in certificate.transmission_ids() {
716 if missing_transmissions.contains_key(transmission_id) {
718 continue;
719 }
720 if self.contains_transmission(*transmission_id) {
722 continue;
723 }
724 match transmission_id {
726 TransmissionID::Ratification => (),
727 TransmissionID::Solution(solution_id, _) => {
728 match block.get_solution(solution_id) {
730 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
732 None => match self.ledger.get_solution(solution_id) {
734 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
736 Err(_) => {
738 match aborted_solutions.contains(solution_id)
740 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
741 {
742 true => {
743 aborted_transmissions.insert(*transmission_id);
744 }
745 false => error!("Missing solution {solution_id} in block {}", block.height()),
746 }
747 continue;
748 }
749 },
750 };
751 }
752 TransmissionID::Transaction(transaction_id, _) => {
753 match unconfirmed_transactions.get(transaction_id) {
755 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
757 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
759 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
761 Err(_) => {
763 match aborted_transactions.contains(transaction_id)
765 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
766 {
767 true => {
768 aborted_transmissions.insert(*transmission_id);
769 }
770 false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
771 }
772 continue;
773 }
774 },
775 };
776 }
777 }
778 }
779 let certificate_id = fmt_id(certificate.id());
781 debug!(
782 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
783 certificate.round(),
784 certificate.transmission_ids().len()
785 );
786 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
787 error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
788 }
789 }
790}
791
792#[cfg(test)]
793impl<N: Network> Storage<N> {
794 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
796 &self.ledger
797 }
798
799 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
801 self.rounds.read().clone().into_iter()
802 }
803
804 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
806 self.certificates.read().clone().into_iter()
807 }
808
809 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
811 self.batch_ids.read().clone().into_iter()
812 }
813
814 pub fn transmissions_iter(
816 &self,
817 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
818 self.transmissions.as_hashmap().into_iter()
819 }
820
821 #[cfg(test)]
825 #[doc(hidden)]
826 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
827 let round = certificate.round();
829 let certificate_id = certificate.id();
831 let batch_id = certificate.batch_id();
833 let author = certificate.author();
835
836 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
838 let transmission_ids = certificate.transmission_ids().clone();
840 self.certificates.write().insert(certificate_id, certificate);
842 self.batch_ids.write().insert(batch_id, round);
844
845 let missing_transmissions = transmission_ids
847 .iter()
848 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
849 .collect::<HashMap<_, _>>();
850 self.transmissions.insert_transmissions(
852 certificate_id,
853 transmission_ids,
854 Default::default(),
855 missing_transmissions,
856 );
857 }
858}
859
860#[cfg(test)]
861pub(crate) mod tests {
862 use super::*;
863 use amareleo_node_bft_ledger_service::MockLedgerService;
864 use amareleo_node_bft_storage_service::BFTMemoryService;
865 use snarkvm::{
866 ledger::narwhal::Data,
867 prelude::{Rng, TestRng},
868 };
869
870 use ::bytes::Bytes;
871 use indexmap::indexset;
872
873 type CurrentNetwork = snarkvm::prelude::MainnetV0;
874
875 pub fn assert_storage<N: Network>(
877 storage: &Storage<N>,
878 rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
879 certificates: &[(Field<N>, BatchCertificate<N>)],
880 batch_ids: &[(Field<N>, u64)],
881 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
882 ) {
883 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
885 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
887 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
889 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
891 }
892
893 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
895 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
897 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
899 match rng.gen::<bool>() {
901 true => Transmission::Solution(s(rng)),
902 false => Transmission::Transaction(t(rng)),
903 }
904 }
905
906 pub(crate) fn sample_transmissions(
908 certificate: &BatchCertificate<CurrentNetwork>,
909 rng: &mut TestRng,
910 ) -> (
911 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
912 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
913 ) {
914 let certificate_id = certificate.id();
916
917 let mut missing_transmissions = HashMap::new();
918 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
919 for transmission_id in certificate.transmission_ids() {
920 let transmission = sample_transmission(rng);
922 missing_transmissions.insert(*transmission_id, transmission.clone());
924 transmissions
926 .entry(*transmission_id)
927 .or_insert((transmission, Default::default()))
928 .1
929 .insert(certificate_id);
930 }
931 (missing_transmissions, transmissions)
932 }
933
934 #[test]
937 fn test_certificate_insert_remove() {
938 let rng = &mut TestRng::default();
939
940 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
942 let ledger = Arc::new(MockLedgerService::new(committee));
944 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
946
947 assert_storage(&storage, &[], &[], &[], &Default::default());
949
950 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
952 let certificate_id = certificate.id();
954 let round = certificate.round();
956 let batch_id = certificate.batch_id();
958 let author = certificate.author();
960
961 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
963
964 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
966 assert!(storage.contains_certificate(certificate_id));
968 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
970 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
972
973 {
975 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
977 let certificates = [(certificate_id, certificate.clone())];
979 let batch_ids = [(batch_id, round)];
981 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
983 }
984
985 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
987 assert_eq!(certificate, candidate_certificate);
989
990 assert!(storage.remove_certificate(certificate_id));
992 assert!(!storage.contains_certificate(certificate_id));
994 assert!(storage.get_certificates_for_round(round).is_empty());
996 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
998 assert_storage(&storage, &[], &[], &[], &Default::default());
1000 }
1001
1002 #[test]
1003 fn test_certificate_duplicate() {
1004 let rng = &mut TestRng::default();
1005
1006 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1008 let ledger = Arc::new(MockLedgerService::new(committee));
1010 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1012
1013 assert_storage(&storage, &[], &[], &[], &Default::default());
1015
1016 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1018 let certificate_id = certificate.id();
1020 let round = certificate.round();
1022 let batch_id = certificate.batch_id();
1024 let author = certificate.author();
1026
1027 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1029 let certificates = [(certificate_id, certificate.clone())];
1031 let batch_ids = [(batch_id, round)];
1033 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1035
1036 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1038 assert!(storage.contains_certificate(certificate_id));
1040 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1042
1043 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1045 assert!(storage.contains_certificate(certificate_id));
1047 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1049
1050 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1052 assert!(storage.contains_certificate(certificate_id));
1054 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1056 }
1057}
1058
1059#[cfg(test)]
1060pub mod prop_tests {
1061 use super::*;
1062 use crate::helpers::{now, storage::tests::assert_storage};
1063 use amareleo_node_bft_ledger_service::MockLedgerService;
1064 use amareleo_node_bft_storage_service::BFTMemoryService;
1065 use snarkvm::{
1066 ledger::{
1067 committee::prop_tests::{CommitteeContext, ValidatorSet},
1068 narwhal::{BatchHeader, Data},
1069 puzzle::SolutionID,
1070 },
1071 prelude::{Signature, Uniform},
1072 };
1073
1074 use ::bytes::Bytes;
1075 use indexmap::indexset;
1076 use proptest::{
1077 collection,
1078 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1079 prop_oneof,
1080 sample::{Selector, size_range},
1081 test_runner::TestRng,
1082 };
1083 use rand::{CryptoRng, Error, Rng, RngCore};
1084 use std::fmt::Debug;
1085 use test_strategy::proptest;
1086
1087 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1088
1089 impl Arbitrary for Storage<CurrentNetwork> {
1090 type Parameters = CommitteeContext;
1091 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1092
1093 fn arbitrary() -> Self::Strategy {
1094 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1095 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1096 let ledger = Arc::new(MockLedgerService::new(committee));
1097 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1098 })
1099 .boxed()
1100 }
1101
1102 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1103 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1104 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1105 let ledger = Arc::new(MockLedgerService::new(committee));
1106 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1107 })
1108 .boxed()
1109 }
1110 }
1111
1112 #[derive(Debug)]
1114 pub struct CryptoTestRng(TestRng);
1115
1116 impl Arbitrary for CryptoTestRng {
1117 type Parameters = ();
1118 type Strategy = BoxedStrategy<CryptoTestRng>;
1119
1120 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1121 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1122 }
1123 }
1124 impl RngCore for CryptoTestRng {
1125 fn next_u32(&mut self) -> u32 {
1126 self.0.next_u32()
1127 }
1128
1129 fn next_u64(&mut self) -> u64 {
1130 self.0.next_u64()
1131 }
1132
1133 fn fill_bytes(&mut self, dest: &mut [u8]) {
1134 self.0.fill_bytes(dest);
1135 }
1136
1137 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1138 self.0.try_fill_bytes(dest)
1139 }
1140 }
1141
1142 impl CryptoRng for CryptoTestRng {}
1143
1144 #[derive(Debug, Clone)]
1145 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1146
1147 impl Arbitrary for AnyTransmission {
1148 type Parameters = ();
1149 type Strategy = BoxedStrategy<AnyTransmission>;
1150
1151 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1152 any_transmission().prop_map(AnyTransmission).boxed()
1153 }
1154 }
1155
1156 #[derive(Debug, Clone)]
1157 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1158
1159 impl Arbitrary for AnyTransmissionID {
1160 type Parameters = ();
1161 type Strategy = BoxedStrategy<AnyTransmissionID>;
1162
1163 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1164 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1165 }
1166 }
1167
1168 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1169 prop_oneof![
1170 (collection::vec(any::<u8>(), 512..=512))
1171 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1172 (collection::vec(any::<u8>(), 2048..=2048))
1173 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1174 ]
1175 .boxed()
1176 }
1177
1178 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1179 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1180 }
1181
1182 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1183 Just(0)
1184 .prop_perturb(|_, rng| {
1185 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1186 })
1187 .boxed()
1188 }
1189
1190 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1191 prop_oneof![
1192 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1193 id,
1194 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1195 )),
1196 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1197 id,
1198 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1199 )),
1200 ]
1201 .boxed()
1202 }
1203
1204 pub fn sign_batch_header<R: Rng + CryptoRng>(
1205 validator_set: &ValidatorSet,
1206 batch_header: &BatchHeader<CurrentNetwork>,
1207 rng: &mut R,
1208 ) -> IndexSet<Signature<CurrentNetwork>> {
1209 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1210 for validator in validator_set.0.iter() {
1211 let private_key = validator.private_key;
1212 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1213 }
1214 signatures
1215 }
1216
1217 #[proptest]
1218 fn test_certificate_duplicate(
1219 context: CommitteeContext,
1220 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1221 mut rng: CryptoTestRng,
1222 selector: Selector,
1223 ) {
1224 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1225 let committee_id = committee.id();
1226
1227 let ledger = Arc::new(MockLedgerService::new(committee));
1229 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1230
1231 assert_storage(&storage, &[], &[], &[], &Default::default());
1233
1234 let signer = selector.select(&validators);
1236
1237 let mut transmission_map = IndexMap::new();
1238
1239 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1240 transmission_map.insert(*id, t.clone());
1241 }
1242
1243 let batch_header = BatchHeader::new(
1244 &signer.private_key,
1245 0,
1246 now(),
1247 committee_id,
1248 transmission_map.keys().cloned().collect(),
1249 Default::default(),
1250 &mut rng,
1251 )
1252 .unwrap();
1253
1254 let mut validators = validators.clone();
1257 validators.remove(signer);
1258
1259 let certificate = BatchCertificate::from(
1260 batch_header.clone(),
1261 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1262 )
1263 .unwrap();
1264
1265 let certificate_id = certificate.id();
1267 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1268 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1269 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1270 }
1271
1272 let round = certificate.round();
1274 let batch_id = certificate.batch_id();
1276 let author = certificate.author();
1278
1279 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1281 let certificates = [(certificate_id, certificate.clone())];
1283 let batch_ids = [(batch_id, round)];
1285
1286 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1288 transmission_map.into_iter().collect();
1289 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1290 assert!(storage.contains_certificate(certificate_id));
1292 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1294
1295 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1297 assert!(storage.contains_certificate(certificate_id));
1299 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1301
1302 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1304 assert!(storage.contains_certificate(certificate_id));
1306 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1308 }
1309}