use crate::{
MAX_WORKERS,
ProposedBatch,
helpers::{Ready, Storage, fmt_id},
};
use amareleo_node_bft_ledger_service::LedgerService;
use snarkvm::{
console::prelude::*,
ledger::{
block::Transaction,
narwhal::{BatchHeader, Data, Transmission, TransmissionID},
puzzle::{Solution, SolutionID},
},
};
use colored::Colorize;
use indexmap::{IndexMap, IndexSet};
use std::sync::Arc;
#[derive(Clone)]
pub struct Worker<N: Network> {
id: u8,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
proposed_batch: Arc<ProposedBatch<N>>,
ready: Ready<N>,
}
impl<N: Network> Worker<N> {
pub fn new(
id: u8,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
proposed_batch: Arc<ProposedBatch<N>>,
) -> Result<Self> {
ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
Ok(Self { id, storage, ledger, proposed_batch, ready: Default::default() })
}
pub const fn id(&self) -> u8 {
self.id
}
}
impl<N: Network> Worker<N> {
pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
pub fn num_transmissions(&self) -> usize {
self.ready.num_transmissions()
}
pub fn num_ratifications(&self) -> usize {
self.ready.num_ratifications()
}
pub fn num_solutions(&self) -> usize {
self.ready.num_solutions()
}
pub fn num_transactions(&self) -> usize {
self.ready.num_transactions()
}
}
impl<N: Network> Worker<N> {
pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
self.ready.transmission_ids()
}
pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
self.ready.transmissions()
}
pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
self.ready.solutions()
}
pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
self.ready.transactions()
}
}
impl<N: Network> Worker<N> {
pub(super) fn clear_solutions(&self) {
self.ready.clear_solutions()
}
}
impl<N: Network> Worker<N> {
pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
let transmission_id = transmission_id.into();
self.ready.contains(transmission_id)
|| self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
|| self.storage.contains_transmission(transmission_id)
|| self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
}
pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
if let Some(transmission) = self.ready.get(transmission_id) {
return Some(transmission);
}
if let Some(transmission) = self.storage.get_transmission(transmission_id) {
return Some(transmission);
}
if let Some(transmission) =
self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
{
return Some(transmission.clone());
}
None
}
pub async fn get_or_fetch_transmission(
&self,
transmission_id: TransmissionID<N>,
) -> Result<(TransmissionID<N>, Transmission<N>)> {
if let Some(transmission) = self.get_transmission(transmission_id) {
return Ok((transmission_id, transmission));
}
bail!("Unable to fetch transmission");
}
pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
self.ready.drain(num_transmissions).into_iter()
}
pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
if !self.contains_transmission(transmission_id) {
return self.ready.insert(transmission_id, transmission);
}
false
}
}
impl<N: Network> Worker<N> {
pub(crate) async fn process_unconfirmed_solution(
&self,
solution_id: SolutionID<N>,
solution: Data<Solution<N>>,
) -> Result<()> {
let transmission = Transmission::Solution(solution.clone());
let checksum = solution.to_checksum::<N>()?;
let transmission_id = TransmissionID::Solution(solution_id, checksum);
if self.contains_transmission(transmission_id) {
bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
}
self.ledger.check_solution_basic(solution_id, solution).await?;
if self.ready.insert(transmission_id, transmission) {
trace!(
"Worker {} - Added unconfirmed solution '{}.{}'",
self.id,
fmt_id(solution_id),
fmt_id(checksum).dimmed()
);
}
Ok(())
}
pub(crate) async fn process_unconfirmed_transaction(
&self,
transaction_id: N::TransactionID,
transaction: Data<Transaction<N>>,
) -> Result<()> {
let transmission = Transmission::Transaction(transaction.clone());
let checksum = transaction.to_checksum::<N>()?;
let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
if self.contains_transmission(transmission_id) {
bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
}
self.ledger.check_transaction_basic(transaction_id, transaction).await?;
if self.ready.insert(transmission_id, transmission) {
trace!(
"Worker {}.{} - Added unconfirmed transaction '{}'",
self.id,
fmt_id(transaction_id),
fmt_id(checksum).dimmed()
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use amareleo_node_bft_ledger_service::LedgerService;
use amareleo_node_bft_storage_service::BFTMemoryService;
use snarkvm::{
console::{network::Network, types::Field},
ledger::{
block::Block,
committee::Committee,
narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
},
prelude::Address,
};
use async_trait::async_trait;
use bytes::Bytes;
use indexmap::IndexMap;
use mockall::mock;
use std::ops::Range;
type CurrentNetwork = snarkvm::prelude::MainnetV0;
const ITERATIONS: usize = 100;
mock! {
#[derive(Debug)]
Ledger<N: Network> {}
#[async_trait]
impl<N: Network> LedgerService<N> for Ledger<N> {
fn latest_round(&self) -> u64;
fn latest_block_height(&self) -> u32;
fn latest_block(&self) -> Block<N>;
fn latest_restrictions_id(&self) -> Field<N>;
fn latest_leader(&self) -> Option<(u64, Address<N>)>;
fn update_latest_leader(&self, round: u64, leader: Address<N>);
fn contains_block_height(&self, height: u32) -> bool;
fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
fn get_block_round(&self, height: u32) -> Result<u64>;
fn get_block(&self, height: u32) -> Result<Block<N>>;
fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
fn current_committee(&self) -> Result<Committee<N>>;
fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
fn ensure_transmission_is_well_formed(
&self,
transmission_id: TransmissionID<N>,
transmission: &mut Transmission<N>,
) -> Result<()>;
async fn check_solution_basic(
&self,
solution_id: SolutionID<N>,
solution: Data<Solution<N>>,
) -> Result<()>;
async fn check_transaction_basic(
&self,
transaction_id: N::TransactionID,
transaction: Data<Transaction<N>>,
) -> Result<()>;
fn check_next_block(&self, block: &Block<N>) -> Result<()>;
fn prepare_advance_to_next_quorum_block(
&self,
subdag: Subdag<N>,
transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
) -> Result<Block<N>>;
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
}
}
#[tokio::test]
async fn test_process_solution_ok() {
let rng = &mut TestRng::default();
let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
let committee_clone = committee.clone();
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
let solution_id = rng.gen::<u64>().into();
let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
let result = worker.process_unconfirmed_solution(solution_id, solution).await;
assert!(result.is_ok());
assert!(worker.ready.contains(transmission_id));
}
#[tokio::test]
async fn test_process_solution_nok() {
let rng = &mut TestRng::default();
let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
let committee_clone = committee.clone();
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
let solution_id = rng.gen::<u64>().into();
let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
let transmission_id = TransmissionID::Solution(solution_id, checksum);
let result = worker.process_unconfirmed_solution(solution_id, solution).await;
assert!(result.is_err());
assert!(!worker.ready.contains(transmission_id));
}
#[tokio::test]
async fn test_process_transaction_ok() {
let mut rng = &mut TestRng::default();
let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
let committee_clone = committee.clone();
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
assert!(result.is_ok());
assert!(worker.ready.contains(transmission_id));
}
#[tokio::test]
async fn test_process_transaction_nok() {
let mut rng = &mut TestRng::default();
let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
let committee_clone = committee.clone();
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
assert!(result.is_err());
assert!(!worker.ready.contains(transmission_id));
}
#[tokio::test]
async fn test_storage_gc_on_initialization() {
let rng = &mut TestRng::default();
for _ in 0..ITERATIONS {
let max_gc_rounds = rng.gen_range(50..=100);
let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
let expected_gc_round = latest_ledger_round - max_gc_rounds;
let committee =
snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
let storage =
Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
assert_eq!(storage.gc_round(), expected_gc_round);
}
}
}
#[cfg(test)]
mod prop_tests {
use super::*;
use amareleo_node_bft_ledger_service::MockLedgerService;
use snarkvm::{
console::account::Address,
ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
};
use test_strategy::proptest;
type CurrentNetwork = snarkvm::prelude::MainnetV0;
fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
let mut members = IndexMap::with_capacity(n as usize);
for i in 0..n {
let rng = &mut TestRng::fixed(i as u64);
let address = Address::new(rng.gen());
info!("Validator {i}: {address}");
members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
}
Committee::<CurrentNetwork>::new(1u64, members).unwrap()
}
#[proptest]
fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
let committee = new_test_committee(4);
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
let worker = Worker::new(id, storage, ledger, Default::default()).unwrap();
assert_eq!(worker.id(), id);
}
#[proptest]
fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
let committee = new_test_committee(4);
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
let worker = Worker::new(id, storage, ledger, Default::default());
if let Err(error) = worker {
assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
}
}
}