amareleo_node_bft/
worker.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    MAX_WORKERS,
18    ProposedBatch,
19    helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_chain_tracing::TracingHandler;
22use amareleo_node_bft_ledger_service::LedgerService;
23use snarkvm::{
24    console::prelude::*,
25    ledger::{
26        block::Transaction,
27        narwhal::{BatchHeader, Data, Transmission, TransmissionID},
28        puzzle::{Solution, SolutionID},
29    },
30};
31
32use colored::Colorize;
33use indexmap::{IndexMap, IndexSet};
34use std::sync::Arc;
35use tracing::subscriber::DefaultGuard;
36
37#[derive(Clone)]
38pub struct Worker<N: Network> {
39    /// The worker ID.
40    id: u8,
41    /// The storage.
42    storage: Storage<N>,
43    /// The ledger service.
44    ledger: Arc<dyn LedgerService<N>>,
45    /// The proposed batch.
46    proposed_batch: Arc<ProposedBatch<N>>,
47    /// Tracing handle
48    tracing: Option<TracingHandler>,
49    /// The ready queue.
50    ready: Ready<N>,
51}
52
53impl<N: Network> Worker<N> {
54    /// Initializes a new worker instance.
55    pub fn new(
56        id: u8,
57        storage: Storage<N>,
58        ledger: Arc<dyn LedgerService<N>>,
59        proposed_batch: Arc<ProposedBatch<N>>,
60        tracing: Option<TracingHandler>,
61    ) -> Result<Self> {
62        // Ensure the worker ID is valid.
63        ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
64        // Return the worker.
65        Ok(Self { id, storage, ledger, proposed_batch, tracing, ready: Default::default() })
66    }
67
68    /// Returns the worker ID.
69    pub const fn id(&self) -> u8 {
70        self.id
71    }
72}
73
74impl<N: Network> Worker<N> {
75    /// The maximum number of transmissions allowed in a worker.
76    pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
77        BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
78    /// The maximum number of transmissions allowed in a worker ping.
79    pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
80
81    // transmissions
82
83    /// Returns the number of transmissions in the ready queue.
84    pub fn num_transmissions(&self) -> usize {
85        self.ready.num_transmissions()
86    }
87
88    /// Returns the number of ratifications in the ready queue.
89    pub fn num_ratifications(&self) -> usize {
90        self.ready.num_ratifications()
91    }
92
93    /// Returns the number of solutions in the ready queue.
94    pub fn num_solutions(&self) -> usize {
95        self.ready.num_solutions()
96    }
97
98    /// Returns the number of transactions in the ready queue.
99    pub fn num_transactions(&self) -> usize {
100        self.ready.num_transactions()
101    }
102}
103
104impl<N: Network> Worker<N> {
105    /// Returns the transmission IDs in the ready queue.
106    pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
107        self.ready.transmission_ids()
108    }
109
110    /// Returns the transmissions in the ready queue.
111    pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
112        self.ready.transmissions()
113    }
114
115    /// Returns the solutions in the ready queue.
116    pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
117        self.ready.solutions()
118    }
119
120    /// Returns the transactions in the ready queue.
121    pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
122        self.ready.transactions()
123    }
124}
125
126impl<N: Network> Worker<N> {
127    /// Clears the solutions from the ready queue.
128    pub(super) fn clear_solutions(&self) {
129        self.ready.clear_solutions()
130    }
131}
132
133impl<N: Network> Worker<N> {
134    /// Retruns tracing guard
135    pub fn get_tracing_guard(&self) -> Option<DefaultGuard> {
136        self.tracing.clone().map(|trace_handle| trace_handle.subscribe_thread())
137    }
138
139    /// Returns `true` if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
140    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
141        let transmission_id = transmission_id.into();
142        // Check if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
143        self.ready.contains(transmission_id)
144            || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
145            || self.storage.contains_transmission(transmission_id)
146            || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
147    }
148
149    /// Returns the transmission if it exists in the ready queue, proposed batch, storage.
150    ///
151    /// Note: We explicitly forbid retrieving a transmission from the ledger, as transmissions
152    /// in the ledger are not guaranteed to be invalid for the current batch.
153    pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
154        // Check if the transmission ID exists in the ready queue.
155        if let Some(transmission) = self.ready.get(transmission_id) {
156            return Some(transmission);
157        }
158        // Check if the transmission ID exists in storage.
159        if let Some(transmission) = self.storage.get_transmission(transmission_id) {
160            return Some(transmission);
161        }
162        // Check if the transmission ID exists in the proposed batch.
163        if let Some(transmission) =
164            self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
165        {
166            return Some(transmission.clone());
167        }
168        None
169    }
170
171    /// Returns the transmissions if it exists in the worker, or requests it from the specified peer.
172    pub async fn get_or_fetch_transmission(
173        &self,
174        transmission_id: TransmissionID<N>,
175    ) -> Result<(TransmissionID<N>, Transmission<N>)> {
176        // Attempt to get the transmission from the worker.
177        if let Some(transmission) = self.get_transmission(transmission_id) {
178            return Ok((transmission_id, transmission));
179        }
180
181        bail!("Unable to fetch transmission");
182    }
183
184    /// Removes up to the specified number of transmissions from the ready queue, and returns them.
185    pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
186        self.ready.drain(num_transmissions).into_iter()
187    }
188
189    /// Reinserts the specified transmission into the ready queue.
190    pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
191        // Check if the transmission ID exists.
192        if !self.contains_transmission(transmission_id) {
193            // Insert the transmission into the ready queue.
194            return self.ready.insert(transmission_id, transmission);
195        }
196        false
197    }
198}
199
200impl<N: Network> Worker<N> {
201    /// Handles the incoming unconfirmed solution.
202    /// Note: This method assumes the incoming solution is valid and does not exist in the ledger.
203    pub(crate) async fn process_unconfirmed_solution(
204        &self,
205        solution_id: SolutionID<N>,
206        solution: Data<Solution<N>>,
207    ) -> Result<()> {
208        let _guard = self.get_tracing_guard();
209        // Construct the transmission.
210        let transmission = Transmission::Solution(solution.clone());
211        // Compute the checksum.
212        let checksum = solution.to_checksum::<N>()?;
213        // Construct the transmission ID.
214        let transmission_id = TransmissionID::Solution(solution_id, checksum);
215        // Check if the solution exists.
216        if self.contains_transmission(transmission_id) {
217            bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
218        }
219        // Check that the solution is well-formed and unique.
220        self.ledger.check_solution_basic(solution_id, solution).await?;
221        // Adds the solution to the ready queue.
222        if self.ready.insert(transmission_id, transmission) {
223            trace!(
224                "Worker {} - Added unconfirmed solution '{}.{}'",
225                self.id,
226                fmt_id(solution_id),
227                fmt_id(checksum).dimmed()
228            );
229        }
230        Ok(())
231    }
232
233    /// Handles the incoming unconfirmed transaction.
234    pub(crate) async fn process_unconfirmed_transaction(
235        &self,
236        transaction_id: N::TransactionID,
237        transaction: Data<Transaction<N>>,
238    ) -> Result<()> {
239        let _guard = self.get_tracing_guard();
240        // Construct the transmission.
241        let transmission = Transmission::Transaction(transaction.clone());
242        // Compute the checksum.
243        let checksum = transaction.to_checksum::<N>()?;
244        // Construct the transmission ID.
245        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
246        // Check if the transaction ID exists.
247        if self.contains_transmission(transmission_id) {
248            bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
249        }
250        // Check that the transaction is well-formed and unique.
251        self.ledger.check_transaction_basic(transaction_id, transaction).await?;
252        // Adds the transaction to the ready queue.
253        if self.ready.insert(transmission_id, transmission) {
254            trace!(
255                "Worker {}.{} - Added unconfirmed transaction '{}'",
256                self.id,
257                fmt_id(transaction_id),
258                fmt_id(checksum).dimmed()
259            );
260        }
261        Ok(())
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    use amareleo_node_bft_ledger_service::LedgerService;
270    use amareleo_node_bft_storage_service::BFTMemoryService;
271    use snarkvm::{
272        console::{network::Network, types::Field},
273        ledger::{
274            block::Block,
275            committee::Committee,
276            narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
277        },
278        prelude::Address,
279    };
280
281    use async_trait::async_trait;
282    use bytes::Bytes;
283    use indexmap::IndexMap;
284    use mockall::mock;
285    use std::ops::Range;
286
287    type CurrentNetwork = snarkvm::prelude::MainnetV0;
288
289    const ITERATIONS: usize = 100;
290
291    mock! {
292        #[derive(Debug)]
293        Ledger<N: Network> {}
294        #[async_trait]
295        impl<N: Network> LedgerService<N> for Ledger<N> {
296            fn latest_round(&self) -> u64;
297            fn latest_block_height(&self) -> u32;
298            fn latest_block(&self) -> Block<N>;
299            fn latest_restrictions_id(&self) -> Field<N>;
300            fn latest_leader(&self) -> Option<(u64, Address<N>)>;
301            fn update_latest_leader(&self, round: u64, leader: Address<N>);
302            fn contains_block_height(&self, height: u32) -> bool;
303            fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
304            fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
305            fn get_block_round(&self, height: u32) -> Result<u64>;
306            fn get_block(&self, height: u32) -> Result<Block<N>>;
307            fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
308            fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
309            fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
310            fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
311            fn current_committee(&self) -> Result<Committee<N>>;
312            fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
313            fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
314            fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
315            fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
316            fn ensure_transmission_is_well_formed(
317                &self,
318                transmission_id: TransmissionID<N>,
319                transmission: &mut Transmission<N>,
320            ) -> Result<()>;
321            async fn check_solution_basic(
322                &self,
323                solution_id: SolutionID<N>,
324                solution: Data<Solution<N>>,
325            ) -> Result<()>;
326            async fn check_transaction_basic(
327                &self,
328                transaction_id: N::TransactionID,
329                transaction: Data<Transaction<N>>,
330            ) -> Result<()>;
331            fn check_next_block(&self, block: &Block<N>) -> Result<()>;
332            fn prepare_advance_to_next_quorum_block(
333                &self,
334                subdag: Subdag<N>,
335                transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
336            ) -> Result<Block<N>>;
337            fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
338        }
339    }
340
341    #[tokio::test]
342    async fn test_process_solution_ok() {
343        let rng = &mut TestRng::default();
344        // Sample a committee.
345        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
346        let committee_clone = committee.clone();
347
348        let mut mock_ledger = MockLedger::default();
349        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
350        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
351        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
352        mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
353        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
354        // Initialize the storage.
355        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
356
357        // Create the Worker.
358        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
359        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
360        let solution_id = rng.gen::<u64>().into();
361        let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
362        let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
363        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
364        assert!(result.is_ok());
365        assert!(worker.ready.contains(transmission_id));
366    }
367
368    #[tokio::test]
369    async fn test_process_solution_nok() {
370        let rng = &mut TestRng::default();
371        // Sample a committee.
372        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
373        let committee_clone = committee.clone();
374
375        let mut mock_ledger = MockLedger::default();
376        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
377        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
378        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
379        mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
380        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
381        // Initialize the storage.
382        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
383
384        // Create the Worker.
385        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
386        let solution_id = rng.gen::<u64>().into();
387        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
388        let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
389        let transmission_id = TransmissionID::Solution(solution_id, checksum);
390        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
391        assert!(result.is_err());
392        assert!(!worker.ready.contains(transmission_id));
393    }
394
395    #[tokio::test]
396    async fn test_process_transaction_ok() {
397        let mut rng = &mut TestRng::default();
398        // Sample a committee.
399        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
400        let committee_clone = committee.clone();
401
402        let mut mock_ledger = MockLedger::default();
403        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
404        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
405        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
406        mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
407        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
408        // Initialize the storage.
409        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
410
411        // Create the Worker.
412        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
413        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
414        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
415        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
416        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
417        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
418        assert!(result.is_ok());
419        assert!(worker.ready.contains(transmission_id));
420    }
421
422    #[tokio::test]
423    async fn test_process_transaction_nok() {
424        let mut rng = &mut TestRng::default();
425        // Sample a committee.
426        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
427        let committee_clone = committee.clone();
428
429        let mut mock_ledger = MockLedger::default();
430        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
431        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
432        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
433        mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
434        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
435        // Initialize the storage.
436        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1, None);
437
438        // Create the Worker.
439        let worker = Worker::new(0, storage, ledger, Default::default(), None).unwrap();
440        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
441        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
442        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
443        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
444        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
445        assert!(result.is_err());
446        assert!(!worker.ready.contains(transmission_id));
447    }
448
449    #[tokio::test]
450    async fn test_storage_gc_on_initialization() {
451        let rng = &mut TestRng::default();
452
453        for _ in 0..ITERATIONS {
454            // Mock the ledger round.
455            let max_gc_rounds = rng.gen_range(50..=100);
456            let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
457            let expected_gc_round = latest_ledger_round - max_gc_rounds;
458
459            // Sample a committee.
460            let committee =
461                snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
462
463            let mut mock_ledger = MockLedger::default();
464            mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
465
466            let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
467            // Initialize the storage.
468            let storage =
469                Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds, None);
470
471            // Ensure that the storage GC round is correct.
472            assert_eq!(storage.gc_round(), expected_gc_round);
473        }
474    }
475}
476
477#[cfg(test)]
478mod prop_tests {
479    use super::*;
480    use amareleo_node_bft_ledger_service::MockLedgerService;
481    use snarkvm::{
482        console::account::Address,
483        ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
484    };
485
486    use test_strategy::proptest;
487
488    type CurrentNetwork = snarkvm::prelude::MainnetV0;
489
490    // Initializes a new test committee.
491    fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
492        let mut members = IndexMap::with_capacity(n as usize);
493        for i in 0..n {
494            // Sample the address.
495            let rng = &mut TestRng::fixed(i as u64);
496            let address = Address::new(rng.gen());
497            info!("Validator {i}: {address}");
498            members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
499        }
500        // Initialize the committee.
501        Committee::<CurrentNetwork>::new(1u64, members).unwrap()
502    }
503
504    #[proptest]
505    fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
506        let committee = new_test_committee(4);
507        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
508        let worker = Worker::new(id, storage, ledger, Default::default(), None).unwrap();
509        assert_eq!(worker.id(), id);
510    }
511
512    #[proptest]
513    fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
514        let committee = new_test_committee(4);
515        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
516        let worker = Worker::new(id, storage, ledger, Default::default(), None);
517        // TODO once Worker implements Debug, simplify this with `unwrap_err`
518        if let Err(error) = worker {
519            assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
520        }
521    }
522}