near_primitives/
rand.rs

1use crate::types::Balance;
2use aliases::Aliases;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7    Default,
8    BorshSerialize,
9    BorshDeserialize,
10    serde::Serialize,
11    Clone,
12    Debug,
13    PartialEq,
14    Eq,
15    ProtocolSchema,
16)]
17pub struct StakeWeightedIndex {
18    stake_sum: Balance,
19    aliases: Vec<u64>,
20    no_alias_odds: Vec<Balance>,
21}
22
23// cspell:words bigs
24// TODO: move out of
25impl StakeWeightedIndex {
26    pub fn new(stakes: Vec<Balance>) -> Self {
27        let n = stakes.len() as u64;
28        let mut aliases = Aliases::new(stakes.len());
29
30        let mut no_alias_odds = stakes;
31        let mut stake_sum = Balance::ZERO;
32        for w in &mut no_alias_odds {
33            stake_sum = stake_sum.checked_add(*w).unwrap();
34            *w = w.checked_mul(u128::from(n)).unwrap();
35        }
36
37        for (index, &odds) in no_alias_odds.iter().enumerate() {
38            if odds < stake_sum {
39                aliases.push_small(index);
40            } else {
41                aliases.push_big(index);
42            }
43        }
44
45        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
46            let s = aliases.pop_small();
47            let b = aliases.pop_big();
48
49            aliases.set_alias(s, b);
50            no_alias_odds[b] = no_alias_odds[b]
51                .checked_sub(stake_sum)
52                .unwrap()
53                .checked_add(no_alias_odds[s])
54                .unwrap();
55
56            if no_alias_odds[b] < stake_sum {
57                aliases.push_small(b);
58            } else {
59                aliases.push_big(b);
60            }
61        }
62
63        while !aliases.smalls_is_empty() {
64            no_alias_odds[aliases.pop_small()] = stake_sum;
65        }
66
67        while !aliases.bigs_is_empty() {
68            no_alias_odds[aliases.pop_big()] = stake_sum;
69        }
70
71        Self { stake_sum, no_alias_odds, aliases: aliases.get_aliases() }
72    }
73
74    pub fn sample(&self, seed: [u8; 32]) -> usize {
75        let usize_seed = Self::copy_8_bytes(&seed[0..8]);
76        let balance_seed = Self::copy_16_bytes(&seed[8..24]);
77        let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
78        let uniform_stake = Balance::from_yoctonear(
79            u128::from_le_bytes(balance_seed) % self.stake_sum.as_yoctonear(),
80        );
81
82        if uniform_stake < self.no_alias_odds[uniform_index] {
83            uniform_index
84        } else {
85            self.aliases[uniform_index] as usize
86        }
87    }
88
89    pub fn get_aliases(&self) -> &[u64] {
90        &self.aliases
91    }
92
93    pub fn get_no_alias_odds(&self) -> &[Balance] {
94        &self.no_alias_odds
95    }
96
97    fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
98        let mut result = [0u8; 8];
99        result.clone_from_slice(arr);
100        result
101    }
102
103    fn copy_16_bytes(arr: &[u8]) -> [u8; 16] {
104        let mut result = [0u8; 16];
105        result.clone_from_slice(arr);
106        result
107    }
108}
109
110/// Sub-module to encapsulate helper struct for managing aliases
111mod aliases {
112    pub struct Aliases {
113        aliases: Vec<usize>,
114        smalls: Vec<usize>,
115        bigs: Vec<usize>,
116    }
117
118    impl Aliases {
119        pub fn new(n: usize) -> Self {
120            Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
121        }
122
123        pub fn push_big(&mut self, b: usize) {
124            self.bigs.push(b);
125        }
126
127        pub fn pop_big(&mut self) -> usize {
128            self.bigs.pop().unwrap()
129        }
130
131        pub fn bigs_is_empty(&self) -> bool {
132            self.bigs.is_empty()
133        }
134
135        pub fn push_small(&mut self, s: usize) {
136            self.smalls.push(s);
137        }
138
139        pub fn pop_small(&mut self) -> usize {
140            self.smalls.pop().unwrap()
141        }
142
143        pub fn smalls_is_empty(&self) -> bool {
144            self.smalls.is_empty()
145        }
146
147        pub fn set_alias(&mut self, index: usize, alias: usize) {
148            self.aliases[index] = alias;
149        }
150
151        pub fn get_aliases(self) -> Vec<u64> {
152            self.aliases.into_iter().map(|a| a as u64).collect()
153        }
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use crate::hash;
160    use crate::rand::StakeWeightedIndex;
161    use near_primitives_core::types::Balance;
162
163    #[test]
164    fn test_should_correctly_compute_odds_and_aliases() {
165        // Example taken from https://www.keithschwarz.com/darts-dice-coins/
166        let stakes = vec![
167            Balance::from_yoctonear(5),
168            Balance::from_yoctonear(8),
169            Balance::from_yoctonear(4),
170            Balance::from_yoctonear(10),
171            Balance::from_yoctonear(4),
172            Balance::from_yoctonear(4),
173            Balance::from_yoctonear(5),
174        ];
175        let stake_weighted_index = StakeWeightedIndex::new(stakes);
176
177        assert_eq!(stake_weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
178
179        assert_eq!(
180            stake_weighted_index.get_no_alias_odds(),
181            &[
182                Balance::from_yoctonear(35),
183                Balance::from_yoctonear(40),
184                Balance::from_yoctonear(28),
185                Balance::from_yoctonear(29),
186                Balance::from_yoctonear(28),
187                Balance::from_yoctonear(28),
188                Balance::from_yoctonear(35)
189            ]
190        );
191    }
192
193    #[test]
194    fn test_sample_should_produce_correct_distribution() {
195        let stakes = vec![
196            Balance::from_yoctonear(5),
197            Balance::from_yoctonear(1),
198            Balance::from_yoctonear(1),
199        ];
200        let stake_weighted_index = StakeWeightedIndex::new(stakes);
201
202        let n_samples = 1_000_000;
203        let mut seed = hash(&[0; 32]);
204        let mut counts: [i32; 3] = [0, 0, 0];
205        for _ in 0..n_samples {
206            let index = stake_weighted_index.sample(seed);
207            counts[index] += 1;
208            seed = hash(&seed);
209        }
210
211        assert_relative_closeness(counts[0], 5 * counts[1]);
212        assert_relative_closeness(counts[1], counts[2]);
213    }
214
215    /// Assert y is within 0.5% of x.
216    #[track_caller]
217    fn assert_relative_closeness(x: i32, y: i32) {
218        let diff = (y - x).abs();
219        let relative_diff = f64::from(diff) / f64::from(x);
220        assert!(relative_diff < 0.005);
221    }
222
223    fn hash(input: &[u8]) -> [u8; 32] {
224        hash::hash(input).0
225    }
226}