amareleo_node_bft/
worker.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    MAX_WORKERS,
18    ProposedBatch,
19    helpers::{Ready, Storage, fmt_id},
20};
21use amareleo_node_bft_ledger_service::LedgerService;
22use snarkvm::{
23    console::prelude::*,
24    ledger::{
25        block::Transaction,
26        narwhal::{BatchHeader, Data, Transmission, TransmissionID},
27        puzzle::{Solution, SolutionID},
28    },
29};
30
31use colored::Colorize;
32use indexmap::{IndexMap, IndexSet};
33use std::sync::Arc;
34
35#[derive(Clone)]
36pub struct Worker<N: Network> {
37    /// The worker ID.
38    id: u8,
39    /// The storage.
40    storage: Storage<N>,
41    /// The ledger service.
42    ledger: Arc<dyn LedgerService<N>>,
43    /// The proposed batch.
44    proposed_batch: Arc<ProposedBatch<N>>,
45    /// The ready queue.
46    ready: Ready<N>,
47}
48
49impl<N: Network> Worker<N> {
50    /// Initializes a new worker instance.
51    pub fn new(
52        id: u8,
53        storage: Storage<N>,
54        ledger: Arc<dyn LedgerService<N>>,
55        proposed_batch: Arc<ProposedBatch<N>>,
56    ) -> Result<Self> {
57        // Ensure the worker ID is valid.
58        ensure!(id < MAX_WORKERS, "Invalid worker ID '{id}'");
59        // Return the worker.
60        Ok(Self { id, storage, ledger, proposed_batch, ready: Default::default() })
61    }
62
63    /// Returns the worker ID.
64    pub const fn id(&self) -> u8 {
65        self.id
66    }
67}
68
69impl<N: Network> Worker<N> {
70    /// The maximum number of transmissions allowed in a worker.
71    pub const MAX_TRANSMISSIONS_PER_WORKER: usize =
72        BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / MAX_WORKERS as usize;
73    /// The maximum number of transmissions allowed in a worker ping.
74    pub const MAX_TRANSMISSIONS_PER_WORKER_PING: usize = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / 10;
75
76    // transmissions
77
78    /// Returns the number of transmissions in the ready queue.
79    pub fn num_transmissions(&self) -> usize {
80        self.ready.num_transmissions()
81    }
82
83    /// Returns the number of ratifications in the ready queue.
84    pub fn num_ratifications(&self) -> usize {
85        self.ready.num_ratifications()
86    }
87
88    /// Returns the number of solutions in the ready queue.
89    pub fn num_solutions(&self) -> usize {
90        self.ready.num_solutions()
91    }
92
93    /// Returns the number of transactions in the ready queue.
94    pub fn num_transactions(&self) -> usize {
95        self.ready.num_transactions()
96    }
97}
98
99impl<N: Network> Worker<N> {
100    /// Returns the transmission IDs in the ready queue.
101    pub fn transmission_ids(&self) -> IndexSet<TransmissionID<N>> {
102        self.ready.transmission_ids()
103    }
104
105    /// Returns the transmissions in the ready queue.
106    pub fn transmissions(&self) -> IndexMap<TransmissionID<N>, Transmission<N>> {
107        self.ready.transmissions()
108    }
109
110    /// Returns the solutions in the ready queue.
111    pub fn solutions(&self) -> impl '_ + Iterator<Item = (SolutionID<N>, Data<Solution<N>>)> {
112        self.ready.solutions()
113    }
114
115    /// Returns the transactions in the ready queue.
116    pub fn transactions(&self) -> impl '_ + Iterator<Item = (N::TransactionID, Data<Transaction<N>>)> {
117        self.ready.transactions()
118    }
119}
120
121impl<N: Network> Worker<N> {
122    /// Clears the solutions from the ready queue.
123    pub(super) fn clear_solutions(&self) {
124        self.ready.clear_solutions()
125    }
126}
127
128impl<N: Network> Worker<N> {
129    /// Returns `true` if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
130    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
131        let transmission_id = transmission_id.into();
132        // Check if the transmission ID exists in the ready queue, proposed batch, storage, or ledger.
133        self.ready.contains(transmission_id)
134            || self.proposed_batch.read().as_ref().map_or(false, |p| p.contains_transmission(transmission_id))
135            || self.storage.contains_transmission(transmission_id)
136            || self.ledger.contains_transmission(&transmission_id).unwrap_or(false)
137    }
138
139    /// Returns the transmission if it exists in the ready queue, proposed batch, storage.
140    ///
141    /// Note: We explicitly forbid retrieving a transmission from the ledger, as transmissions
142    /// in the ledger are not guaranteed to be invalid for the current batch.
143    pub fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
144        // Check if the transmission ID exists in the ready queue.
145        if let Some(transmission) = self.ready.get(transmission_id) {
146            return Some(transmission);
147        }
148        // Check if the transmission ID exists in storage.
149        if let Some(transmission) = self.storage.get_transmission(transmission_id) {
150            return Some(transmission);
151        }
152        // Check if the transmission ID exists in the proposed batch.
153        if let Some(transmission) =
154            self.proposed_batch.read().as_ref().and_then(|p| p.get_transmission(transmission_id))
155        {
156            return Some(transmission.clone());
157        }
158        None
159    }
160
161    /// Returns the transmissions if it exists in the worker, or requests it from the specified peer.
162    pub async fn get_or_fetch_transmission(
163        &self,
164        transmission_id: TransmissionID<N>,
165    ) -> Result<(TransmissionID<N>, Transmission<N>)> {
166        // Attempt to get the transmission from the worker.
167        if let Some(transmission) = self.get_transmission(transmission_id) {
168            return Ok((transmission_id, transmission));
169        }
170
171        bail!("Unable to fetch transmission");
172    }
173
174    /// Removes up to the specified number of transmissions from the ready queue, and returns them.
175    pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
176        self.ready.drain(num_transmissions).into_iter()
177    }
178
179    /// Reinserts the specified transmission into the ready queue.
180    pub(crate) fn reinsert(&self, transmission_id: TransmissionID<N>, transmission: Transmission<N>) -> bool {
181        // Check if the transmission ID exists.
182        if !self.contains_transmission(transmission_id) {
183            // Insert the transmission into the ready queue.
184            return self.ready.insert(transmission_id, transmission);
185        }
186        false
187    }
188}
189
190impl<N: Network> Worker<N> {
191    /// Handles the incoming unconfirmed solution.
192    /// Note: This method assumes the incoming solution is valid and does not exist in the ledger.
193    pub(crate) async fn process_unconfirmed_solution(
194        &self,
195        solution_id: SolutionID<N>,
196        solution: Data<Solution<N>>,
197    ) -> Result<()> {
198        // Construct the transmission.
199        let transmission = Transmission::Solution(solution.clone());
200        // Compute the checksum.
201        let checksum = solution.to_checksum::<N>()?;
202        // Construct the transmission ID.
203        let transmission_id = TransmissionID::Solution(solution_id, checksum);
204        // Check if the solution exists.
205        if self.contains_transmission(transmission_id) {
206            bail!("Solution '{}.{}' already exists.", fmt_id(solution_id), fmt_id(checksum).dimmed());
207        }
208        // Check that the solution is well-formed and unique.
209        self.ledger.check_solution_basic(solution_id, solution).await?;
210        // Adds the solution to the ready queue.
211        if self.ready.insert(transmission_id, transmission) {
212            trace!(
213                "Worker {} - Added unconfirmed solution '{}.{}'",
214                self.id,
215                fmt_id(solution_id),
216                fmt_id(checksum).dimmed()
217            );
218        }
219        Ok(())
220    }
221
222    /// Handles the incoming unconfirmed transaction.
223    pub(crate) async fn process_unconfirmed_transaction(
224        &self,
225        transaction_id: N::TransactionID,
226        transaction: Data<Transaction<N>>,
227    ) -> Result<()> {
228        // Construct the transmission.
229        let transmission = Transmission::Transaction(transaction.clone());
230        // Compute the checksum.
231        let checksum = transaction.to_checksum::<N>()?;
232        // Construct the transmission ID.
233        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
234        // Check if the transaction ID exists.
235        if self.contains_transmission(transmission_id) {
236            bail!("Transaction '{}.{}' already exists.", fmt_id(transaction_id), fmt_id(checksum).dimmed());
237        }
238        // Check that the transaction is well-formed and unique.
239        self.ledger.check_transaction_basic(transaction_id, transaction).await?;
240        // Adds the transaction to the ready queue.
241        if self.ready.insert(transmission_id, transmission) {
242            trace!(
243                "Worker {}.{} - Added unconfirmed transaction '{}'",
244                self.id,
245                fmt_id(transaction_id),
246                fmt_id(checksum).dimmed()
247            );
248        }
249        Ok(())
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    use amareleo_node_bft_ledger_service::LedgerService;
258    use amareleo_node_bft_storage_service::BFTMemoryService;
259    use snarkvm::{
260        console::{network::Network, types::Field},
261        ledger::{
262            block::Block,
263            committee::Committee,
264            narwhal::{BatchCertificate, Subdag, Transmission, TransmissionID},
265        },
266        prelude::Address,
267    };
268
269    use async_trait::async_trait;
270    use bytes::Bytes;
271    use indexmap::IndexMap;
272    use mockall::mock;
273    use std::ops::Range;
274
275    type CurrentNetwork = snarkvm::prelude::MainnetV0;
276
277    const ITERATIONS: usize = 100;
278
279    mock! {
280        #[derive(Debug)]
281        Ledger<N: Network> {}
282        #[async_trait]
283        impl<N: Network> LedgerService<N> for Ledger<N> {
284            fn latest_round(&self) -> u64;
285            fn latest_block_height(&self) -> u32;
286            fn latest_block(&self) -> Block<N>;
287            fn latest_restrictions_id(&self) -> Field<N>;
288            fn latest_leader(&self) -> Option<(u64, Address<N>)>;
289            fn update_latest_leader(&self, round: u64, leader: Address<N>);
290            fn contains_block_height(&self, height: u32) -> bool;
291            fn get_block_height(&self, hash: &N::BlockHash) -> Result<u32>;
292            fn get_block_hash(&self, height: u32) -> Result<N::BlockHash>;
293            fn get_block_round(&self, height: u32) -> Result<u64>;
294            fn get_block(&self, height: u32) -> Result<Block<N>>;
295            fn get_blocks(&self, heights: Range<u32>) -> Result<Vec<Block<N>>>;
296            fn get_solution(&self, solution_id: &SolutionID<N>) -> Result<Solution<N>>;
297            fn get_unconfirmed_transaction(&self, transaction_id: N::TransactionID) -> Result<Transaction<N>>;
298            fn get_batch_certificate(&self, certificate_id: &Field<N>) -> Result<BatchCertificate<N>>;
299            fn current_committee(&self) -> Result<Committee<N>>;
300            fn get_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
301            fn get_committee_lookback_for_round(&self, round: u64) -> Result<Committee<N>>;
302            fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
303            fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
304            fn ensure_transmission_is_well_formed(
305                &self,
306                transmission_id: TransmissionID<N>,
307                transmission: &mut Transmission<N>,
308            ) -> Result<()>;
309            async fn check_solution_basic(
310                &self,
311                solution_id: SolutionID<N>,
312                solution: Data<Solution<N>>,
313            ) -> Result<()>;
314            async fn check_transaction_basic(
315                &self,
316                transaction_id: N::TransactionID,
317                transaction: Data<Transaction<N>>,
318            ) -> Result<()>;
319            fn check_next_block(&self, block: &Block<N>) -> Result<()>;
320            fn prepare_advance_to_next_quorum_block(
321                &self,
322                subdag: Subdag<N>,
323                transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
324            ) -> Result<Block<N>>;
325            fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
326        }
327    }
328
329    #[tokio::test]
330    async fn test_process_solution_ok() {
331        let rng = &mut TestRng::default();
332        // Sample a committee.
333        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
334        let committee_clone = committee.clone();
335
336        let mut mock_ledger = MockLedger::default();
337        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
338        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
339        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
340        mock_ledger.expect_check_solution_basic().returning(|_, _| Ok(()));
341        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
342        // Initialize the storage.
343        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
344
345        // Create the Worker.
346        let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
347        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
348        let solution_id = rng.gen::<u64>().into();
349        let solution_checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
350        let transmission_id = TransmissionID::Solution(solution_id, solution_checksum);
351        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
352        assert!(result.is_ok());
353        assert!(worker.ready.contains(transmission_id));
354    }
355
356    #[tokio::test]
357    async fn test_process_solution_nok() {
358        let rng = &mut TestRng::default();
359        // Sample a committee.
360        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
361        let committee_clone = committee.clone();
362
363        let mut mock_ledger = MockLedger::default();
364        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
365        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
366        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
367        mock_ledger.expect_check_solution_basic().returning(|_, _| Err(anyhow!("")));
368        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
369        // Initialize the storage.
370        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
371
372        // Create the Worker.
373        let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
374        let solution_id = rng.gen::<u64>().into();
375        let solution = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
376        let checksum = solution.to_checksum::<CurrentNetwork>().unwrap();
377        let transmission_id = TransmissionID::Solution(solution_id, checksum);
378        let result = worker.process_unconfirmed_solution(solution_id, solution).await;
379        assert!(result.is_err());
380        assert!(!worker.ready.contains(transmission_id));
381    }
382
383    #[tokio::test]
384    async fn test_process_transaction_ok() {
385        let mut rng = &mut TestRng::default();
386        // Sample a committee.
387        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
388        let committee_clone = committee.clone();
389
390        let mut mock_ledger = MockLedger::default();
391        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
392        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
393        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
394        mock_ledger.expect_check_transaction_basic().returning(|_, _| Ok(()));
395        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
396        // Initialize the storage.
397        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
398
399        // Create the Worker.
400        let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
401        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
402        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
403        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
404        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
405        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
406        assert!(result.is_ok());
407        assert!(worker.ready.contains(transmission_id));
408    }
409
410    #[tokio::test]
411    async fn test_process_transaction_nok() {
412        let mut rng = &mut TestRng::default();
413        // Sample a committee.
414        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
415        let committee_clone = committee.clone();
416
417        let mut mock_ledger = MockLedger::default();
418        mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
419        mock_ledger.expect_get_committee_lookback_for_round().returning(move |_| Ok(committee_clone.clone()));
420        mock_ledger.expect_contains_transmission().returning(|_| Ok(false));
421        mock_ledger.expect_check_transaction_basic().returning(|_, _| Err(anyhow!("")));
422        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
423        // Initialize the storage.
424        let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
425
426        // Create the Worker.
427        let worker = Worker::new(0, storage, ledger, Default::default()).unwrap();
428        let transaction_id: <CurrentNetwork as Network>::TransactionID = Field::<CurrentNetwork>::rand(&mut rng).into();
429        let transaction = Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
430        let checksum = transaction.to_checksum::<CurrentNetwork>().unwrap();
431        let transmission_id = TransmissionID::Transaction(transaction_id, checksum);
432        let result = worker.process_unconfirmed_transaction(transaction_id, transaction).await;
433        assert!(result.is_err());
434        assert!(!worker.ready.contains(transmission_id));
435    }
436
437    #[tokio::test]
438    async fn test_storage_gc_on_initialization() {
439        let rng = &mut TestRng::default();
440
441        for _ in 0..ITERATIONS {
442            // Mock the ledger round.
443            let max_gc_rounds = rng.gen_range(50..=100);
444            let latest_ledger_round = rng.gen_range((max_gc_rounds + 1)..1000);
445            let expected_gc_round = latest_ledger_round - max_gc_rounds;
446
447            // Sample a committee.
448            let committee =
449                snarkvm::ledger::committee::test_helpers::sample_committee_for_round(latest_ledger_round, rng);
450
451            let mut mock_ledger = MockLedger::default();
452            mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
453
454            let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
455            // Initialize the storage.
456            let storage =
457                Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
458
459            // Ensure that the storage GC round is correct.
460            assert_eq!(storage.gc_round(), expected_gc_round);
461        }
462    }
463}
464
465#[cfg(test)]
466mod prop_tests {
467    use super::*;
468    use amareleo_node_bft_ledger_service::MockLedgerService;
469    use snarkvm::{
470        console::account::Address,
471        ledger::committee::{Committee, MIN_VALIDATOR_STAKE},
472    };
473
474    use test_strategy::proptest;
475
476    type CurrentNetwork = snarkvm::prelude::MainnetV0;
477
478    // Initializes a new test committee.
479    fn new_test_committee(n: u16) -> Committee<CurrentNetwork> {
480        let mut members = IndexMap::with_capacity(n as usize);
481        for i in 0..n {
482            // Sample the address.
483            let rng = &mut TestRng::fixed(i as u64);
484            let address = Address::new(rng.gen());
485            info!("Validator {i}: {address}");
486            members.insert(address, (MIN_VALIDATOR_STAKE, false, rng.gen_range(0..100)));
487        }
488        // Initialize the committee.
489        Committee::<CurrentNetwork>::new(1u64, members).unwrap()
490    }
491
492    #[proptest]
493    fn worker_initialization(#[strategy(0..MAX_WORKERS)] id: u8, storage: Storage<CurrentNetwork>) {
494        let committee = new_test_committee(4);
495        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
496        let worker = Worker::new(id, storage, ledger, Default::default()).unwrap();
497        assert_eq!(worker.id(), id);
498    }
499
500    #[proptest]
501    fn invalid_worker_id(#[strategy(MAX_WORKERS..)] id: u8, storage: Storage<CurrentNetwork>) {
502        let committee = new_test_committee(4);
503        let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(MockLedgerService::new(committee));
504        let worker = Worker::new(id, storage, ledger, Default::default());
505        // TODO once Worker implements Debug, simplify this with `unwrap_err`
506        if let Err(error) = worker {
507            assert_eq!(error.to_string(), format!("Invalid worker ID '{}'", id));
508        }
509    }
510}