1use crate::{
17 MAX_WORKERS,
18 ProposedBatch,
19 helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_node_bft_ledger_service::LedgerService;
22use snarkvm::{
23 console::prelude::*,
24 ledger::{
25 block::Transaction,
26 narwhal::{BatchHeader, Data, Transmission, TransmissionID},
27 puzzle::{Solution, SolutionID},
28 },
29};
30
31use colored::Colorize;
32use indexmap::{IndexMap, IndexSet};
33use std::sync::Arc;
34
35#[derive(Clone)]
36pub struct Worker<N: Network> {
37 id: u8,
39 storage: Storage<N>,
41 ledger: Arc<dyn LedgerService<N>>,
43 proposed_batch: Arc<ProposedBatch<N>>,
45 ready: Ready<N>,
47}
48
49impl<N: Network> Worker<N> {
50 pub fn new(
52 id: u8,
53 storage: Storage<N>,
54 ledger: Arc<dyn LedgerService<N>>,
55 proposed_batch: Arc<ProposedBatch<N>>,
56 ) -> Result<Self> {
57 ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
59 Ok(Self { id, storage, ledger, proposed_batch, ready: Default::default() })
61 }
62
63 pub const fn id(&self) -> u8 {
65 self.id
66 }
67}
68
69impl<N: Network> Worker<N> {
70 pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
72 BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
73 pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
75
76 pub fn num_transmissions(&self) -> usize {
80 self.ready.num_transmissions()
81 }
82
83 pub fn num_ratifications(&self) -> usize {
85 self.ready.num_ratifications()
86 }
87
88 pub fn num_solutions(&self) -> usize {
90 self.ready.num_solutions()
91 }
92
93 pub fn num_transactions(&self) -> usize {
95 self.ready.num_transactions()
96 }
97}
98
99impl<N: Network> Worker<N> {
100 pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
102 self.ready.transmission_ids()
103 }
104
105 pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
107 self.ready.transmissions()
108 }
109
110 pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
112 self.ready.solutions()
113 }
114
115 pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
117 self.ready.transactions()
118 }
119}
120
121impl<N: Network> Worker<N> {
122 pub(super) fn clear_solutions(&self) {
124 self.ready.clear_solutions()
125 }
126}
127
128impl<N: Network> Worker<N> {
129 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
131 let transmission_id = transmission_id.into();
132 self.ready.contains(transmission_id)
134 || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
135 || self.storage.contains_transmission(transmission_id)
136 || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
137 }
138
139 pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
144 if let Some(transmission) = self.ready.get(transmission_id) {
146 return Some(transmission);
147 }
148 if let Some(transmission) = self.storage.get_transmission(transmission_id) {
150 return Some(transmission);
151 }
152 if let Some(transmission) =
154 self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
155 {
156 return Some(transmission.clone());
157 }
158 None
159 }
160
161 pub async fn get_or_fetch_transmission(
163 &self,
164 transmission_id: TransmissionID<N>,
165 ) -> Result<(TransmissionID<N>, Transmission<N>)> {
166 if let Some(transmission) = self.get_transmission(transmission_id) {
168 return Ok((transmission_id, transmission));
169 }
170
171 bail!("Unable to fetch transmission");
172 }
173
174 pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
176 self.ready.drain(num_transmissions).into_iter()
177 }
178
179 pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
181 if !self.contains_transmission(transmission_id) {
183 return self.ready.insert(transmission_id, transmission);
185 }
186 false
187 }
188}
189
190impl<N: Network> Worker<N> {
191 pub(crate) async fn process_unconfirmed_solution(
194 &self,
195 solution_id: SolutionID<N>,
196 solution: Data<Solution<N>>,
197 ) -> Result<()> {
198 let transmission = Transmission::Solution(solution.clone());
200 let checksum = solution.to_checksum::<N>()?;
202 let transmission_id = TransmissionID::Solution(solution_id, checksum);
204 if self.contains_transmission(transmission_id) {
206 bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
207 }
208 self.ledger.check_solution_basic(solution_id, solution).await?;
210 if self.ready.insert(transmission_id, transmission) {
212 trace!(
213 "Worker {} - Added unconfirmed solution '{}.{}'",
214 self.id,
215 fmt_id(solution_id),
216 fmt_id(checksum).dimmed()
217 );
218 }
219 Ok(())
220 }
221
222 pub(crate) async fn process_unconfirmed_transaction(
224 &self,
225 transaction_id: N::TransactionID,
226 transaction: Data<Transaction<N>>,
227 ) -> Result<()> {
228 let transmission = Transmission::Transaction(transaction.clone());
230 let checksum = transaction.to_checksum::<N>()?;
232 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
234 if self.contains_transmission(transmission_id) {
236 bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
237 }
238 self.ledger.check_transaction_basic(transaction_id, transaction).await?;
240 if self.ready.insert(transmission_id, transmission) {
242 trace!(
243 "Worker {}.{} - Added unconfirmed transaction '{}'",
244 self.id,
245 fmt_id(transaction_id),
246 fmt_id(checksum).dimmed()
247 );
248 }
249 Ok(())
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 use amareleo_node_bft_ledger_service::LedgerService;
258 use amareleo_node_bft_storage_service::BFTMemoryService;
259 use snarkvm::{
260 console::{network::Network, types::Field},
261 ledger::{
262 block::Block,
263 committee::Committee,
264 narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
265 },
266 prelude::Address,
267 };
268
269 use async_trait::async_trait;
270 use bytes::Bytes;
271 use indexmap::IndexMap;
272 use mockall::mock;
273 use std::ops::Range;
274
275 type CurrentNetwork = snarkvm::prelude::MainnetV0;
276
277 const ITERATIONS: usize = 100;
278
279 mock! {
280 #[derive(Debug)]
281 Ledger<N: Network> {}
282 #[async_trait]
283 impl<N: Network> LedgerService<N> for Ledger<N> {
284 fn latest_round(&self) -> u64;
285 fn latest_block_height(&self) -> u32;
286 fn latest_block(&self) -> Block<N>;
287 fn latest_restrictions_id(&self) -> Field<N>;
288 fn latest_leader(&self) -> Option<(u64, Address<N>)>;
289 fn update_latest_leader(&self, round: u64, leader: Address<N>);
290 fn contains_block_height(&self, height: u32) -> bool;
291 fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
292 fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
293 fn get_block_round(&self, height: u32) -> Result<u64>;
294 fn get_block(&self, height: u32) -> Result<Block<N>>;
295 fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
296 fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
297 fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
298 fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
299 fn current_committee(&self) -> Result<Committee<N>>;
300 fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
301 fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
302 fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
303 fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
304 fn ensure_transmission_is_well_formed(
305 &self,
306 transmission_id: TransmissionID<N>,
307 transmission: &mut Transmission<N>,
308 ) -> Result<()>;
309 async fn check_solution_basic(
310 &self,
311 solution_id: SolutionID<N>,
312 solution: Data<Solution<N>>,
313 ) -> Result<()>;
314 async fn check_transaction_basic(
315 &self,
316 transaction_id: N::TransactionID,
317 transaction: Data<Transaction<N>>,
318 ) -> Result<()>;
319 fn check_next_block(&self, block: &Block<N>) -> Result<()>;
320 fn prepare_advance_to_next_quorum_block(
321 &self,
322 subdag: Subdag<N>,
323 transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
324 ) -> Result<Block<N>>;
325 fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
326 }
327 }
328
329 #[tokio::test]
330 async fn test_process_solution_ok() {
331 let rng = &mut TestRng::default();
332 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
334 let committee_clone = committee.clone();
335
336 let mut mock_ledger = MockLedger::default();
337 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
338 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
339 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
340 mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
341 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
342 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
344
345 let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
347 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
348 let solution_id = rng.gen::<u64>().into();
349 let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
350 let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
351 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
352 assert!(result.is_ok());
353 assert!(worker.ready.contains(transmission_id));
354 }
355
356 #[tokio::test]
357 async fn test_process_solution_nok() {
358 let rng = &mut TestRng::default();
359 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
361 let committee_clone = committee.clone();
362
363 let mut mock_ledger = MockLedger::default();
364 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
365 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
366 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
367 mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
368 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
369 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
371
372 let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
374 let solution_id = rng.gen::<u64>().into();
375 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
376 let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
377 let transmission_id = TransmissionID::Solution(solution_id, checksum);
378 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
379 assert!(result.is_err());
380 assert!(!worker.ready.contains(transmission_id));
381 }
382
383 #[tokio::test]
384 async fn test_process_transaction_ok() {
385 let mut rng = &mut TestRng::default();
386 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
388 let committee_clone = committee.clone();
389
390 let mut mock_ledger = MockLedger::default();
391 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
392 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
393 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
394 mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
395 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
396 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
398
399 let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
401 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
402 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
403 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
404 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
405 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
406 assert!(result.is_ok());
407 assert!(worker.ready.contains(transmission_id));
408 }
409
410 #[tokio::test]
411 async fn test_process_transaction_nok() {
412 let mut rng = &mut TestRng::default();
413 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
415 let committee_clone = committee.clone();
416
417 let mut mock_ledger = MockLedger::default();
418 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
419 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
420 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
421 mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
422 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
423 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
425
426 let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
428 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
429 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
430 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
431 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
432 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
433 assert!(result.is_err());
434 assert!(!worker.ready.contains(transmission_id));
435 }
436
437 #[tokio::test]
438 async fn test_storage_gc_on_initialization() {
439 let rng = &mut TestRng::default();
440
441 for _ in 0..ITERATIONS {
442 let max_gc_rounds = rng.gen_range(50..=100);
444 let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
445 let expected_gc_round = latest_ledger_round - max_gc_rounds;
446
447 let committee =
449 snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
450
451 let mut mock_ledger = MockLedger::default();
452 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
453
454 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
455 let storage =
457 Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
458
459 assert_eq!(storage.gc_round(), expected_gc_round);
461 }
462 }
463}
464
465#[cfg(test)]
466mod prop_tests {
467 use super::*;
468 use amareleo_node_bft_ledger_service::MockLedgerService;
469 use snarkvm::{
470 console::account::Address,
471 ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
472 };
473
474 use test_strategy::proptest;
475
476 type CurrentNetwork = snarkvm::prelude::MainnetV0;
477
478 fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
480 let mut members = IndexMap::with_capacity(n as usize);
481 for i in 0..n {
482 let rng = &mut TestRng::fixed(i as u64);
484 let address = Address::new(rng.gen());
485 info!("Validator {i}: {address}");
486 members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
487 }
488 Committee::<CurrentNetwork>::new(1u64, members).unwrap()
490 }
491
492 #[proptest]
493 fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
494 let committee = new_test_committee(4);
495 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
496 let worker = Worker::new(id, storage, ledger, Default::default()).unwrap();
497 assert_eq!(worker.id(), id);
498 }
499
500 #[proptest]
501 fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
502 let committee = new_test_committee(4);
503 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
504 let worker = Worker::new(id, storage, ledger, Default::default());
505 if let Err(error) = worker {
507 assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
508 }
509 }
510}