1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21 ledger::{
22 block::{Block, Transaction},
23 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24 },
25 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26 utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::RwLock;
34use rayon::iter::{IntoParallelIterator, ParallelIterator};
35use std::{
36 collections::{HashMap, HashSet},
37 sync::{
38 Arc,
39 atomic::{AtomicU32, AtomicU64, Ordering},
40 },
41};
42use tracing::subscriber::DefaultGuard;
43
44#[derive(Clone, Debug)]
45pub struct Storage<N: Network>(Arc<StorageInner<N>>);
46
47impl<N: Network> std::ops::Deref for Storage<N> {
48 type Target = Arc<StorageInner<N>>;
49
50 fn deref(&self) -> &Self::Target {
51 &self.0
52 }
53}
54
55impl<N: Network> TracingHandlerGuard for Storage<N> {
56 fn get_tracing_guard(&self) -> Option<DefaultGuard> {
58 self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
59 }
60}
61
62#[derive(Debug)]
82pub struct StorageInner<N: Network> {
83 ledger: Arc<dyn LedgerService<N>>,
85 current_height: AtomicU32,
88 current_round: AtomicU64,
91 gc_round: AtomicU64,
93 max_gc_rounds: u64,
95 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
98 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
100 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
102 transmissions: Arc<dyn StorageService<N>>,
104 tracing: Option<TracingHandler>,
106}
107
108impl<N: Network> Storage<N> {
109 pub fn new(
111 ledger: Arc<dyn LedgerService<N>>,
112 transmissions: Arc<dyn StorageService<N>>,
113 max_gc_rounds: u64,
114 tracing: Option<TracingHandler>,
115 ) -> Self {
116 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
118 let current_round = committee.starting_round().max(1);
120
121 let storage = Self(Arc::new(StorageInner {
123 ledger,
124 current_height: Default::default(),
125 current_round: Default::default(),
126 gc_round: Default::default(),
127 max_gc_rounds,
128 rounds: Default::default(),
129 certificates: Default::default(),
130 batch_ids: Default::default(),
131 transmissions,
132 tracing,
133 }));
134 storage.update_current_round(current_round);
136 storage.garbage_collect_certificates(current_round);
138 storage
140 }
141}
142
143impl<N: Network> Storage<N> {
144 pub fn current_height(&self) -> u32 {
146 self.current_height.load(Ordering::SeqCst)
148 }
149}
150
151impl<N: Network> Storage<N> {
152 pub fn current_round(&self) -> u64 {
154 self.current_round.load(Ordering::SeqCst)
156 }
157
158 pub fn gc_round(&self) -> u64 {
160 self.gc_round.load(Ordering::SeqCst)
162 }
163
164 pub fn max_gc_rounds(&self) -> u64 {
166 self.max_gc_rounds
167 }
168
169 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
172 let next_round = current_round + 1;
174
175 {
177 let storage_round = self.current_round();
179 if next_round < storage_round {
181 return Ok(storage_round);
182 }
183 }
184
185 let current_committee = self.ledger.current_committee()?;
187 let starting_round = current_committee.starting_round();
189 if next_round < starting_round {
191 let latest_block_round = self.ledger.latest_round();
193 guard_info!(
195 self,
196 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
197 );
198 self.sync_round_with_block(latest_block_round);
200 return Ok(latest_block_round);
202 }
203
204 self.update_current_round(next_round);
206
207 #[cfg(feature = "metrics")]
208 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
209
210 let storage_round = self.current_round();
212 let gc_round = self.gc_round();
214 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
216 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
218
219 guard_info!(self, "Starting round {next_round}...");
221 Ok(next_round)
222 }
223
224 fn update_current_round(&self, next_round: u64) {
226 self.current_round.store(next_round, Ordering::SeqCst);
228 }
229
230 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
232 let current_gc_round = self.gc_round();
234 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
236 if next_gc_round > current_gc_round {
238 for gc_round in current_gc_round..=next_gc_round {
240 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
242 self.remove_certificate(id);
244 }
245 }
246 self.gc_round.store(next_gc_round, Ordering::SeqCst);
248 }
249 }
250}
251
252impl<N: Network> Storage<N> {
253 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
255 self.rounds.read().contains_key(&round)
257 }
258
259 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
261 self.certificates.read().contains_key(&certificate_id)
263 }
264
265 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
267 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
268 }
269
270 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
272 self.batch_ids.read().contains_key(&batch_id)
274 }
275
276 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
278 self.transmissions.contains_transmission(transmission_id.into())
279 }
280
281 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
284 self.transmissions.get_transmission(transmission_id.into())
285 }
286
287 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
290 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
292 }
293
294 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
297 self.batch_ids.read().get(&batch_id).copied()
299 }
300
301 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
304 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
306 }
307
308 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
311 self.certificates.read().get(&certificate_id).cloned()
313 }
314
315 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
319 if let Some(entries) = self.rounds.read().get(&round) {
321 let certificates = self.certificates.read();
322 entries.iter().find_map(
323 |(certificate_id, _, a)| {
324 if a == &author { certificates.get(certificate_id).cloned() } else { None }
325 },
326 )
327 } else {
328 Default::default()
329 }
330 }
331
332 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
335 if round == 0 {
337 return Default::default();
338 }
339 if let Some(entries) = self.rounds.read().get(&round) {
341 let certificates = self.certificates.read();
342 entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
343 } else {
344 Default::default()
345 }
346 }
347
348 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
351 if round == 0 {
353 return Default::default();
354 }
355 if let Some(entries) = self.rounds.read().get(&round) {
357 entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
358 } else {
359 Default::default()
360 }
361 }
362
363 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
366 if round == 0 {
368 return Default::default();
369 }
370 if let Some(entries) = self.rounds.read().get(&round) {
372 entries.iter().map(|(_, _, author)| *author).collect()
373 } else {
374 Default::default()
375 }
376 }
377
378 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
381 let rounds = self.rounds.read();
383 let certificates = self.certificates.read();
384
385 cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
387 .flat_map(|(_, certificates_for_round)| {
388 cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
390 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
392 None
393 } else {
394 certificates.get(&certificate_id).cloned()
396 }
397 })
398 })
399 .collect()
400 }
401
402 pub fn check_batch_header(
415 &self,
416 batch_header: &BatchHeader<N>,
417 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
418 aborted_transmissions: HashSet<TransmissionID<N>>,
419 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
420 let round = batch_header.round();
422 let gc_round = self.gc_round();
424 let gc_log = format!("(gc = {gc_round})");
426
427 if self.contains_batch(batch_header.batch_id()) {
429 bail!("Batch for round {round} already exists in storage {gc_log}")
430 }
431
432 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
434 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
435 };
436 if !committee_lookback.is_committee_member(batch_header.author()) {
438 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
439 }
440
441 check_timestamp_for_liveness(batch_header.timestamp())?;
443
444 let missing_transmissions = self
446 .transmissions
447 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
448 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
449
450 let previous_round = round.saturating_sub(1);
452 if previous_round > gc_round {
454 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
456 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
457 };
458 if !self.contains_certificates_for_round(previous_round) {
460 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
461 }
462 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
464 bail!("Too many previous certificates for round {round} {gc_log}")
465 }
466 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
468 for previous_certificate_id in batch_header.previous_certificate_ids() {
470 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
472 bail!(
473 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
474 fmt_id(previous_certificate_id)
475 )
476 };
477 if previous_certificate.round() != previous_round {
479 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
480 }
481 if previous_authors.contains(&previous_certificate.author()) {
483 bail!("Round {round} certificate contains a duplicate author {gc_log}")
484 }
485 previous_authors.insert(previous_certificate.author());
487 }
488 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
490 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
491 }
492 }
493 Ok(missing_transmissions)
494 }
495
496 pub fn check_certificate(
512 &self,
513 certificate: &BatchCertificate<N>,
514 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
515 aborted_transmissions: HashSet<TransmissionID<N>>,
516 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
517 let round = certificate.round();
519 let gc_round = self.gc_round();
521 let gc_log = format!("(gc = {gc_round})");
523
524 if self.contains_certificate(certificate.id()) {
526 bail!("Certificate for round {round} already exists in storage {gc_log}")
527 }
528
529 if self.contains_certificate_in_round_from(round, certificate.author()) {
531 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
532 }
533
534 let missing_transmissions =
536 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
537
538 check_timestamp_for_liveness(certificate.timestamp())?;
540
541 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
543 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
544 };
545
546 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
548 signers.insert(certificate.author());
550
551 for signature in certificate.signatures() {
553 let signer = signature.to_address();
555 if !committee_lookback.is_committee_member(signer) {
557 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
558 }
559 signers.insert(signer);
561 }
562
563 if !committee_lookback.is_quorum_threshold_reached(&signers) {
565 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
566 }
567 Ok(missing_transmissions)
568 }
569
570 pub fn insert_certificate(
582 &self,
583 certificate: BatchCertificate<N>,
584 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
585 aborted_transmissions: HashSet<TransmissionID<N>>,
586 ) -> Result<()> {
587 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
589 let missing_transmissions =
591 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
592 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
594 Ok(())
595 }
596
597 fn insert_certificate_atomic(
603 &self,
604 certificate: BatchCertificate<N>,
605 aborted_transmission_ids: HashSet<TransmissionID<N>>,
606 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
607 ) {
608 let round = certificate.round();
610 let certificate_id = certificate.id();
612 let batch_id = certificate.batch_id();
614 let author = certificate.author();
616
617 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
619 let transmission_ids = certificate.transmission_ids().clone();
621 self.certificates.write().insert(certificate_id, certificate);
623 self.batch_ids.write().insert(batch_id, round);
625 self.transmissions.insert_transmissions(
627 certificate_id,
628 transmission_ids,
629 aborted_transmission_ids,
630 missing_transmissions,
631 );
632 }
633
634 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
641 let Some(certificate) = self.get_certificate(certificate_id) else {
643 guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
644 return false;
645 };
646 let round = certificate.round();
648 let batch_id = certificate.batch_id();
650 let author = certificate.author();
652
653 match self.rounds.write().entry(round) {
659 Entry::Occupied(mut entry) => {
660 entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
662 if entry.get().is_empty() {
664 entry.swap_remove();
665 }
666 }
667 Entry::Vacant(_) => {}
668 }
669 self.certificates.write().swap_remove(&certificate_id);
671 self.batch_ids.write().swap_remove(&batch_id);
673 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
675 true
677 }
678}
679
680impl<N: Network> Storage<N> {
681 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
683 if next_height > self.current_height() {
685 self.current_height.store(next_height, Ordering::SeqCst);
687 }
688 }
689
690 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
692 let next_round = next_round.max(1);
694 if next_round > self.current_round() {
696 self.update_current_round(next_round);
698 guard_info!(self, "Synced to round {next_round}...");
700 }
701 }
702
703 pub(crate) fn sync_certificate_with_block(
705 &self,
706 block: &Block<N>,
707 certificate: BatchCertificate<N>,
708 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
709 ) {
710 if certificate.round() <= self.gc_round() {
712 return;
713 }
714 if self.contains_certificate(certificate.id()) {
716 return;
717 }
718 let mut missing_transmissions = HashMap::new();
720
721 let mut aborted_transmissions = HashSet::new();
723
724 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
726 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
727
728 for transmission_id in certificate.transmission_ids() {
730 if missing_transmissions.contains_key(transmission_id) {
732 continue;
733 }
734 if self.contains_transmission(*transmission_id) {
736 continue;
737 }
738 match transmission_id {
740 TransmissionID::Ratification => (),
741 TransmissionID::Solution(solution_id, _) => {
742 match block.get_solution(solution_id) {
744 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
746 None => match self.ledger.get_solution(solution_id) {
748 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
750 Err(_) => {
752 match aborted_solutions.contains(solution_id)
754 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
755 {
756 true => {
757 aborted_transmissions.insert(*transmission_id);
758 }
759 false => {
760 guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
761 }
762 }
763 continue;
764 }
765 },
766 };
767 }
768 TransmissionID::Transaction(transaction_id, _) => {
769 match unconfirmed_transactions.get(transaction_id) {
771 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
773 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
775 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
777 Err(_) => {
779 match aborted_transactions.contains(transaction_id)
781 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
782 {
783 true => {
784 aborted_transmissions.insert(*transmission_id);
785 }
786 false => guard_warn!(
787 self,
788 "Missing transaction {transaction_id} in block {}",
789 block.height()
790 ),
791 }
792 continue;
793 }
794 },
795 };
796 }
797 }
798 }
799 let certificate_id = fmt_id(certificate.id());
801 guard_debug!(
802 self,
803 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
804 certificate.round(),
805 certificate.transmission_ids().len()
806 );
807 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
808 guard_error!(
809 self,
810 "Failed to insert certificate '{certificate_id}' from block {} - {error}",
811 block.height()
812 );
813 }
814 }
815}
816
817#[cfg(test)]
818impl<N: Network> Storage<N> {
819 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
821 &self.ledger
822 }
823
824 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
826 self.rounds.read().clone().into_iter()
827 }
828
829 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
831 self.certificates.read().clone().into_iter()
832 }
833
834 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
836 self.batch_ids.read().clone().into_iter()
837 }
838
839 pub fn transmissions_iter(
841 &self,
842 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
843 self.transmissions.as_hashmap().into_iter()
844 }
845
846 #[cfg(test)]
850 #[doc(hidden)]
851 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
852 let round = certificate.round();
854 let certificate_id = certificate.id();
856 let batch_id = certificate.batch_id();
858 let author = certificate.author();
860
861 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
863 let transmission_ids = certificate.transmission_ids().clone();
865 self.certificates.write().insert(certificate_id, certificate);
867 self.batch_ids.write().insert(batch_id, round);
869
870 let missing_transmissions = transmission_ids
872 .iter()
873 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
874 .collect::<HashMap<_, _>>();
875 self.transmissions.insert_transmissions(
877 certificate_id,
878 transmission_ids,
879 Default::default(),
880 missing_transmissions,
881 );
882 }
883}
884
885#[cfg(test)]
886pub(crate) mod tests {
887 use super::*;
888 use amareleo_node_bft_ledger_service::MockLedgerService;
889 use amareleo_node_bft_storage_service::BFTMemoryService;
890 use snarkvm::{
891 ledger::narwhal::Data,
892 prelude::{Rng, TestRng},
893 };
894
895 use ::bytes::Bytes;
896 use indexmap::indexset;
897
898 type CurrentNetwork = snarkvm::prelude::MainnetV0;
899
900 pub fn assert_storage<N: Network>(
902 storage: &Storage<N>,
903 rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
904 certificates: &[(Field<N>, BatchCertificate<N>)],
905 batch_ids: &[(Field<N>, u64)],
906 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
907 ) {
908 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
910 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
912 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
914 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
916 }
917
918 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
920 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
922 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
924 match rng.gen::<bool>() {
926 true => Transmission::Solution(s(rng)),
927 false => Transmission::Transaction(t(rng)),
928 }
929 }
930
931 pub(crate) fn sample_transmissions(
933 certificate: &BatchCertificate<CurrentNetwork>,
934 rng: &mut TestRng,
935 ) -> (
936 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
937 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
938 ) {
939 let certificate_id = certificate.id();
941
942 let mut missing_transmissions = HashMap::new();
943 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
944 for transmission_id in certificate.transmission_ids() {
945 let transmission = sample_transmission(rng);
947 missing_transmissions.insert(*transmission_id, transmission.clone());
949 transmissions
951 .entry(*transmission_id)
952 .or_insert((transmission, Default::default()))
953 .1
954 .insert(certificate_id);
955 }
956 (missing_transmissions, transmissions)
957 }
958
959 #[test]
962 fn test_certificate_insert_remove() {
963 let rng = &mut TestRng::default();
964
965 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
967 let ledger = Arc::new(MockLedgerService::new(committee));
969 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
971
972 assert_storage(&storage, &[], &[], &[], &Default::default());
974
975 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
977 let certificate_id = certificate.id();
979 let round = certificate.round();
981 let batch_id = certificate.batch_id();
983 let author = certificate.author();
985
986 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
988
989 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
991 assert!(storage.contains_certificate(certificate_id));
993 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
995 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
997
998 {
1000 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1002 let certificates = [(certificate_id, certificate.clone())];
1004 let batch_ids = [(batch_id, round)];
1006 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1008 }
1009
1010 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1012 assert_eq!(certificate, candidate_certificate);
1014
1015 assert!(storage.remove_certificate(certificate_id));
1017 assert!(!storage.contains_certificate(certificate_id));
1019 assert!(storage.get_certificates_for_round(round).is_empty());
1021 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1023 assert_storage(&storage, &[], &[], &[], &Default::default());
1025 }
1026
1027 #[test]
1028 fn test_certificate_duplicate() {
1029 let rng = &mut TestRng::default();
1030
1031 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1033 let ledger = Arc::new(MockLedgerService::new(committee));
1035 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1037
1038 assert_storage(&storage, &[], &[], &[], &Default::default());
1040
1041 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1043 let certificate_id = certificate.id();
1045 let round = certificate.round();
1047 let batch_id = certificate.batch_id();
1049 let author = certificate.author();
1051
1052 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1054 let certificates = [(certificate_id, certificate.clone())];
1056 let batch_ids = [(batch_id, round)];
1058 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1060
1061 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1063 assert!(storage.contains_certificate(certificate_id));
1065 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1067
1068 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1070 assert!(storage.contains_certificate(certificate_id));
1072 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1074
1075 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1077 assert!(storage.contains_certificate(certificate_id));
1079 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1081 }
1082}
1083
1084#[cfg(test)]
1085pub mod prop_tests {
1086 use super::*;
1087 use crate::helpers::{now, storage::tests::assert_storage};
1088 use amareleo_node_bft_ledger_service::MockLedgerService;
1089 use amareleo_node_bft_storage_service::BFTMemoryService;
1090 use snarkvm::{
1091 ledger::{
1092 committee::prop_tests::{CommitteeContext, ValidatorSet},
1093 narwhal::{BatchHeader, Data},
1094 puzzle::SolutionID,
1095 },
1096 prelude::{Signature, Uniform},
1097 };
1098
1099 use ::bytes::Bytes;
1100 use indexmap::indexset;
1101 use proptest::{
1102 collection,
1103 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1104 prop_oneof,
1105 sample::{Selector, size_range},
1106 test_runner::TestRng,
1107 };
1108 use rand::{CryptoRng, Error, Rng, RngCore};
1109 use std::fmt::Debug;
1110 use test_strategy::proptest;
1111
1112 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1113
1114 impl Arbitrary for Storage<CurrentNetwork> {
1115 type Parameters = CommitteeContext;
1116 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1117
1118 fn arbitrary() -> Self::Strategy {
1119 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1120 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1121 let ledger = Arc::new(MockLedgerService::new(committee));
1122 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1123 })
1124 .boxed()
1125 }
1126
1127 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1128 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1129 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1130 let ledger = Arc::new(MockLedgerService::new(committee));
1131 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1132 })
1133 .boxed()
1134 }
1135 }
1136
1137 #[derive(Debug)]
1139 pub struct CryptoTestRng(TestRng);
1140
1141 impl Arbitrary for CryptoTestRng {
1142 type Parameters = ();
1143 type Strategy = BoxedStrategy<CryptoTestRng>;
1144
1145 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1146 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1147 }
1148 }
1149 impl RngCore for CryptoTestRng {
1150 fn next_u32(&mut self) -> u32 {
1151 self.0.next_u32()
1152 }
1153
1154 fn next_u64(&mut self) -> u64 {
1155 self.0.next_u64()
1156 }
1157
1158 fn fill_bytes(&mut self, dest: &mut [u8]) {
1159 self.0.fill_bytes(dest);
1160 }
1161
1162 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1163 self.0.try_fill_bytes(dest)
1164 }
1165 }
1166
1167 impl CryptoRng for CryptoTestRng {}
1168
1169 #[derive(Debug, Clone)]
1170 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1171
1172 impl Arbitrary for AnyTransmission {
1173 type Parameters = ();
1174 type Strategy = BoxedStrategy<AnyTransmission>;
1175
1176 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1177 any_transmission().prop_map(AnyTransmission).boxed()
1178 }
1179 }
1180
1181 #[derive(Debug, Clone)]
1182 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1183
1184 impl Arbitrary for AnyTransmissionID {
1185 type Parameters = ();
1186 type Strategy = BoxedStrategy<AnyTransmissionID>;
1187
1188 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1189 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1190 }
1191 }
1192
1193 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1194 prop_oneof![
1195 (collection::vec(any::<u8>(), 512..=512))
1196 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1197 (collection::vec(any::<u8>(), 2048..=2048))
1198 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1199 ]
1200 .boxed()
1201 }
1202
1203 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1204 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1205 }
1206
1207 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1208 Just(0)
1209 .prop_perturb(|_, rng| {
1210 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1211 })
1212 .boxed()
1213 }
1214
1215 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1216 prop_oneof![
1217 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1218 id,
1219 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1220 )),
1221 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1222 id,
1223 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1224 )),
1225 ]
1226 .boxed()
1227 }
1228
1229 pub fn sign_batch_header<R: Rng + CryptoRng>(
1230 validator_set: &ValidatorSet,
1231 batch_header: &BatchHeader<CurrentNetwork>,
1232 rng: &mut R,
1233 ) -> IndexSet<Signature<CurrentNetwork>> {
1234 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1235 for validator in validator_set.0.iter() {
1236 let private_key = validator.private_key;
1237 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1238 }
1239 signatures
1240 }
1241
1242 #[proptest]
1243 fn test_certificate_duplicate(
1244 context: CommitteeContext,
1245 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1246 mut rng: CryptoTestRng,
1247 selector: Selector,
1248 ) {
1249 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1250 let committee_id = committee.id();
1251
1252 let ledger = Arc::new(MockLedgerService::new(committee));
1254 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1255
1256 assert_storage(&storage, &[], &[], &[], &Default::default());
1258
1259 let signer = selector.select(&validators);
1261
1262 let mut transmission_map = IndexMap::new();
1263
1264 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1265 transmission_map.insert(*id, t.clone());
1266 }
1267
1268 let batch_header = BatchHeader::new(
1269 &signer.private_key,
1270 0,
1271 now(),
1272 committee_id,
1273 transmission_map.keys().cloned().collect(),
1274 Default::default(),
1275 &mut rng,
1276 )
1277 .unwrap();
1278
1279 let mut validators = validators.clone();
1282 validators.remove(signer);
1283
1284 let certificate = BatchCertificate::from(
1285 batch_header.clone(),
1286 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1287 )
1288 .unwrap();
1289
1290 let certificate_id = certificate.id();
1292 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1293 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1294 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1295 }
1296
1297 let round = certificate.round();
1299 let batch_id = certificate.batch_id();
1301 let author = certificate.author();
1303
1304 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1306 let certificates = [(certificate_id, certificate.clone())];
1308 let batch_ids = [(batch_id, round)];
1310
1311 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1313 transmission_map.into_iter().collect();
1314 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1315 assert!(storage.contains_certificate(certificate_id));
1317 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1319
1320 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1322 assert!(storage.contains_certificate(certificate_id));
1324 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1326
1327 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1329 assert!(storage.contains_certificate(certificate_id));
1331 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1333 }
1334}