amareleo_node_bft/
worker.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    MAX_WORKERS,
18    ProposedBatch,
19    helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_chain_tracing::{TracingHandler, TracingHandlerGuard};
22use amareleo_node_bft_ledger_service::LedgerService;
23use snarkvm::{
24    console::prelude::*,
25    ledger::{
26        block::Transaction,
27        narwhal::{BatchHeader, Data, Transmission, TransmissionID},
28        puzzle::{Solution, SolutionID},
29    },
30};
31
32use colored::Colorize;
33use indexmap::{IndexMap, IndexSet};
34use std::sync::Arc;
35use tracing::subscriber::DefaultGuard;
36
37#[derive(Clone)]
38pub struct Worker<N: Network> {
39    /// The worker ID.
40    id: u8,
41    /// The storage.
42    storage: Storage<N>,
43    /// The ledger service.
44    ledger: Arc<dyn LedgerService<N>>,
45    /// The proposed batch.
46    proposed_batch: Arc<ProposedBatch<N>>,
47    /// Tracing handle
48    tracing: Option<TracingHandler>,
49    /// The ready queue.
50    ready: Ready<N>,
51}
52
53impl<N: Network> TracingHandlerGuard for Worker<N> {
54    /// Retruns tracing guard
55    fn get_tracing_guard(&self) -> Option<DefaultGuard> {
56        self.tracing.as_ref().and_then(|trace_handle| trace_handle.get_tracing_guard())
57    }
58}
59
60impl<N: Network> Worker<N> {
61    /// Initializes a new worker instance.
62    pub fn new(
63        id: u8,
64        storage: Storage<N>,
65        ledger: Arc<dyn LedgerService<N>>,
66        proposed_batch: Arc<ProposedBatch<N>>,
67        tracing: Option<TracingHandler>,
68    ) -> Result<Self> {
69        // Ensure the worker ID is valid.
70        ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
71        // Return the worker.
72        Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
73    }
74
75    /// Returns the worker ID.
76    pub const fn id(&self) -> u8 {
77        self.id
78    }
79}
80
81impl<N: Network> Worker<N> {
82    /// The maximum number of transmissions allowed in a worker.
83    pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
84        BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
85    /// The maximum number of transmissions allowed in a worker ping.
86    pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
87
88    // transmissions
89
90    /// Returns the number of transmissions in the ready queue.
91    pub fn num_transmissions(&self) -> usize {
92        self.ready.num_transmissions()
93    }
94
95    /// Returns the number of ratifications in the ready queue.
96    pub fn num_ratifications(&self) -> usize {
97        self.ready.num_ratifications()
98    }
99
100    /// Returns the number of solutions in the ready queue.
101    pub fn num_solutions(&self) -> usize {
102        self.ready.num_solutions()
103    }
104
105    /// Returns the number of transactions in the ready queue.
106    pub fn num_transactions(&self) -> usize {
107        self.ready.num_transactions()
108    }
109}
110
111impl<N: Network> Worker<N> {
112    /// Returns the transmission IDs in the ready queue.
113    pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
114        self.ready.transmission_ids()
115    }
116
117    /// Returns the transmissions in the ready queue.
118    pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
119        self.ready.transmissions()
120    }
121
122    /// Returns the solutions in the ready queue.
123    pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
124        self.ready.solutions()
125    }
126
127    /// Returns the transactions in the ready queue.
128    pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
129        self.ready.transactions()
130    }
131}
132
133impl<N: Network> Worker<N> {
134    /// Clears the solutions from the ready queue.
135    pub(super) fn clear_solutions(&self) {
136        self.ready.clear_solutions()
137    }
138}
139
140impl<N: Network> Worker<N> {
141    /// Returns `true` if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
142    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
143        let transmission_id = transmission_id.into();
144        // Check if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
145        self.ready.contains(transmission_id)
146            || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
147            || self.storage.contains_transmission(transmission_id)
148            || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
149    }
150
151    /// Returns the transmission if it exists in the ready queue, proposed batch, storage.
152    ///
153    /// Note: We explicitly forbid retrieving a transmission from the ledger, as transmissions
154    /// in the ledger are not guaranteed to be invalid for the current batch.
155    pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
156        // Check if the transmission ID exists in the ready queue.
157        if let Some(transmission) = self.ready.get(transmission_id) {
158            return Some(transmission);
159        }
160        // Check if the transmission ID exists in storage.
161        if let Some(transmission) = self.storage.get_transmission(transmission_id) {
162            return Some(transmission);
163        }
164        // Check if the transmission ID exists in the proposed batch.
165        if let Some(transmission) =
166            self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
167        {
168            return Some(transmission.clone());
169        }
170        None
171    }
172
173    /// Returns the transmissions if it exists in the worker, or requests it from the specified peer.
174    pub async fn get_or_fetch_transmission(
175        &self,
176        transmission_id: TransmissionID<N>,
177    ) -> Result<(TransmissionID<N>, Transmission<N>)> {
178        // Attempt to get the transmission from the worker.
179        if let Some(transmission) = self.get_transmission(transmission_id) {
180            return Ok((transmission_id, transmission));
181        }
182
183        bail!("Unable to fetch transmission");
184    }
185
186    /// Removes up to the specified number of transmissions from the ready queue, and returns them.
187    pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
188        self.ready.drain(num_transmissions).into_iter()
189    }
190
191    /// Reinserts the specified transmission into the ready queue.
192    pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
193        // Check if the transmission ID exists.
194        if !self.contains_transmission(transmission_id) {
195            // Insert the transmission into the ready queue.
196            return self.ready.insert(transmission_id, transmission);
197        }
198        false
199    }
200}
201
202impl<N: Network> Worker<N> {
203    /// Handles the incoming unconfirmed solution.
204    /// Note: This method assumes the incoming solution is valid and does not exist in the ledger.
205    pub(crate) async fn process_unconfirmed_solution(
206        &self,
207        solution_id: SolutionID<N>,
208        solution: Data<Solution<N>>,
209    ) -> Result<()> {
210        // Construct the transmission.
211        let transmission = Transmission::Solution(solution.clone());
212        // Compute the checksum.
213        let checksum = solution.to_checksum::<N>()?;
214        // Construct the transmission ID.
215        let transmission_id = TransmissionID::Solution(solution_id, checksum);
216        // Check if the solution exists.
217        if self.contains_transmission(transmission_id) {
218            bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
219        }
220        // Check that the solution is well-formed and unique.
221        self.ledger.check_solution_basic(solution_id, solution).await?;
222        // Adds the solution to the ready queue.
223        if self.ready.insert(transmission_id, transmission) {
224            guard_trace!(
225                self,
226                "Worker {} - Added unconfirmed solution '{}.{}'",
227                self.id,
228                fmt_id(solution_id),
229                fmt_id(checksum).dimmed()
230            );
231        }
232        Ok(())
233    }
234
235    /// Handles the incoming unconfirmed transaction.
236    pub(crate) async fn process_unconfirmed_transaction(
237        &self,
238        transaction_id: N::TransactionID,
239        transaction: Data<Transaction<N>>,
240    ) -> Result<()> {
241        // Construct the transmission.
242        let transmission = Transmission::Transaction(transaction.clone());
243        // Compute the checksum.
244        let checksum = transaction.to_checksum::<N>()?;
245        // Construct the transmission ID.
246        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
247        // Check if the transaction ID exists.
248        if self.contains_transmission(transmission_id) {
249            bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
250        }
251        // Check that the transaction is well-formed and unique.
252        self.ledger.check_transaction_basic(transaction_id, transaction).await?;
253        // Adds the transaction to the ready queue.
254        if self.ready.insert(transmission_id, transmission) {
255            guard_trace!(
256                self,
257                "Worker {}.{} - Added unconfirmed transaction '{}'",
258                self.id,
259                fmt_id(transaction_id),
260                fmt_id(checksum).dimmed()
261            );
262        }
263        Ok(())
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    use amareleo_node_bft_ledger_service::LedgerService;
272    use amareleo_node_bft_storage_service::BFTMemoryService;
273    use snarkvm::{
274        console::{network::Network, types::Field},
275        ledger::{
276            block::Block,
277            committee::Committee,
278            narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
279        },
280        prelude::Address,
281    };
282
283    use async_trait::async_trait;
284    use bytes::Bytes;
285    use indexmap::IndexMap;
286    use mockall::mock;
287    use std::ops::Range;
288
289    type CurrentNetwork = snarkvm::prelude::MainnetV0;
290
291    const ITERATIONS: usize = 100;
292
293    mock! {
294        #[derive(Debug)]
295        Ledger<N: Network> {}
296        #[async_trait]
297        impl<N: Network> LedgerService<N> for Ledger<N> {
298            fn latest_round(&self) -> u64;
299            fn latest_block_height(&self) -> u32;
300            fn latest_block(&self) -> Block<N>;
301            fn latest_restrictions_id(&self) -> Field<N>;
302            fn latest_leader(&self) -> Option<(u64, Address<N>)>;
303            fn update_latest_leader(&self, round: u64, leader: Address<N>);
304            fn contains_block_height(&self, height: u32) -> bool;
305            fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
306            fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
307            fn get_block_round(&self, height: u32) -> Result<u64>;
308            fn get_block(&self, height: u32) -> Result<Block<N>>;
309            fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
310            fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
311            fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
312            fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
313            fn current_committee(&self) -> Result<Committee<N>>;
314            fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
315            fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
316            fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
317            fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
318            fn ensure_transmission_is_well_formed(
319                &self,
320                transmission_id: TransmissionID<N>,
321                transmission: &mut Transmission<N>,
322            ) -> Result<()>;
323            async fn check_solution_basic(
324                &self,
325                solution_id: SolutionID<N>,
326                solution: Data<Solution<N>>,
327            ) -> Result<()>;
328            async fn check_transaction_basic(
329                &self,
330                transaction_id: N::TransactionID,
331                transaction: Data<Transaction<N>>,
332            ) -> Result<()>;
333            fn check_next_block(&self, block: &Block<N>) -> Result<()>;
334            fn prepare_advance_to_next_quorum_block(
335                &self,
336                subdag: Subdag<N>,
337                transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
338            ) -> Result<Block<N>>;
339            fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
340        }
341    }
342
343    #[tokio::test]
344    async fn test_process_solution_ok() {
345        let rng = &mut TestRng::default();
346        // Sample a committee.
347        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
348        let committee_clone = committee.clone();
349
350        let mut mock_ledger = MockLedger::default();
351        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
352        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
353        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
354        mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
355        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
356        // Initialize the storage.
357        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
358
359        // Create the Worker.
360        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
361        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
362        let solution_id = rng.gen::<u64>().into();
363        let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
364        let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
365        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
366        assert!(result.is_ok());
367        assert!(worker.ready.contains(transmission_id));
368    }
369
370    #[tokio::test]
371    async fn test_process_solution_nok() {
372        let rng = &mut TestRng::default();
373        // Sample a committee.
374        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
375        let committee_clone = committee.clone();
376
377        let mut mock_ledger = MockLedger::default();
378        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
379        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
380        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
381        mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
382        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
383        // Initialize the storage.
384        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
385
386        // Create the Worker.
387        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
388        let solution_id = rng.gen::<u64>().into();
389        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
390        let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
391        let transmission_id = TransmissionID::Solution(solution_id, checksum);
392        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
393        assert!(result.is_err());
394        assert!(!worker.ready.contains(transmission_id));
395    }
396
397    #[tokio::test]
398    async fn test_process_transaction_ok() {
399        let mut rng = &mut TestRng::default();
400        // Sample a committee.
401        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
402        let committee_clone = committee.clone();
403
404        let mut mock_ledger = MockLedger::default();
405        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
406        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
407        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
408        mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
409        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
410        // Initialize the storage.
411        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
412
413        // Create the Worker.
414        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
415        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
416        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
417        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
418        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
419        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
420        assert!(result.is_ok());
421        assert!(worker.ready.contains(transmission_id));
422    }
423
424    #[tokio::test]
425    async fn test_process_transaction_nok() {
426        let mut rng = &mut TestRng::default();
427        // Sample a committee.
428        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
429        let committee_clone = committee.clone();
430
431        let mut mock_ledger = MockLedger::default();
432        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
433        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
434        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
435        mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
436        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
437        // Initialize the storage.
438        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
439
440        // Create the Worker.
441        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
442        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
443        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
444        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
445        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
446        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
447        assert!(result.is_err());
448        assert!(!worker.ready.contains(transmission_id));
449    }
450
451    #[tokio::test]
452    async fn test_storage_gc_on_initialization() {
453        let rng = &mut TestRng::default();
454
455        for _ in 0..ITERATIONS {
456            // Mock the ledger round.
457            let max_gc_rounds = rng.gen_range(50..=100);
458            let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
459            let expected_gc_round = latest_ledger_round - max_gc_rounds;
460
461            // Sample a committee.
462            let committee =
463                snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
464
465            let mut mock_ledger = MockLedger::default();
466            mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
467
468            let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
469            // Initialize the storage.
470            let storage =
471                Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
472
473            // Ensure that the storage GC round is correct.
474            assert_eq!(storage.gc_round(), expected_gc_round);
475        }
476    }
477}
478
479#[cfg(test)]
480mod prop_tests {
481    use super::*;
482    use amareleo_node_bft_ledger_service::MockLedgerService;
483    use snarkvm::{
484        console::account::Address,
485        ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
486    };
487
488    use test_strategy::proptest;
489
490    type CurrentNetwork = snarkvm::prelude::MainnetV0;
491
492    // Initializes a new test committee.
493    fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
494        let mut members = IndexMap::with_capacity(n as usize);
495        for i in 0..n {
496            // Sample the address.
497            let rng = &mut TestRng::fixed(i as u64);
498            let address = Address::new(rng.gen());
499            // guard_info!("Validator {i}: {address}");
500            members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
501        }
502        // Initialize the committee.
503        Committee::<CurrentNetwork>::new(1u64, members).unwrap()
504    }
505
506    #[proptest]
507    fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
508        let committee = new_test_committee(4);
509        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
510        let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
511        assert_eq!(worker.id(), id);
512    }
513
514    #[proptest]
515    fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
516        let committee = new_test_committee(4);
517        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
518        let worker = Worker::new(id, storage, ledger, Default::default(), None);
519        // TODO once Worker implements Debug, simplify this with `unwrap_err`
520        if let Err(error) = worker {
521            assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
522        }
523    }
524}