1use crate::{
17 MAX_WORKERS,
18 ProposedBatch,
19 helpers::{Ready, Storage, fmt_id},
20 spawn_blocking,
21};
22use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
23use amareleo_node_bft_ledger_service::LedgerService;
24use snarkvm::{
25 console::prelude::*,
26 ledger::{
27 block::Transaction,
28 narwhal::{BatchHeader, Data, Transmission, TransmissionID},
29 puzzle::{Solution, SolutionID},
30 },
31};
32
33use colored::Colorize;
34use indexmap::{IndexMap, IndexSet};
35#[cfg(feature = "locktick")]
36use locktick::parking_lot::RwLock;
37#[cfg(not(feature = "locktick"))]
38use parking_lot::RwLock;
39use std::sync::Arc;
40use tracing::subscriber::DefaultGuard;
41
42#[derive(Clone)]
43pub struct Worker<N: Network> {
44 id: u8,
46 storage: Storage<N>,
48 ledger: Arc<dyn LedgerService<N>>,
50 proposed_batch: Arc<ProposedBatch<N>>,
52 tracing: Option<TracingHandler>,
54 ready: Arc<RwLock<Ready<N>>>,
56}
57
58impl<N: Network> TracingHandlerGuard for Worker<N> {
59 fn get_tracing_guard(&self) -> Option<DefaultGuard> {
61 self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
62 }
63}
64
65impl<N: Network> Worker<N> {
66 pub fn new(
68 id: u8,
69 storage: Storage<N>,
70 ledger: Arc<dyn LedgerService<N>>,
71 proposed_batch: Arc<ProposedBatch<N>>,
72 tracing: Option<TracingHandler>,
73 ) -> Result<Self> {
74 ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
76 Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
78 }
79
80 pub const fn id(&self) -> u8 {
82 self.id
83 }
84}
85
86impl<N: Network> Worker<N> {
87 pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
89 BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
90 pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
92
93 pub fn num_transmissions(&self) -> usize {
97 self.ready.read().num_transmissions()
98 }
99
100 pub fn num_ratifications(&self) -> usize {
102 self.ready.read().num_ratifications()
103 }
104
105 pub fn num_solutions(&self) -> usize {
107 self.ready.read().num_solutions()
108 }
109
110 pub fn num_transactions(&self) -> usize {
112 self.ready.read().num_transactions()
113 }
114}
115
116impl<N: Network> Worker<N> {
117 pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
119 self.ready.read().transmission_ids()
120 }
121
122 pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
124 self.ready.read().transmissions()
125 }
126
127 pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
129 self.ready.read().solutions().into_iter()
130 }
131
132 pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
134 self.ready.read().transactions().into_iter()
135 }
136}
137
138impl<N: Network> Worker<N> {
139 pub(super) fn clear_solutions(&self) {
141 self.ready.write().clear_solutions()
142 }
143}
144
145impl<N: Network> Worker<N> {
146 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
148 let transmission_id = transmission_id.into();
149 self.ready.read().contains(transmission_id)
151 || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
152 || self.storage.contains_transmission(transmission_id)
153 || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
154 }
155
156 pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
161 if let Some(transmission) = self.ready.read().get(transmission_id) {
163 return Some(transmission);
164 }
165 if let Some(transmission) = self.storage.get_transmission(transmission_id) {
167 return Some(transmission);
168 }
169 if let Some(transmission) =
171 self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
172 {
173 return Some(transmission.clone());
174 }
175 None
176 }
177
178 pub async fn get_or_fetch_transmission(
180 &self,
181 transmission_id: TransmissionID<N>,
182 ) -> Result<(TransmissionID<N>, Transmission<N>)> {
183 if let Some(transmission) = self.get_transmission(transmission_id) {
185 return Ok((transmission_id, transmission));
186 }
187
188 bail!("Unable to fetch transmission");
189 }
190
191 pub(crate) fn insert_front(&self, key: TransmissionID<N>, value: Transmission<N>) {
193 self.ready.write().insert_front(key, value);
194 }
195
196 pub(crate) fn remove_front(&self) -> Option<(TransmissionID<N>, Transmission<N>)> {
198 self.ready.write().remove_front()
199 }
200
201 pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
203 if !self.contains_transmission(transmission_id) {
205 return self.ready.write().insert(transmission_id, transmission);
207 }
208 false
209 }
210}
211
212impl<N: Network> Worker<N> {
213 pub(crate) async fn process_unconfirmed_solution(
216 &self,
217 solution_id: SolutionID<N>,
218 solution: Data<Solution<N>>,
219 ) -> Result<()> {
220 let transmission = Transmission::Solution(solution.clone());
222 let checksum = solution.to_checksum::<N>()?;
224 let transmission_id = TransmissionID::Solution(solution_id, checksum);
226 if self.contains_transmission(transmission_id) {
228 bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
229 }
230 self.ledger.check_solution_basic(solution_id, solution).await?;
232 if self.ready.write().insert(transmission_id, transmission) {
234 guard_trace!(
235 self,
236 "Worker {} - Added unconfirmed solution '{}.{}'",
237 self.id,
238 fmt_id(solution_id),
239 fmt_id(checksum).dimmed()
240 );
241 }
242 Ok(())
243 }
244
245 pub(crate) async fn process_unconfirmed_transaction(
247 &self,
248 transaction_id: N::TransactionID,
249 transaction: Data<Transaction<N>>,
250 ) -> Result<()> {
251 let transmission = Transmission::Transaction(transaction.clone());
253 let checksum = transaction.to_checksum::<N>()?;
255 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
257 if self.contains_transmission(transmission_id) {
259 bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
260 }
261 let transaction = spawn_blocking!({
263 match transaction {
264 Data::Object(transaction) => Ok(transaction),
265 Data::Buffer(bytes) => Ok(Transaction::<N>::read_le(&mut bytes.take(N::MAX_TRANSACTION_SIZE as u64))?),
266 }
267 })?;
268
269 self.ledger.check_transaction_basic(transaction_id, transaction).await?;
271 if self.ready.write().insert(transmission_id, transmission) {
273 guard_trace!(
274 self,
275 "Worker {}.{} - Added unconfirmed transaction '{}'",
276 self.id,
277 fmt_id(transaction_id),
278 fmt_id(checksum).dimmed()
279 );
280 }
281 Ok(())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 use amareleo_node_bft_ledger_service::LedgerService;
290 use amareleo_node_bft_storage_service::BFTMemoryService;
291 use snarkvm::{
292 console::{network::Network, types::Field},
293 ledger::{
294 block::Block,
295 committee::Committee,
296 ledger_test_helpers::sample_execution_transaction_with_fee,
297 narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
298 },
299 prelude::Address,
300 };
301
302 use async_trait::async_trait;
303 use bytes::Bytes;
304 use indexmap::IndexMap;
305 use mockall::mock;
306 use std::ops::Range;
307
308 type CurrentNetwork = snarkvm::prelude::MainnetV0;
309
310 const ITERATIONS: usize = 100;
311
312 mock! {
313 #[derive(Debug)]
314 Ledger<N: Network> {}
315 #[async_trait]
316 impl<N: Network> LedgerService<N> for Ledger<N> {
317 fn latest_round(&self) -> u64;
318 fn latest_block_height(&self) -> u32;
319 fn latest_block(&self) -> Block<N>;
320 fn latest_restrictions_id(&self) -> Field<N>;
321 fn latest_leader(&self) -> Option<(u64, Address<N>)>;
322 fn update_latest_leader(&self, round: u64, leader: Address<N>);
323 fn contains_block_height(&self, height: u32) -> bool;
324 fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
325 fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
326 fn get_block_round(&self, height: u32) -> Result<u64>;
327 fn get_block(&self, height: u32) -> Result<Block<N>>;
328 fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
329 fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
330 fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
331 fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
332 fn current_committee(&self) -> Result<Committee<N>>;
333 fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
334 fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
335 fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
336 fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
337 fn ensure_transmission_is_well_formed(
338 &self,
339 transmission_id: TransmissionID<N>,
340 transmission: &mut Transmission<N>,
341 ) -> Result<()>;
342 async fn check_solution_basic(
343 &self,
344 solution_id: SolutionID<N>,
345 solution: Data<Solution<N>>,
346 ) -> Result<()>;
347 async fn check_transaction_basic(
348 &self,
349 transaction_id: N::TransactionID,
350 transaction: Transaction<N>,
351 ) -> Result<()>;
352 fn check_next_block(&self, block: &Block<N>) -> Result<()>;
353 fn prepare_advance_to_next_quorum_block(
354 &self,
355 subdag: Subdag<N>,
356 transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
357 ) -> Result<Block<N>>;
358 fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
359 fn transaction_spent_cost_in_microcredits(&self, transaction_id: N::TransactionID, transaction: Transaction<N>) -> Result<u64>;
360 }
361 }
362
363 #[tokio::test]
364 async fn test_process_solution_ok() {
365 let rng = &mut TestRng::default();
366 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
368 let committee_clone = committee.clone();
369
370 let mut mock_ledger = MockLedger::default();
371 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
372 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
373 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
374 mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
375 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
376 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
378
379 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
381 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
382 let solution_id = rng.gen::<u64>().into();
383 let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
384 let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
385 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
386 assert!(result.is_ok());
387 assert!(worker.ready.read().contains(transmission_id));
388 }
389
390 #[tokio::test]
391 async fn test_process_solution_nok() {
392 let rng = &mut TestRng::default();
393 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
395 let committee_clone = committee.clone();
396
397 let mut mock_ledger = MockLedger::default();
398 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
399 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
400 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
401 mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
402 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
403 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
405
406 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
408 let solution_id = rng.gen::<u64>().into();
409 let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
410 let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
411 let transmission_id = TransmissionID::Solution(solution_id, checksum);
412 let result = worker.process_unconfirmed_solution(solution_id, solution).await;
413 assert!(result.is_err());
414 assert!(!worker.ready.read().contains(transmission_id));
415 }
416
417 #[tokio::test]
418 async fn test_process_transaction_ok() {
419 let rng = &mut TestRng::default();
420 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
422 let committee_clone = committee.clone();
423
424 let mut mock_ledger = MockLedger::default();
425 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
426 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
427 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
428 mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
429 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
430 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
432
433 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
435 let transaction = sample_execution_transaction_with_fee(false, rng);
436 let transaction_id = transaction.id();
437 let transaction_data = Data::Object(transaction);
438 let checksum = transaction_data.to_checksum::<CurrentNetwork>().unwrap();
439 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
440 let result = worker.process_unconfirmed_transaction(transaction_id, transaction_data).await;
441 assert!(result.is_ok());
442 assert!(worker.ready.read().contains(transmission_id));
443 }
444
445 #[tokio::test]
446 async fn test_process_transaction_nok() {
447 let mut rng = &mut TestRng::default();
448 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
450 let committee_clone = committee.clone();
451
452 let mut mock_ledger = MockLedger::default();
453 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
454 mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
455 mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
456 mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
457 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
458 let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
460
461 let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
463 let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
464 let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
465 let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
466 let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
467 let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
468 assert!(result.is_err());
469 assert!(!worker.ready.read().contains(transmission_id));
470 }
471
472 #[tokio::test]
473 async fn test_storage_gc_on_initialization() {
474 let rng = &mut TestRng::default();
475
476 for _ in 0..ITERATIONS {
477 let max_gc_rounds = rng.gen_range(50..=100);
479 let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
480 let expected_gc_round = latest_ledger_round - max_gc_rounds;
481
482 let committee =
484 snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
485
486 let mut mock_ledger = MockLedger::default();
487 mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
488
489 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
490 let storage =
492 Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
493
494 assert_eq!(storage.gc_round(), expected_gc_round);
496 }
497 }
498}
499
500#[cfg(test)]
501mod prop_tests {
502 use super::*;
503 use amareleo_node_bft_ledger_service::MockLedgerService;
504 use snarkvm::{
505 console::account::Address,
506 ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
507 };
508
509 use test_strategy::proptest;
510
511 type CurrentNetwork = snarkvm::prelude::MainnetV0;
512
513 fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
515 let mut members = IndexMap::with_capacity(n as usize);
516 for i in 0..n {
517 let rng = &mut TestRng::fixed(i as u64);
519 let address = Address::new(rng.gen());
520 members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
522 }
523 Committee::<CurrentNetwork>::new(1u64, members).unwrap()
525 }
526
527 #[proptest]
528 fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
529 let committee = new_test_committee(4);
530 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
531 let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
532 assert_eq!(worker.id(), id);
533 }
534
535 #[proptest]
536 fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
537 let committee = new_test_committee(4);
538 let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
539 let worker = Worker::new(id, storage, ledger, Default::default(), None);
540 if let Err(error) = worker {
542 assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
543 }
544 }
545}