casper_node/components/consensus/protocols/
common.rs

1//! Utilities common to different consensus algorithms.
2
3use itertools::Itertools;
4use num_rational::Ratio;
5use std::collections::{BTreeMap, HashSet};
6
7use num_traits::AsPrimitive;
8
9use crate::components::consensus::{
10    traits::Context,
11    utils::{ValidatorMap, Validators, Weight},
12};
13use casper_types::U512;
14
15/// Computes the validator set given the stakes and the faulty and inactive
16/// reports from the previous eras.
17pub fn validators<C: Context>(
18    faulty: &HashSet<C::ValidatorId>,
19    inactive: &HashSet<C::ValidatorId>,
20    validator_stakes: BTreeMap<C::ValidatorId, U512>,
21) -> Validators<C::ValidatorId> {
22    let sum_stakes = safe_sum(validator_stakes.values().copied()).expect("should not overflow");
23    // We use u64 weights. Scale down by floor(sum / u64::MAX) + 1.
24    // This guarantees that the resulting sum is greater than 0 and less than u64::MAX.
25    #[allow(clippy::arithmetic_side_effects)] // Divisor isn't 0 and addition can't overflow.
26    let scaling_factor: U512 = sum_stakes / U512::from(u64::MAX) + 1;
27
28    // TODO sort validators by descending weight
29    #[allow(clippy::arithmetic_side_effects)] // Divisor isn't 0.
30    let mut validators: Validators<C::ValidatorId> = validator_stakes
31        .into_iter()
32        .map(|(key, stake)| (key, AsPrimitive::<u64>::as_(stake / scaling_factor)))
33        .collect();
34
35    for vid in faulty {
36        validators.ban(vid);
37    }
38
39    for vid in inactive {
40        validators.set_cannot_propose(vid);
41    }
42
43    assert!(
44        validators.ensure_nonzero_proposing_stake(),
45        "cannot start era with total weight 0"
46    );
47
48    validators
49}
50
51/// Compute the validator weight map from the set of validators.
52pub(crate) fn validator_weights<C: Context>(
53    validators: &Validators<C::ValidatorId>,
54) -> ValidatorMap<Weight> {
55    ValidatorMap::from(validators.iter().map(|v| v.weight()).collect_vec())
56}
57
58/// Computes the fault tolerance threshold for the protocol instance
59pub(crate) fn ftt<C: Context>(
60    finality_threshold_fraction: Ratio<u64>,
61    validators: &Validators<C::ValidatorId>,
62) -> Weight {
63    let total_weight = u128::from(validators.total_weight());
64    assert!(
65        finality_threshold_fraction < 1.into(),
66        "finality threshold must be less than 100%"
67    );
68    #[allow(clippy::arithmetic_side_effects)] // FTT is less than 1, so this can't overflow
69    let ftt = total_weight * *finality_threshold_fraction.numer() as u128
70        / *finality_threshold_fraction.denom() as u128;
71    (ftt as u64).into()
72}
73
74/// A U512 sum implementation that check for overflow.
75fn safe_sum<I>(mut iterator: I) -> Option<U512>
76where
77    I: Iterator<Item = U512>,
78{
79    iterator.try_fold(U512::zero(), |acc, n| acc.checked_add(n))
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::components::consensus::ClContext;
86    use casper_types::{testing::TestRng, PublicKey};
87    use rand::Rng;
88
89    #[test]
90    #[should_panic]
91    fn ftt_panics_during_overflow() {
92        let rng = &mut TestRng::new();
93        let mut validator_stakes = BTreeMap::new();
94        validator_stakes.insert(PublicKey::random(rng), U512::MAX);
95        validator_stakes.insert(PublicKey::random(rng), U512::from(1_u32));
96
97        validators::<ClContext>(&Default::default(), &Default::default(), validator_stakes);
98    }
99
100    #[test]
101    fn total_weights_less_than_u64_max() {
102        let mut rng = TestRng::new();
103
104        let (test_stake_1, test_stake_2) = (rng.gen(), rng.gen());
105
106        let mut test_stakes = |a: u64, b: u64| -> BTreeMap<PublicKey, U512> {
107            let mut result = BTreeMap::new();
108            result.insert(
109                PublicKey::random(&mut rng),
110                U512::from(a) * U512::from(u128::MAX),
111            );
112            result.insert(
113                PublicKey::random(&mut rng),
114                U512::from(b) * U512::from(u128::MAX),
115            );
116            result
117        };
118
119        // First, we test with random values.
120        let stakes = test_stakes(test_stake_1, test_stake_2);
121        let weights = validators::<ClContext>(&Default::default(), &Default::default(), stakes);
122        assert!(weights.total_weight().0 < u64::MAX);
123
124        // Then, we test with values that were known to cause issues before.
125        let stakes = test_stakes(514, 771);
126        let weights = validators::<ClContext>(&Default::default(), &Default::default(), stakes);
127        assert!(weights.total_weight().0 < u64::MAX);
128
129        let stakes = test_stakes(668, 614);
130        let weights = validators::<ClContext>(&Default::default(), &Default::default(), stakes);
131        assert!(weights.total_weight().0 < u64::MAX);
132    }
133}