1use crate::{
17 MAX_WORKERS,
18 ProposedBatch,
19 helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
22use amareleo_node_bft_ledger_service::LedgerService;
23use snarkvm::{
24 console::prelude::*,
25 ledger::{
26 block::Transaction,
27 narwhal::{BatchHeader, Data, Transmission, TransmissionID},
28 puzzle::{Solution, SolutionID},
29 },
30};
31
32use colored::Colorize;
33use indexmap::{IndexMap, IndexSet};
34use std::sync::Arc;
35use tracing::subscriber::DefaultGuard;
36
37#[derive(Clone)]
38pub struct Worker<N: Network> {
39 id: u8,
41 storage: Storage<N>,
43 ledger: Arc<dyn LedgerService<N>>,
45 proposed_batch: Arc<ProposedBatch<N>>,
47 tracing: Option<TracingHandler>,
49 ready: Ready<N>,
51}
52
53impl<N: Network> TracingHandlerGuard for Worker<N> {
54 fn get_tracing_guard(&self) -> Option<DefaultGuard> {
56 self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
57 }
58}
59
60impl<N: Network> Worker<N> {
61 pub fn new(
63 id: u8,
64 storage: Storage<N>,
65 ledger: Arc<dyn LedgerService<N>>,
66 proposed_batch: Arc<ProposedBatch<N>>,
67 tracing: Option<TracingHandler>,
68 ) -> Result<Self> {
69 ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
71 Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
73 }
74
75 pub const fn id(&self) -> u8 {
77 self.id
78 }
79}
80
81impl<N: Network> Worker<N> {
82 pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
84 BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
85 pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
87
88 pub fn num_transmissions(&self) -> usize {
92 self.ready.num_transmissions()
93 }
94
95 pub fn num_ratifications(&self) -> usize {
97 self.ready.num_ratifications()
98 }
99
100 pub fn num_solutions(&self) -> usize {
102 self.ready.num_solutions()
103 }
104
105 pub fn num_transactions(&self) -> usize {
107 self.ready.num_transactions()
108 }
109}
110
111impl<N: Network> Worker<N> {
112 pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
114 self.ready.transmission_ids()
115 }
116
117 pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
119 self.ready.transmissions()
120 }
121
122 pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
124 self.ready.solutions()
125 }
126
127 pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
129 self.ready.transactions()
130 }
131}
132
133impl<N: Network> Worker<N> {
134 pub(super) fn clear_solutions(&self) {
136 self.ready.clear_solutions()
137 }
138}
139
140impl<N: Network> Worker<N> {
141 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
143 let transmission_id = transmission_id.into();
144 self.ready.contains(transmission_id)
146 || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
147 || self.storage.contains_transmission(transmission_id)
148 || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
149 }
150
151 pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
156 if let Some(transmission) = self.ready.get(transmission_id) {
158 return Some(transmission);
159 }
160 if let Some(transmission) = self.storage.get_transmission(transmission_id) {
162 return Some(transmission);
163 }
164 if let Some(transmission) =
166 self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
167 {
168 return Some(transmission.clone());
169 }
170 None
171 }
172
173 pub async fn get_or_fetch_transmission(
175 &self,
176 transmission_id: TransmissionID<N>,
177 ) -> Result<(TransmissionID<N>, Transmission<N>)> {
178 if let Some(transmission) = self.get_transmission(transmission_id) {
180 return Ok((transmission_id, transmission));
181 }
182
183 bail!("Unable to fetch transmission");
184 }
185
186 pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
188 self.ready.drain(num_transmissions).into_iter()
189 }
190
191 pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
193 if !self.contains_transmission(transmission_id) {
195 return self.ready.insert(transmission_id, transmission);
197 }
198 false
199 }
200}
201
202impl<N: Network> Worker<N> {
203 pub(crate) async fn process_unconfirmed_solution(
206 &self,
207 solution_id: SolutionID<N>,
208 solution: Data<Solution<N>>,
209 ) -> Result<()> {
210 let transmission = Transmission::Solution(solution.clone());
212 let checksum = solution.to_checksum::<N>()?;
214 let transmission_id = TransmissionID::Solution(solution_id, checksum);
216 if self.contains_transmission(transmission_id) {
218 bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
219 }
220 self.ledger.check_solution_basic(solution_id, solution).await?;
222 if self.ready.insert(transmission_id, transmission) {
224 guard_trace!(
225 self,
226 "Worker {} - Added unconfirmed solution '{}.{}'",
227 self.id,
228 fmt_id(solution_id),
229 fmt_id(checksum).dimmed()
230 );
231 }
232 Ok(())
233 }
234
235 pub(crate) async fn process_unconfirmed_transaction(
237 &self,
238 transaction_id: N::TransactionID,
239 transaction: Data<Transaction<N>>,
240 ) -> Result<()> {
241 let transmission = Transmission::Transaction(transaction.clone());
243 let checksum = transaction.to_checksum::<N>()?;
245 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
247 if self.contains_transmission(transmission_id) {
249 bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
250 }
251 self.ledger.check_transaction_basic(transaction_id, transaction).await?;
253 if self.ready.insert(transmission_id, transmission) {
255 guard_trace!(
256 self,
257 "Worker {}.{} - Added unconfirmed transaction '{}'",
258 self.id,
259 fmt_id(transaction_id),
260 fmt_id(checksum).dimmed()
261 );
262 }
263 Ok(())
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 use amareleo_node_bft_ledger_service::LedgerService;
272 use amareleo_node_bft_storage_service::BFTMemoryService;
273 use snarkvm::{
274 console::{network::Network, types::Field},
275 ledger::{
276 block::Block,
277 committee::Committee,
278 narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
279 },
280 prelude::Address,
281 };
282
283 use async_trait::async_trait;
284 use bytes::Bytes;
285 use indexmap::IndexMap;
286 use mockall::mock;
287 use std::ops::Range;
288
289 type CurrentNetwork = snarkvm::prelude::MainnetV0;
290
291 const ITERATIONS: usize = 100;
292
293 mock! {
294 #[derive(Debug)]
295 Ledger<N: Network> {}
296 #[async_trait]
297 impl<N: Network> LedgerService<N> for Ledger<N> {
298 fn latest_round(&self) -> u64;
299 fn latest_block_height(&self) -> u32;
300 fn latest_block(&self) -> Block<N>;
301 fn latest_restrictions_id(&self) -> Field<N>;
302 fn latest_leader(&self) -> Option<(u64, Address<N>)>;
303 fn update_latest_leader(&self, round: u64, leader: Address<N>);
304 fn contains_block_height(&self, height: u32) -> bool;
305 fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
306 fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
307 fn get_block_round(&self, height: u32) -> Result<u64>;
308 fn get_block(&self, height: u32) -> Result<Block<N>>;
309 fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
310 fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
311 fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
312 fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
313 fn current_committee(&self) -> Result<Committee<N>>;
314 fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
315 fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
316 fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
317 fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
318 fn ensure_transmission_is_well_formed(
319 &self,
320 transmission_id: TransmissionID<N>,
321 transmission: &mut Transmission<N>,
322 ) -> Result<()>;
323 async fn check_solution_basic(
324 &self,
325 solution_id: SolutionID<N>,
326 solution: Data<Solution<N>>,
327 ) -> Result<()>;
328 async fn check_transaction_basic(
329 &self,
330 transaction_id: N::TransactionID,
331 transaction: Data<Transaction<N>>,
332 ) -> Result<()>;
333 fn check_next_block(&self, block: &Block<N>) -> Result<()>;
334 fn prepare_advance_to_next_quorum_block(
335 &self,
336 subdag: Subdag<N>,
337 transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
338 ) -> Result<Block<N>>;
339 fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
340 }
341 }
342
343 #[tokio::test]
344 async fn test_process_solution_ok() {
345 let rng = &mut TestRng::default();
346 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
348 let committee_clone = committee.clone();
349
350 let mut mock_ledger = MockLedger::default();
351 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
352 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
353 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
354 mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
355 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
356 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
358
359 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
361 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
362 let solution_id = rng.gen::<u64>().into();
363 let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
364 let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
365 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
366 assert!(result.is_ok());
367 assert!(worker.ready.contains(transmission_id));
368 }
369
370 #[tokio::test]
371 async fn test_process_solution_nok() {
372 let rng = &mut TestRng::default();
373 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
375 let committee_clone = committee.clone();
376
377 let mut mock_ledger = MockLedger::default();
378 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
379 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
380 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
381 mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
382 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
383 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
385
386 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
388 let solution_id = rng.gen::<u64>().into();
389 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
390 let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
391 let transmission_id = TransmissionID::Solution(solution_id, checksum);
392 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
393 assert!(result.is_err());
394 assert!(!worker.ready.contains(transmission_id));
395 }
396
397 #[tokio::test]
398 async fn test_process_transaction_ok() {
399 let mut rng = &mut TestRng::default();
400 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
402 let committee_clone = committee.clone();
403
404 let mut mock_ledger = MockLedger::default();
405 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
406 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
407 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
408 mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
409 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
410 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
412
413 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
415 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
416 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
417 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
418 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
419 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
420 assert!(result.is_ok());
421 assert!(worker.ready.contains(transmission_id));
422 }
423
424 #[tokio::test]
425 async fn test_process_transaction_nok() {
426 let mut rng = &mut TestRng::default();
427 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
429 let committee_clone = committee.clone();
430
431 let mut mock_ledger = MockLedger::default();
432 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
433 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
434 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
435 mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
436 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
437 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
439
440 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
442 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
443 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
444 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
445 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
446 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
447 assert!(result.is_err());
448 assert!(!worker.ready.contains(transmission_id));
449 }
450
451 #[tokio::test]
452 async fn test_storage_gc_on_initialization() {
453 let rng = &mut TestRng::default();
454
455 for _ in 0..ITERATIONS {
456 let max_gc_rounds = rng.gen_range(50..=100);
458 let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
459 let expected_gc_round = latest_ledger_round - max_gc_rounds;
460
461 let committee =
463 snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
464
465 let mut mock_ledger = MockLedger::default();
466 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
467
468 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
469 let storage =
471 Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
472
473 assert_eq!(storage.gc_round(), expected_gc_round);
475 }
476 }
477}
478
479#[cfg(test)]
480mod prop_tests {
481 use super::*;
482 use amareleo_node_bft_ledger_service::MockLedgerService;
483 use snarkvm::{
484 console::account::Address,
485 ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
486 };
487
488 use test_strategy::proptest;
489
490 type CurrentNetwork = snarkvm::prelude::MainnetV0;
491
492 fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
494 let mut members = IndexMap::with_capacity(n as usize);
495 for i in 0..n {
496 let rng = &mut TestRng::fixed(i as u64);
498 let address = Address::new(rng.gen());
499 members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
501 }
502 Committee::<CurrentNetwork>::new(1u64, members).unwrap()
504 }
505
506 #[proptest]
507 fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
508 let committee = new_test_committee(4);
509 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
510 let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
511 assert_eq!(worker.id(), id);
512 }
513
514 #[proptest]
515 fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
516 let committee = new_test_committee(4);
517 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
518 let worker = Worker::new(id, storage, ledger, Default::default(), None);
519 if let Err(error) = worker {
521 assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
522 }
523 }
524}