1use crate::{
17 MAX_WORKERS,
18 ProposedBatch,
19 helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_chain_tracing::TracingHandler;
22use amareleo_node_bft_ledger_service::LedgerService;
23use snarkvm::{
24 console::prelude::*,
25 ledger::{
26 block::Transaction,
27 narwhal::{BatchHeader, Data, Transmission, TransmissionID},
28 puzzle::{Solution, SolutionID},
29 },
30};
31
32use colored::Colorize;
33use indexmap::{IndexMap, IndexSet};
34use std::sync::Arc;
35use tracing::subscriber::DefaultGuard;
36
37#[derive(Clone)]
38pub struct Worker<N: Network> {
39 id: u8,
41 storage: Storage<N>,
43 ledger: Arc<dyn LedgerService<N>>,
45 proposed_batch: Arc<ProposedBatch<N>>,
47 tracing: Option<TracingHandler>,
49 ready: Ready<N>,
51}
52
53impl<N: Network> Worker<N> {
54 pub fn new(
56 id: u8,
57 storage: Storage<N>,
58 ledger: Arc<dyn LedgerService<N>>,
59 proposed_batch: Arc<ProposedBatch<N>>,
60 tracing: Option<TracingHandler>,
61 ) -> Result<Self> {
62 ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
64 Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
66 }
67
68 pub const fn id(&self) -> u8 {
70 self.id
71 }
72}
73
74impl<N: Network> Worker<N> {
75 pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
77 BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
78 pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
80
81 pub fn num_transmissions(&self) -> usize {
85 self.ready.num_transmissions()
86 }
87
88 pub fn num_ratifications(&self) -> usize {
90 self.ready.num_ratifications()
91 }
92
93 pub fn num_solutions(&self) -> usize {
95 self.ready.num_solutions()
96 }
97
98 pub fn num_transactions(&self) -> usize {
100 self.ready.num_transactions()
101 }
102}
103
104impl<N: Network> Worker<N> {
105 pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
107 self.ready.transmission_ids()
108 }
109
110 pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
112 self.ready.transmissions()
113 }
114
115 pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
117 self.ready.solutions()
118 }
119
120 pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
122 self.ready.transactions()
123 }
124}
125
126impl<N: Network> Worker<N> {
127 pub(super) fn clear_solutions(&self) {
129 self.ready.clear_solutions()
130 }
131}
132
133impl<N: Network> Worker<N> {
134 pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
136 self.tracing.clone().map(|trace_handle| trace_handle.subscribe_thread())
137 }
138
139 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
141 let transmission_id = transmission_id.into();
142 self.ready.contains(transmission_id)
144 || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
145 || self.storage.contains_transmission(transmission_id)
146 || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
147 }
148
149 pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
154 if let Some(transmission) = self.ready.get(transmission_id) {
156 return Some(transmission);
157 }
158 if let Some(transmission) = self.storage.get_transmission(transmission_id) {
160 return Some(transmission);
161 }
162 if let Some(transmission) =
164 self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
165 {
166 return Some(transmission.clone());
167 }
168 None
169 }
170
171 pub async fn get_or_fetch_transmission(
173 &self,
174 transmission_id: TransmissionID<N>,
175 ) -> Result<(TransmissionID<N>, Transmission<N>)> {
176 if let Some(transmission) = self.get_transmission(transmission_id) {
178 return Ok((transmission_id, transmission));
179 }
180
181 bail!("Unable to fetch transmission");
182 }
183
184 pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
186 self.ready.drain(num_transmissions).into_iter()
187 }
188
189 pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
191 if !self.contains_transmission(transmission_id) {
193 return self.ready.insert(transmission_id, transmission);
195 }
196 false
197 }
198}
199
200impl<N: Network> Worker<N> {
201 pub(crate) async fn process_unconfirmed_solution(
204 &self,
205 solution_id: SolutionID<N>,
206 solution: Data<Solution<N>>,
207 ) -> Result<()> {
208 let _guard = self.get_tracing_guard();
209 let transmission = Transmission::Solution(solution.clone());
211 let checksum = solution.to_checksum::<N>()?;
213 let transmission_id = TransmissionID::Solution(solution_id, checksum);
215 if self.contains_transmission(transmission_id) {
217 bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
218 }
219 self.ledger.check_solution_basic(solution_id, solution).await?;
221 if self.ready.insert(transmission_id, transmission) {
223 trace!(
224 "Worker {} - Added unconfirmed solution '{}.{}'",
225 self.id,
226 fmt_id(solution_id),
227 fmt_id(checksum).dimmed()
228 );
229 }
230 Ok(())
231 }
232
233 pub(crate) async fn process_unconfirmed_transaction(
235 &self,
236 transaction_id: N::TransactionID,
237 transaction: Data<Transaction<N>>,
238 ) -> Result<()> {
239 let _guard = self.get_tracing_guard();
240 let transmission = Transmission::Transaction(transaction.clone());
242 let checksum = transaction.to_checksum::<N>()?;
244 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
246 if self.contains_transmission(transmission_id) {
248 bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
249 }
250 self.ledger.check_transaction_basic(transaction_id, transaction).await?;
252 if self.ready.insert(transmission_id, transmission) {
254 trace!(
255 "Worker {}.{} - Added unconfirmed transaction '{}'",
256 self.id,
257 fmt_id(transaction_id),
258 fmt_id(checksum).dimmed()
259 );
260 }
261 Ok(())
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 use amareleo_node_bft_ledger_service::LedgerService;
270 use amareleo_node_bft_storage_service::BFTMemoryService;
271 use snarkvm::{
272 console::{network::Network, types::Field},
273 ledger::{
274 block::Block,
275 committee::Committee,
276 narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
277 },
278 prelude::Address,
279 };
280
281 use async_trait::async_trait;
282 use bytes::Bytes;
283 use indexmap::IndexMap;
284 use mockall::mock;
285 use std::ops::Range;
286
287 type CurrentNetwork = snarkvm::prelude::MainnetV0;
288
289 const ITERATIONS: usize = 100;
290
291 mock! {
292 #[derive(Debug)]
293 Ledger<N: Network> {}
294 #[async_trait]
295 impl<N: Network> LedgerService<N> for Ledger<N> {
296 fn latest_round(&self) -> u64;
297 fn latest_block_height(&self) -> u32;
298 fn latest_block(&self) -> Block<N>;
299 fn latest_restrictions_id(&self) -> Field<N>;
300 fn latest_leader(&self) -> Option<(u64, Address<N>)>;
301 fn update_latest_leader(&self, round: u64, leader: Address<N>);
302 fn contains_block_height(&self, height: u32) -> bool;
303 fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
304 fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
305 fn get_block_round(&self, height: u32) -> Result<u64>;
306 fn get_block(&self, height: u32) -> Result<Block<N>>;
307 fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
308 fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
309 fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
310 fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
311 fn current_committee(&self) -> Result<Committee<N>>;
312 fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
313 fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
314 fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
315 fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
316 fn ensure_transmission_is_well_formed(
317 &self,
318 transmission_id: TransmissionID<N>,
319 transmission: &mut Transmission<N>,
320 ) -> Result<()>;
321 async fn check_solution_basic(
322 &self,
323 solution_id: SolutionID<N>,
324 solution: Data<Solution<N>>,
325 ) -> Result<()>;
326 async fn check_transaction_basic(
327 &self,
328 transaction_id: N::TransactionID,
329 transaction: Data<Transaction<N>>,
330 ) -> Result<()>;
331 fn check_next_block(&self, block: &Block<N>) -> Result<()>;
332 fn prepare_advance_to_next_quorum_block(
333 &self,
334 subdag: Subdag<N>,
335 transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
336 ) -> Result<Block<N>>;
337 fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
338 }
339 }
340
341 #[tokio::test]
342 async fn test_process_solution_ok() {
343 let rng = &mut TestRng::default();
344 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
346 let committee_clone = committee.clone();
347
348 let mut mock_ledger = MockLedger::default();
349 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
350 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
351 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
352 mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
353 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
354 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
356
357 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
359 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
360 let solution_id = rng.gen::<u64>().into();
361 let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
362 let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
363 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
364 assert!(result.is_ok());
365 assert!(worker.ready.contains(transmission_id));
366 }
367
368 #[tokio::test]
369 async fn test_process_solution_nok() {
370 let rng = &mut TestRng::default();
371 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
373 let committee_clone = committee.clone();
374
375 let mut mock_ledger = MockLedger::default();
376 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
377 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
378 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
379 mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
380 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
381 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
383
384 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
386 let solution_id = rng.gen::<u64>().into();
387 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
388 let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
389 let transmission_id = TransmissionID::Solution(solution_id, checksum);
390 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
391 assert!(result.is_err());
392 assert!(!worker.ready.contains(transmission_id));
393 }
394
395 #[tokio::test]
396 async fn test_process_transaction_ok() {
397 let mut rng = &mut TestRng::default();
398 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
400 let committee_clone = committee.clone();
401
402 let mut mock_ledger = MockLedger::default();
403 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
404 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
405 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
406 mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
407 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
408 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
410
411 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
413 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
414 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
415 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
416 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
417 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
418 assert!(result.is_ok());
419 assert!(worker.ready.contains(transmission_id));
420 }
421
422 #[tokio::test]
423 async fn test_process_transaction_nok() {
424 let mut rng = &mut TestRng::default();
425 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
427 let committee_clone = committee.clone();
428
429 let mut mock_ledger = MockLedger::default();
430 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
431 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
432 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
433 mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
434 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
435 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
437
438 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
440 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
441 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
442 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
443 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
444 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
445 assert!(result.is_err());
446 assert!(!worker.ready.contains(transmission_id));
447 }
448
449 #[tokio::test]
450 async fn test_storage_gc_on_initialization() {
451 let rng = &mut TestRng::default();
452
453 for _ in 0..ITERATIONS {
454 let max_gc_rounds = rng.gen_range(50..=100);
456 let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
457 let expected_gc_round = latest_ledger_round - max_gc_rounds;
458
459 let committee =
461 snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
462
463 let mut mock_ledger = MockLedger::default();
464 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
465
466 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
467 let storage =
469 Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
470
471 assert_eq!(storage.gc_round(), expected_gc_round);
473 }
474 }
475}
476
477#[cfg(test)]
478mod prop_tests {
479 use super::*;
480 use amareleo_node_bft_ledger_service::MockLedgerService;
481 use snarkvm::{
482 console::account::Address,
483 ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
484 };
485
486 use test_strategy::proptest;
487
488 type CurrentNetwork = snarkvm::prelude::MainnetV0;
489
490 fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
492 let mut members = IndexMap::with_capacity(n as usize);
493 for i in 0..n {
494 let rng = &mut TestRng::fixed(i as u64);
496 let address = Address::new(rng.gen());
497 info!("Validator {i}: {address}");
498 members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
499 }
500 Committee::<CurrentNetwork>::new(1u64, members).unwrap()
502 }
503
504 #[proptest]
505 fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
506 let committee = new_test_committee(4);
507 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
508 let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
509 assert_eq!(worker.id(), id);
510 }
511
512 #[proptest]
513 fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
514 let committee = new_test_committee(4);
515 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
516 let worker = Worker::new(id, storage, ledger, Default::default(), None);
517 if let Err(error) = worker {
519 assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
520 }
521 }
522}