1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
18use amareleo_node_bft_ledger_service::LedgerService;
19use amareleo_node_bft_storage_service::StorageService;
20use snarkvm::{
21 ledger::{
22 block::{Block, Transaction},
23 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
24 },
25 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
26 utilities::{cfg_into_iter, cfg_sorted_by},
27};
28
29use indexmap::{IndexMap, IndexSet, map::Entry};
30use parking_lot::RwLock;
31use rayon::iter::{IntoParallelIterator, ParallelIterator};
32use std::{
33 collections::{HashMap, HashSet},
34 sync::{
35 Arc,
36 atomic::{AtomicU32, AtomicU64, Ordering},
37 },
38};
39use tracing::subscriber::DefaultGuard;
40
41#[derive(Clone, Debug)]
42pub struct Storage<N: Network>(Arc<StorageInner<N>>);
43
44impl<N: Network> std::ops::Deref for Storage<N> {
45 type Target = Arc<StorageInner<N>>;
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52impl<N: Network> TracingHandlerGuard for Storage<N> {
53 fn get_tracing_guard(&self) -> Option<DefaultGuard> {
55 self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
56 }
57}
58
59#[derive(Debug)]
79pub struct StorageInner<N: Network> {
80 ledger: Arc<dyn LedgerService<N>>,
82 current_height: AtomicU32,
85 current_round: AtomicU64,
88 gc_round: AtomicU64,
90 max_gc_rounds: u64,
92 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
95 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
97 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
99 transmissions: Arc<dyn StorageService<N>>,
101 tracing: Option<TracingHandler>,
103}
104
105impl<N: Network> Storage<N> {
106 pub fn new(
108 ledger: Arc<dyn LedgerService<N>>,
109 transmissions: Arc<dyn StorageService<N>>,
110 max_gc_rounds: u64,
111 tracing: Option<TracingHandler>,
112 ) -> Self {
113 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
115 let current_round = committee.starting_round().max(1);
117
118 let storage = Self(Arc::new(StorageInner {
120 ledger,
121 current_height: Default::default(),
122 current_round: Default::default(),
123 gc_round: Default::default(),
124 max_gc_rounds,
125 rounds: Default::default(),
126 certificates: Default::default(),
127 batch_ids: Default::default(),
128 transmissions,
129 tracing,
130 }));
131 storage.update_current_round(current_round);
133 storage.garbage_collect_certificates(current_round);
135 storage
137 }
138}
139
140impl<N: Network> Storage<N> {
141 pub fn current_height(&self) -> u32 {
143 self.current_height.load(Ordering::SeqCst)
145 }
146}
147
148impl<N: Network> Storage<N> {
149 pub fn current_round(&self) -> u64 {
151 self.current_round.load(Ordering::SeqCst)
153 }
154
155 pub fn gc_round(&self) -> u64 {
157 self.gc_round.load(Ordering::SeqCst)
159 }
160
161 pub fn max_gc_rounds(&self) -> u64 {
163 self.max_gc_rounds
164 }
165
166 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
169 let next_round = current_round + 1;
171
172 {
174 let storage_round = self.current_round();
176 if next_round < storage_round {
178 return Ok(storage_round);
179 }
180 }
181
182 let current_committee = self.ledger.current_committee()?;
184 let starting_round = current_committee.starting_round();
186 if next_round < starting_round {
188 let latest_block_round = self.ledger.latest_round();
190 guard_info!(
192 self,
193 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
194 );
195 self.sync_round_with_block(latest_block_round);
197 return Ok(latest_block_round);
199 }
200
201 self.update_current_round(next_round);
203
204 #[cfg(feature = "metrics")]
205 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
206
207 let storage_round = self.current_round();
209 let gc_round = self.gc_round();
211 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
213 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
215
216 guard_info!(self, "Starting round {next_round}...");
218 Ok(next_round)
219 }
220
221 fn update_current_round(&self, next_round: u64) {
223 self.current_round.store(next_round, Ordering::SeqCst);
225 }
226
227 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
229 let current_gc_round = self.gc_round();
231 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
233 if next_gc_round > current_gc_round {
235 for gc_round in current_gc_round..=next_gc_round {
237 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
239 self.remove_certificate(id);
241 }
242 }
243 self.gc_round.store(next_gc_round, Ordering::SeqCst);
245 }
246 }
247}
248
249impl<N: Network> Storage<N> {
250 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
252 self.rounds.read().contains_key(&round)
254 }
255
256 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
258 self.certificates.read().contains_key(&certificate_id)
260 }
261
262 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
264 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
265 }
266
267 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
269 self.batch_ids.read().contains_key(&batch_id)
271 }
272
273 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
275 self.transmissions.contains_transmission(transmission_id.into())
276 }
277
278 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
281 self.transmissions.get_transmission(transmission_id.into())
282 }
283
284 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
287 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
289 }
290
291 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
294 self.batch_ids.read().get(&batch_id).copied()
296 }
297
298 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
301 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
303 }
304
305 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
308 self.certificates.read().get(&certificate_id).cloned()
310 }
311
312 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
316 if let Some(entries) = self.rounds.read().get(&round) {
318 let certificates = self.certificates.read();
319 entries.iter().find_map(
320 |(certificate_id, _, a)| {
321 if a == &author { certificates.get(certificate_id).cloned() } else { None }
322 },
323 )
324 } else {
325 Default::default()
326 }
327 }
328
329 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
332 if round == 0 {
334 return Default::default();
335 }
336 if let Some(entries) = self.rounds.read().get(&round) {
338 let certificates = self.certificates.read();
339 entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
340 } else {
341 Default::default()
342 }
343 }
344
345 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
348 if round == 0 {
350 return Default::default();
351 }
352 if let Some(entries) = self.rounds.read().get(&round) {
354 entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
355 } else {
356 Default::default()
357 }
358 }
359
360 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
363 if round == 0 {
365 return Default::default();
366 }
367 if let Some(entries) = self.rounds.read().get(&round) {
369 entries.iter().map(|(_, _, author)| *author).collect()
370 } else {
371 Default::default()
372 }
373 }
374
375 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
378 let rounds = self.rounds.read();
380 let certificates = self.certificates.read();
381
382 cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
384 .flat_map(|(_, certificates_for_round)| {
385 cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _, _)| {
387 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
389 None
390 } else {
391 certificates.get(&certificate_id).cloned()
393 }
394 })
395 })
396 .collect()
397 }
398
399 pub fn check_batch_header(
412 &self,
413 batch_header: &BatchHeader<N>,
414 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
415 aborted_transmissions: HashSet<TransmissionID<N>>,
416 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
417 let round = batch_header.round();
419 let gc_round = self.gc_round();
421 let gc_log = format!("(gc = {gc_round})");
423
424 if self.contains_batch(batch_header.batch_id()) {
426 bail!("Batch for round {round} already exists in storage {gc_log}")
427 }
428
429 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
431 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
432 };
433 if !committee_lookback.is_committee_member(batch_header.author()) {
435 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
436 }
437
438 check_timestamp_for_liveness(batch_header.timestamp())?;
440
441 let missing_transmissions = self
443 .transmissions
444 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
445 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
446
447 let previous_round = round.saturating_sub(1);
449 if previous_round > gc_round {
451 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
453 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
454 };
455 if !self.contains_certificates_for_round(previous_round) {
457 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
458 }
459 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
461 bail!("Too many previous certificates for round {round} {gc_log}")
462 }
463 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
465 for previous_certificate_id in batch_header.previous_certificate_ids() {
467 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
469 bail!(
470 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
471 fmt_id(previous_certificate_id)
472 )
473 };
474 if previous_certificate.round() != previous_round {
476 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
477 }
478 if previous_authors.contains(&previous_certificate.author()) {
480 bail!("Round {round} certificate contains a duplicate author {gc_log}")
481 }
482 previous_authors.insert(previous_certificate.author());
484 }
485 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
487 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
488 }
489 }
490 Ok(missing_transmissions)
491 }
492
493 pub fn check_certificate(
509 &self,
510 certificate: &BatchCertificate<N>,
511 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
512 aborted_transmissions: HashSet<TransmissionID<N>>,
513 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
514 let round = certificate.round();
516 let gc_round = self.gc_round();
518 let gc_log = format!("(gc = {gc_round})");
520
521 if self.contains_certificate(certificate.id()) {
523 bail!("Certificate for round {round} already exists in storage {gc_log}")
524 }
525
526 if self.contains_certificate_in_round_from(round, certificate.author()) {
528 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
529 }
530
531 let missing_transmissions =
533 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
534
535 check_timestamp_for_liveness(certificate.timestamp())?;
537
538 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
540 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
541 };
542
543 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
545 signers.insert(certificate.author());
547
548 for signature in certificate.signatures() {
550 let signer = signature.to_address();
552 if !committee_lookback.is_committee_member(signer) {
554 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
555 }
556 signers.insert(signer);
558 }
559
560 if !committee_lookback.is_quorum_threshold_reached(&signers) {
562 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
563 }
564 Ok(missing_transmissions)
565 }
566
567 pub fn insert_certificate(
579 &self,
580 certificate: BatchCertificate<N>,
581 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
582 aborted_transmissions: HashSet<TransmissionID<N>>,
583 ) -> Result<()> {
584 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
586 let missing_transmissions =
588 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
589 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
591 Ok(())
592 }
593
594 fn insert_certificate_atomic(
600 &self,
601 certificate: BatchCertificate<N>,
602 aborted_transmission_ids: HashSet<TransmissionID<N>>,
603 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
604 ) {
605 let round = certificate.round();
607 let certificate_id = certificate.id();
609 let batch_id = certificate.batch_id();
611 let author = certificate.author();
613
614 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
616 let transmission_ids = certificate.transmission_ids().clone();
618 self.certificates.write().insert(certificate_id, certificate);
620 self.batch_ids.write().insert(batch_id, round);
622 self.transmissions.insert_transmissions(
624 certificate_id,
625 transmission_ids,
626 aborted_transmission_ids,
627 missing_transmissions,
628 );
629 }
630
631 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
638 let Some(certificate) = self.get_certificate(certificate_id) else {
640 guard_warn!(self, "Certificate {certificate_id} does not exist in storage");
641 return false;
642 };
643 let round = certificate.round();
645 let batch_id = certificate.batch_id();
647 let author = certificate.author();
649
650 match self.rounds.write().entry(round) {
656 Entry::Occupied(mut entry) => {
657 entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
659 if entry.get().is_empty() {
661 entry.swap_remove();
662 }
663 }
664 Entry::Vacant(_) => {}
665 }
666 self.certificates.write().swap_remove(&certificate_id);
668 self.batch_ids.write().swap_remove(&batch_id);
670 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
672 true
674 }
675}
676
677impl<N: Network> Storage<N> {
678 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
680 if next_height > self.current_height() {
682 self.current_height.store(next_height, Ordering::SeqCst);
684 }
685 }
686
687 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
689 let next_round = next_round.max(1);
691 if next_round > self.current_round() {
693 self.update_current_round(next_round);
695 guard_info!(self, "Synced to round {next_round}...");
697 }
698 }
699
700 pub(crate) fn sync_certificate_with_block(
702 &self,
703 block: &Block<N>,
704 certificate: BatchCertificate<N>,
705 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
706 ) {
707 if certificate.round() <= self.gc_round() {
709 return;
710 }
711 if self.contains_certificate(certificate.id()) {
713 return;
714 }
715 let mut missing_transmissions = HashMap::new();
717
718 let mut aborted_transmissions = HashSet::new();
720
721 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
723 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
724
725 for transmission_id in certificate.transmission_ids() {
727 if missing_transmissions.contains_key(transmission_id) {
729 continue;
730 }
731 if self.contains_transmission(*transmission_id) {
733 continue;
734 }
735 match transmission_id {
737 TransmissionID::Ratification => (),
738 TransmissionID::Solution(solution_id, _) => {
739 match block.get_solution(solution_id) {
741 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
743 None => match self.ledger.get_solution(solution_id) {
745 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
747 Err(_) => {
749 match aborted_solutions.contains(solution_id)
751 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
752 {
753 true => {
754 aborted_transmissions.insert(*transmission_id);
755 }
756 false => {
757 guard_error!(self, "Missing solution {solution_id} in block {}", block.height())
758 }
759 }
760 continue;
761 }
762 },
763 };
764 }
765 TransmissionID::Transaction(transaction_id, _) => {
766 match unconfirmed_transactions.get(transaction_id) {
768 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
770 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
772 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
774 Err(_) => {
776 match aborted_transactions.contains(transaction_id)
778 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
779 {
780 true => {
781 aborted_transmissions.insert(*transmission_id);
782 }
783 false => guard_warn!(
784 self,
785 "Missing transaction {transaction_id} in block {}",
786 block.height()
787 ),
788 }
789 continue;
790 }
791 },
792 };
793 }
794 }
795 }
796 let certificate_id = fmt_id(certificate.id());
798 guard_debug!(
799 self,
800 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
801 certificate.round(),
802 certificate.transmission_ids().len()
803 );
804 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
805 guard_error!(
806 self,
807 "Failed to insert certificate '{certificate_id}' from block {} - {error}",
808 block.height()
809 );
810 }
811 }
812}
813
814#[cfg(test)]
815impl<N: Network> Storage<N> {
816 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
818 &self.ledger
819 }
820
821 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
823 self.rounds.read().clone().into_iter()
824 }
825
826 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
828 self.certificates.read().clone().into_iter()
829 }
830
831 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
833 self.batch_ids.read().clone().into_iter()
834 }
835
836 pub fn transmissions_iter(
838 &self,
839 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
840 self.transmissions.as_hashmap().into_iter()
841 }
842
843 #[cfg(test)]
847 #[doc(hidden)]
848 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
849 let round = certificate.round();
851 let certificate_id = certificate.id();
853 let batch_id = certificate.batch_id();
855 let author = certificate.author();
857
858 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
860 let transmission_ids = certificate.transmission_ids().clone();
862 self.certificates.write().insert(certificate_id, certificate);
864 self.batch_ids.write().insert(batch_id, round);
866
867 let missing_transmissions = transmission_ids
869 .iter()
870 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
871 .collect::<HashMap<_, _>>();
872 self.transmissions.insert_transmissions(
874 certificate_id,
875 transmission_ids,
876 Default::default(),
877 missing_transmissions,
878 );
879 }
880}
881
882#[cfg(test)]
883pub(crate) mod tests {
884 use super::*;
885 use amareleo_node_bft_ledger_service::MockLedgerService;
886 use amareleo_node_bft_storage_service::BFTMemoryService;
887 use snarkvm::{
888 ledger::narwhal::Data,
889 prelude::{Rng, TestRng},
890 };
891
892 use ::bytes::Bytes;
893 use indexmap::indexset;
894
895 type CurrentNetwork = snarkvm::prelude::MainnetV0;
896
897 pub fn assert_storage<N: Network>(
899 storage: &Storage<N>,
900 rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
901 certificates: &[(Field<N>, BatchCertificate<N>)],
902 batch_ids: &[(Field<N>, u64)],
903 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
904 ) {
905 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
907 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
909 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
911 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
913 }
914
915 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
917 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
919 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
921 match rng.gen::<bool>() {
923 true => Transmission::Solution(s(rng)),
924 false => Transmission::Transaction(t(rng)),
925 }
926 }
927
928 pub(crate) fn sample_transmissions(
930 certificate: &BatchCertificate<CurrentNetwork>,
931 rng: &mut TestRng,
932 ) -> (
933 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
934 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
935 ) {
936 let certificate_id = certificate.id();
938
939 let mut missing_transmissions = HashMap::new();
940 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
941 for transmission_id in certificate.transmission_ids() {
942 let transmission = sample_transmission(rng);
944 missing_transmissions.insert(*transmission_id, transmission.clone());
946 transmissions
948 .entry(*transmission_id)
949 .or_insert((transmission, Default::default()))
950 .1
951 .insert(certificate_id);
952 }
953 (missing_transmissions, transmissions)
954 }
955
956 #[test]
959 fn test_certificate_insert_remove() {
960 let rng = &mut TestRng::default();
961
962 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
964 let ledger = Arc::new(MockLedgerService::new(committee));
966 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
968
969 assert_storage(&storage, &[], &[], &[], &Default::default());
971
972 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
974 let certificate_id = certificate.id();
976 let round = certificate.round();
978 let batch_id = certificate.batch_id();
980 let author = certificate.author();
982
983 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
985
986 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
988 assert!(storage.contains_certificate(certificate_id));
990 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
992 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
994
995 {
997 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
999 let certificates = [(certificate_id, certificate.clone())];
1001 let batch_ids = [(batch_id, round)];
1003 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1005 }
1006
1007 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1009 assert_eq!(certificate, candidate_certificate);
1011
1012 assert!(storage.remove_certificate(certificate_id));
1014 assert!(!storage.contains_certificate(certificate_id));
1016 assert!(storage.get_certificates_for_round(round).is_empty());
1018 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1020 assert_storage(&storage, &[], &[], &[], &Default::default());
1022 }
1023
1024 #[test]
1025 fn test_certificate_duplicate() {
1026 let rng = &mut TestRng::default();
1027
1028 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1030 let ledger = Arc::new(MockLedgerService::new(committee));
1032 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1034
1035 assert_storage(&storage, &[], &[], &[], &Default::default());
1037
1038 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1040 let certificate_id = certificate.id();
1042 let round = certificate.round();
1044 let batch_id = certificate.batch_id();
1046 let author = certificate.author();
1048
1049 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1051 let certificates = [(certificate_id, certificate.clone())];
1053 let batch_ids = [(batch_id, round)];
1055 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1057
1058 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1060 assert!(storage.contains_certificate(certificate_id));
1062 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1064
1065 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1067 assert!(storage.contains_certificate(certificate_id));
1069 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1071
1072 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1074 assert!(storage.contains_certificate(certificate_id));
1076 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1078 }
1079}
1080
1081#[cfg(test)]
1082pub mod prop_tests {
1083 use super::*;
1084 use crate::helpers::{now, storage::tests::assert_storage};
1085 use amareleo_node_bft_ledger_service::MockLedgerService;
1086 use amareleo_node_bft_storage_service::BFTMemoryService;
1087 use snarkvm::{
1088 ledger::{
1089 committee::prop_tests::{CommitteeContext, ValidatorSet},
1090 narwhal::{BatchHeader, Data},
1091 puzzle::SolutionID,
1092 },
1093 prelude::{Signature, Uniform},
1094 };
1095
1096 use ::bytes::Bytes;
1097 use indexmap::indexset;
1098 use proptest::{
1099 collection,
1100 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1101 prop_oneof,
1102 sample::{Selector, size_range},
1103 test_runner::TestRng,
1104 };
1105 use rand::{CryptoRng, Error, Rng, RngCore};
1106 use std::fmt::Debug;
1107 use test_strategy::proptest;
1108
1109 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1110
1111 impl Arbitrary for Storage<CurrentNetwork> {
1112 type Parameters = CommitteeContext;
1113 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1114
1115 fn arbitrary() -> Self::Strategy {
1116 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1117 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1118 let ledger = Arc::new(MockLedgerService::new(committee));
1119 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1120 })
1121 .boxed()
1122 }
1123
1124 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1125 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1126 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1127 let ledger = Arc::new(MockLedgerService::new(committee));
1128 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds, None)
1129 })
1130 .boxed()
1131 }
1132 }
1133
1134 #[derive(Debug)]
1136 pub struct CryptoTestRng(TestRng);
1137
1138 impl Arbitrary for CryptoTestRng {
1139 type Parameters = ();
1140 type Strategy = BoxedStrategy<CryptoTestRng>;
1141
1142 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1143 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1144 }
1145 }
1146 impl RngCore for CryptoTestRng {
1147 fn next_u32(&mut self) -> u32 {
1148 self.0.next_u32()
1149 }
1150
1151 fn next_u64(&mut self) -> u64 {
1152 self.0.next_u64()
1153 }
1154
1155 fn fill_bytes(&mut self, dest: &mut [u8]) {
1156 self.0.fill_bytes(dest);
1157 }
1158
1159 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1160 self.0.try_fill_bytes(dest)
1161 }
1162 }
1163
1164 impl CryptoRng for CryptoTestRng {}
1165
1166 #[derive(Debug, Clone)]
1167 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1168
1169 impl Arbitrary for AnyTransmission {
1170 type Parameters = ();
1171 type Strategy = BoxedStrategy<AnyTransmission>;
1172
1173 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1174 any_transmission().prop_map(AnyTransmission).boxed()
1175 }
1176 }
1177
1178 #[derive(Debug, Clone)]
1179 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1180
1181 impl Arbitrary for AnyTransmissionID {
1182 type Parameters = ();
1183 type Strategy = BoxedStrategy<AnyTransmissionID>;
1184
1185 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1186 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1187 }
1188 }
1189
1190 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1191 prop_oneof![
1192 (collection::vec(any::<u8>(), 512..=512))
1193 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1194 (collection::vec(any::<u8>(), 2048..=2048))
1195 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1196 ]
1197 .boxed()
1198 }
1199
1200 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1201 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1202 }
1203
1204 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1205 Just(0)
1206 .prop_perturb(|_, rng| {
1207 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1208 })
1209 .boxed()
1210 }
1211
1212 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1213 prop_oneof![
1214 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1215 id,
1216 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1217 )),
1218 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1219 id,
1220 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1221 )),
1222 ]
1223 .boxed()
1224 }
1225
1226 pub fn sign_batch_header<R: Rng + CryptoRng>(
1227 validator_set: &ValidatorSet,
1228 batch_header: &BatchHeader<CurrentNetwork>,
1229 rng: &mut R,
1230 ) -> IndexSet<Signature<CurrentNetwork>> {
1231 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1232 for validator in validator_set.0.iter() {
1233 let private_key = validator.private_key;
1234 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1235 }
1236 signatures
1237 }
1238
1239 #[proptest]
1240 fn test_certificate_duplicate(
1241 context: CommitteeContext,
1242 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1243 mut rng: CryptoTestRng,
1244 selector: Selector,
1245 ) {
1246 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1247 let committee_id = committee.id();
1248
1249 let ledger = Arc::new(MockLedgerService::new(committee));
1251 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1, None);
1252
1253 assert_storage(&storage, &[], &[], &[], &Default::default());
1255
1256 let signer = selector.select(&validators);
1258
1259 let mut transmission_map = IndexMap::new();
1260
1261 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1262 transmission_map.insert(*id, t.clone());
1263 }
1264
1265 let batch_header = BatchHeader::new(
1266 &signer.private_key,
1267 0,
1268 now(),
1269 committee_id,
1270 transmission_map.keys().cloned().collect(),
1271 Default::default(),
1272 &mut rng,
1273 )
1274 .unwrap();
1275
1276 let mut validators = validators.clone();
1279 validators.remove(signer);
1280
1281 let certificate = BatchCertificate::from(
1282 batch_header.clone(),
1283 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1284 )
1285 .unwrap();
1286
1287 let certificate_id = certificate.id();
1289 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1290 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1291 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1292 }
1293
1294 let round = certificate.round();
1296 let batch_id = certificate.batch_id();
1298 let author = certificate.author();
1300
1301 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1303 let certificates = [(certificate_id, certificate.clone())];
1305 let batch_ids = [(batch_id, round)];
1307
1308 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1310 transmission_map.into_iter().collect();
1311 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1312 assert!(storage.contains_certificate(certificate_id));
1314 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1316
1317 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1319 assert!(storage.contains_certificate(certificate_id));
1321 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1323
1324 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1326 assert!(storage.contains_certificate(certificate_id));
1328 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1330 }
1331}