amareleo_node_bft/
worker.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    MAX_WORKERS,
18    ProposedBatch,
19    helpers::{Ready, Storage, fmt_id},
20    spawn_blocking,
21};
22use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
23use amareleo_node_bft_ledger_service::LedgerService;
24use snarkvm::{
25    console::prelude::*,
26    ledger::{
27        block::Transaction,
28        narwhal::{BatchHeader, Data, Transmission, TransmissionID},
29        puzzle::{Solution, SolutionID},
30    },
31};
32
33use colored::Colorize;
34use indexmap::{IndexMap, IndexSet};
35#[cfg(feature = "locktick")]
36use locktick::parking_lot::RwLock;
37#[cfg(not(feature = "locktick"))]
38use parking_lot::RwLock;
39use std::sync::Arc;
40use tracing::subscriber::DefaultGuard;
41
42#[derive(Clone)]
43pub struct Worker<N: Network> {
44    /// The worker ID.
45    id: u8,
46    /// The storage.
47    storage: Storage<N>,
48    /// The ledger service.
49    ledger: Arc<dyn LedgerService<N>>,
50    /// The proposed batch.
51    proposed_batch: Arc<ProposedBatch<N>>,
52    /// Tracing handle
53    tracing: Option<TracingHandler>,
54    /// The ready queue.
55    ready: Arc<RwLock<Ready<N>>>,
56}
57
58impl<N: Network> TracingHandlerGuard for Worker<N> {
59    /// Retruns tracing guard
60    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
61        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
62    }
63}
64
65impl<N: Network> Worker<N> {
66    /// Initializes a new worker instance.
67    pub fn new(
68        id: u8,
69        storage: Storage<N>,
70        ledger: Arc<dyn LedgerService<N>>,
71        proposed_batch: Arc<ProposedBatch<N>>,
72        tracing: Option<TracingHandler>,
73    ) -> Result<Self> {
74        // Ensure the worker ID is valid.
75        ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
76        // Return the worker.
77        Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
78    }
79
80    /// Returns the worker ID.
81    pub const fn id(&self) -> u8 {
82        self.id
83    }
84}
85
86impl<N: Network> Worker<N> {
87    /// The maximum number of transmissions allowed in a worker.
88    pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
89        BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
90    /// The maximum number of transmissions allowed in a worker ping.
91    pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
92
93    // transmissions
94
95    /// Returns the number of transmissions in the ready queue.
96    pub fn num_transmissions(&self) -> usize {
97        self.ready.read().num_transmissions()
98    }
99
100    /// Returns the number of ratifications in the ready queue.
101    pub fn num_ratifications(&self) -> usize {
102        self.ready.read().num_ratifications()
103    }
104
105    /// Returns the number of solutions in the ready queue.
106    pub fn num_solutions(&self) -> usize {
107        self.ready.read().num_solutions()
108    }
109
110    /// Returns the number of transactions in the ready queue.
111    pub fn num_transactions(&self) -> usize {
112        self.ready.read().num_transactions()
113    }
114}
115
116impl<N: Network> Worker<N> {
117    /// Returns the transmission IDs in the ready queue.
118    pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
119        self.ready.read().transmission_ids()
120    }
121
122    /// Returns the transmissions in the ready queue.
123    pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
124        self.ready.read().transmissions()
125    }
126
127    /// Returns the solutions in the ready queue.
128    pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
129        self.ready.read().solutions().into_iter()
130    }
131
132    /// Returns the transactions in the ready queue.
133    pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
134        self.ready.read().transactions().into_iter()
135    }
136}
137
138impl<N: Network> Worker<N> {
139    /// Clears the solutions from the ready queue.
140    pub(super) fn clear_solutions(&self) {
141        self.ready.write().clear_solutions()
142    }
143}
144
145impl<N: Network> Worker<N> {
146    /// Returns `true` if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
147    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
148        let transmission_id = transmission_id.into();
149        // Check if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
150        self.ready.read().contains(transmission_id)
151            || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
152            || self.storage.contains_transmission(transmission_id)
153            || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
154    }
155
156    /// Returns the transmission if it exists in the ready queue, proposed batch, storage.
157    ///
158    /// Note: We explicitly forbid retrieving a transmission from the ledger, as transmissions
159    /// in the ledger are not guaranteed to be invalid for the current batch.
160    pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
161        // Check if the transmission ID exists in the ready queue.
162        if let Some(transmission) = self.ready.read().get(transmission_id) {
163            return Some(transmission);
164        }
165        // Check if the transmission ID exists in storage.
166        if let Some(transmission) = self.storage.get_transmission(transmission_id) {
167            return Some(transmission);
168        }
169        // Check if the transmission ID exists in the proposed batch.
170        if let Some(transmission) =
171            self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
172        {
173            return Some(transmission.clone());
174        }
175        None
176    }
177
178    /// Returns the transmissions if it exists in the worker, or requests it from the specified peer.
179    pub async fn get_or_fetch_transmission(
180        &self,
181        transmission_id: TransmissionID<N>,
182    ) -> Result<(TransmissionID<N>, Transmission<N>)> {
183        // Attempt to get the transmission from the worker.
184        if let Some(transmission) = self.get_transmission(transmission_id) {
185            return Ok((transmission_id, transmission));
186        }
187
188        bail!("Unable to fetch transmission");
189    }
190
191    /// Inserts the transmission at the front of the ready queue.
192    pub(crate) fn insert_front(&self, key: TransmissionID<N>, value: Transmission<N>) {
193        self.ready.write().insert_front(key, value);
194    }
195
196    /// Removes and returns the transmission at the front of the ready queue.
197    pub(crate) fn remove_front(&self) -> Option<(TransmissionID<N>, Transmission<N>)> {
198        self.ready.write().remove_front()
199    }
200
201    /// Reinserts the specified transmission into the ready queue.
202    pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
203        // Check if the transmission ID exists.
204        if !self.contains_transmission(transmission_id) {
205            // Insert the transmission into the ready queue.
206            return self.ready.write().insert(transmission_id, transmission);
207        }
208        false
209    }
210}
211
212impl<N: Network> Worker<N> {
213    /// Handles the incoming unconfirmed solution.
214    /// Note: This method assumes the incoming solution is valid and does not exist in the ledger.
215    pub(crate) async fn process_unconfirmed_solution(
216        &self,
217        solution_id: SolutionID<N>,
218        solution: Data<Solution<N>>,
219    ) -> Result<()> {
220        // Construct the transmission.
221        let transmission = Transmission::Solution(solution.clone());
222        // Compute the checksum.
223        let checksum = solution.to_checksum::<N>()?;
224        // Construct the transmission ID.
225        let transmission_id = TransmissionID::Solution(solution_id, checksum);
226        // Check if the solution exists.
227        if self.contains_transmission(transmission_id) {
228            bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
229        }
230        // Check that the solution is well-formed and unique.
231        self.ledger.check_solution_basic(solution_id, solution).await?;
232        // Adds the solution to the ready queue.
233        if self.ready.write().insert(transmission_id, transmission) {
234            guard_trace!(
235                self,
236                "Worker {} - Added unconfirmed solution '{}.{}'",
237                self.id,
238                fmt_id(solution_id),
239                fmt_id(checksum).dimmed()
240            );
241        }
242        Ok(())
243    }
244
245    /// Handles the incoming unconfirmed transaction.
246    pub(crate) async fn process_unconfirmed_transaction(
247        &self,
248        transaction_id: N::TransactionID,
249        transaction: Data<Transaction<N>>,
250    ) -> Result<()> {
251        // Construct the transmission.
252        let transmission = Transmission::Transaction(transaction.clone());
253        // Compute the checksum.
254        let checksum = transaction.to_checksum::<N>()?;
255        // Construct the transmission ID.
256        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
257        // Check if the transaction ID exists.
258        if self.contains_transmission(transmission_id) {
259            bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
260        }
261        // Deserialize the transaction. If the transaction exceeds the maximum size, then return an error.
262        let transaction = spawn_blocking!({
263            match transaction {
264                Data::Object(transaction) => Ok(transaction),
265                Data::Buffer(bytes) => Ok(Transaction::<N>::read_le(&mut bytes.take(N::MAX_TRANSACTION_SIZE as u64))?),
266            }
267        })?;
268
269        // Check that the transaction is well-formed and unique.
270        self.ledger.check_transaction_basic(transaction_id, transaction).await?;
271        // Adds the transaction to the ready queue.
272        if self.ready.write().insert(transmission_id, transmission) {
273            guard_trace!(
274                self,
275                "Worker {}.{} - Added unconfirmed transaction '{}'",
276                self.id,
277                fmt_id(transaction_id),
278                fmt_id(checksum).dimmed()
279            );
280        }
281        Ok(())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    use amareleo_node_bft_ledger_service::LedgerService;
290    use amareleo_node_bft_storage_service::BFTMemoryService;
291    use snarkvm::{
292        console::{network::Network, types::Field},
293        ledger::{
294            block::Block,
295            committee::Committee,
296            ledger_test_helpers::sample_execution_transaction_with_fee,
297            narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
298        },
299        prelude::Address,
300    };
301
302    use async_trait::async_trait;
303    use bytes::Bytes;
304    use indexmap::IndexMap;
305    use mockall::mock;
306    use std::ops::Range;
307
308    type CurrentNetwork = snarkvm::prelude::MainnetV0;
309
310    const ITERATIONS: usize = 100;
311
312    mock! {
313        #[derive(Debug)]
314        Ledger<N: Network> {}
315        #[async_trait]
316        impl<N: Network> LedgerService<N> for Ledger<N> {
317            fn latest_round(&self) -> u64;
318            fn latest_block_height(&self) -> u32;
319            fn latest_block(&self) -> Block<N>;
320            fn latest_restrictions_id(&self) -> Field<N>;
321            fn latest_leader(&self) -> Option<(u64, Address<N>)>;
322            fn update_latest_leader(&self, round: u64, leader: Address<N>);
323            fn contains_block_height(&self, height: u32) -> bool;
324            fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
325            fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
326            fn get_block_round(&self, height: u32) -> Result<u64>;
327            fn get_block(&self, height: u32) -> Result<Block<N>>;
328            fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
329            fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
330            fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
331            fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
332            fn current_committee(&self) -> Result<Committee<N>>;
333            fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
334            fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
335            fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
336            fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
337            fn ensure_transmission_is_well_formed(
338                &self,
339                transmission_id: TransmissionID<N>,
340                transmission: &mut Transmission<N>,
341            ) -> Result<()>;
342            async fn check_solution_basic(
343                &self,
344                solution_id: SolutionID<N>,
345                solution: Data<Solution<N>>,
346            ) -> Result<()>;
347            async fn check_transaction_basic(
348                &self,
349                transaction_id: N::TransactionID,
350                transaction: Transaction<N>,
351            ) -> Result<()>;
352            fn check_next_block(&self, block: &Block<N>) -> Result<()>;
353            fn prepare_advance_to_next_quorum_block(
354                &self,
355                subdag: Subdag<N>,
356                transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
357            ) -> Result<Block<N>>;
358            fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
359            fn transaction_spent_cost_in_microcredits(&self, transaction_id: N::TransactionID, transaction: Transaction<N>) -> Result<u64>;
360        }
361    }
362
363    #[tokio::test]
364    async fn test_process_solution_ok() {
365        let rng = &mut TestRng::default();
366        // Sample a committee.
367        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
368        let committee_clone = committee.clone();
369
370        let mut mock_ledger = MockLedger::default();
371        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
372        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
373        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
374        mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
375        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
376        // Initialize the storage.
377        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
378
379        // Create the Worker.
380        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
381        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
382        let solution_id = rng.gen::<u64>().into();
383        let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
384        let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
385        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
386        assert!(result.is_ok());
387        assert!(worker.ready.read().contains(transmission_id));
388    }
389
390    #[tokio::test]
391    async fn test_process_solution_nok() {
392        let rng = &mut TestRng::default();
393        // Sample a committee.
394        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
395        let committee_clone = committee.clone();
396
397        let mut mock_ledger = MockLedger::default();
398        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
399        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
400        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
401        mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
402        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
403        // Initialize the storage.
404        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
405
406        // Create the Worker.
407        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
408        let solution_id = rng.gen::<u64>().into();
409        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
410        let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
411        let transmission_id = TransmissionID::Solution(solution_id, checksum);
412        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
413        assert!(result.is_err());
414        assert!(!worker.ready.read().contains(transmission_id));
415    }
416
417    #[tokio::test]
418    async fn test_process_transaction_ok() {
419        let rng = &mut TestRng::default();
420        // Sample a committee.
421        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
422        let committee_clone = committee.clone();
423
424        let mut mock_ledger = MockLedger::default();
425        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
426        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
427        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
428        mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
429        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
430        // Initialize the storage.
431        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
432
433        // Create the Worker.
434        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
435        let transaction = sample_execution_transaction_with_fee(false, rng);
436        let transaction_id = transaction.id();
437        let transaction_data = Data::Object(transaction);
438        let checksum = transaction_data.to_checksum::<CurrentNetwork>().unwrap();
439        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
440        let result = worker.process_unconfirmed_transaction(transaction_id, transaction_data).await;
441        assert!(result.is_ok());
442        assert!(worker.ready.read().contains(transmission_id));
443    }
444
445    #[tokio::test]
446    async fn test_process_transaction_nok() {
447        let mut rng = &mut TestRng::default();
448        // Sample a committee.
449        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
450        let committee_clone = committee.clone();
451
452        let mut mock_ledger = MockLedger::default();
453        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
454        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
455        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
456        mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
457        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
458        // Initialize the storage.
459        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
460
461        // Create the Worker.
462        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
463        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
464        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
465        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
466        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
467        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
468        assert!(result.is_err());
469        assert!(!worker.ready.read().contains(transmission_id));
470    }
471
472    #[tokio::test]
473    async fn test_storage_gc_on_initialization() {
474        let rng = &mut TestRng::default();
475
476        for _ in 0..ITERATIONS {
477            // Mock the ledger round.
478            let max_gc_rounds = rng.gen_range(50..=100);
479            let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
480            let expected_gc_round = latest_ledger_round - max_gc_rounds;
481
482            // Sample a committee.
483            let committee =
484                snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
485
486            let mut mock_ledger = MockLedger::default();
487            mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
488
489            let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
490            // Initialize the storage.
491            let storage =
492                Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
493
494            // Ensure that the storage GC round is correct.
495            assert_eq!(storage.gc_round(), expected_gc_round);
496        }
497    }
498}
499
500#[cfg(test)]
501mod prop_tests {
502    use super::*;
503    use amareleo_node_bft_ledger_service::MockLedgerService;
504    use snarkvm::{
505        console::account::Address,
506        ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
507    };
508
509    use test_strategy::proptest;
510
511    type CurrentNetwork = snarkvm::prelude::MainnetV0;
512
513    // Initializes a new test committee.
514    fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
515        let mut members = IndexMap::with_capacity(n as usize);
516        for i in 0..n {
517            // Sample the address.
518            let rng = &mut TestRng::fixed(i as u64);
519            let address = Address::new(rng.gen());
520            // guard_info!("Validator {i}: {address}");
521            members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
522        }
523        // Initialize the committee.
524        Committee::<CurrentNetwork>::new(1u64, members).unwrap()
525    }
526
527    #[proptest]
528    fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
529        let committee = new_test_committee(4);
530        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
531        let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
532        assert_eq!(worker.id(), id);
533    }
534
535    #[proptest]
536    fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
537        let committee = new_test_committee(4);
538        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
539        let worker = Worker::new(id, storage, ledger, Default::default(), None);
540        // TODO once Worker implements Debug, simplify this with `unwrap_err`
541        if let Err(error) = worker {
542            assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
543        }
544    }
545}