near_primitives/
rand.rs

1use crate::types::Balance;
2use aliases::Aliases;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7    Default,
8    BorshSerialize,
9    BorshDeserialize,
10    serde::Serialize,
11    Clone,
12    Debug,
13    PartialEq,
14    Eq,
15    ProtocolSchema,
16)]
17pub struct WeightedIndex {
18    weight_sum: Balance,
19    aliases: Vec<u64>,
20    no_alias_odds: Vec<Balance>,
21}
22
23// cspell:words bigs
24impl WeightedIndex {
25    pub fn new(weights: Vec<Balance>) -> Self {
26        let n = Balance::from(weights.len() as u64);
27        let mut aliases = Aliases::new(weights.len());
28
29        let mut no_alias_odds = weights;
30        let mut weight_sum: Balance = 0;
31        for w in &mut no_alias_odds {
32            weight_sum += *w;
33            *w *= n;
34        }
35
36        for (index, &odds) in no_alias_odds.iter().enumerate() {
37            if odds < weight_sum {
38                aliases.push_small(index);
39            } else {
40                aliases.push_big(index);
41            }
42        }
43
44        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
45            let s = aliases.pop_small();
46            let b = aliases.pop_big();
47
48            aliases.set_alias(s, b);
49            no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
50
51            if no_alias_odds[b] < weight_sum {
52                aliases.push_small(b);
53            } else {
54                aliases.push_big(b);
55            }
56        }
57
58        while !aliases.smalls_is_empty() {
59            no_alias_odds[aliases.pop_small()] = weight_sum;
60        }
61
62        while !aliases.bigs_is_empty() {
63            no_alias_odds[aliases.pop_big()] = weight_sum;
64        }
65
66        Self { weight_sum, no_alias_odds, aliases: aliases.get_aliases() }
67    }
68
69    pub fn sample(&self, seed: [u8; 32]) -> usize {
70        let usize_seed = Self::copy_8_bytes(&seed[0..8]);
71        let balance_seed = Self::copy_16_bytes(&seed[8..24]);
72        let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
73        let uniform_weight = Balance::from_le_bytes(balance_seed) % self.weight_sum;
74
75        if uniform_weight < self.no_alias_odds[uniform_index] {
76            uniform_index
77        } else {
78            self.aliases[uniform_index] as usize
79        }
80    }
81
82    pub fn get_aliases(&self) -> &[u64] {
83        &self.aliases
84    }
85
86    pub fn get_no_alias_odds(&self) -> &[Balance] {
87        &self.no_alias_odds
88    }
89
90    fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
91        let mut result = [0u8; 8];
92        result.clone_from_slice(arr);
93        result
94    }
95
96    fn copy_16_bytes(arr: &[u8]) -> [u8; 16] {
97        let mut result = [0u8; 16];
98        result.clone_from_slice(arr);
99        result
100    }
101}
102
103/// Sub-module to encapsulate helper struct for managing aliases
104mod aliases {
105    pub struct Aliases {
106        aliases: Vec<usize>,
107        smalls: Vec<usize>,
108        bigs: Vec<usize>,
109    }
110
111    impl Aliases {
112        pub fn new(n: usize) -> Self {
113            Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
114        }
115
116        pub fn push_big(&mut self, b: usize) {
117            self.bigs.push(b);
118        }
119
120        pub fn pop_big(&mut self) -> usize {
121            self.bigs.pop().unwrap()
122        }
123
124        pub fn bigs_is_empty(&self) -> bool {
125            self.bigs.is_empty()
126        }
127
128        pub fn push_small(&mut self, s: usize) {
129            self.smalls.push(s);
130        }
131
132        pub fn pop_small(&mut self) -> usize {
133            self.smalls.pop().unwrap()
134        }
135
136        pub fn smalls_is_empty(&self) -> bool {
137            self.smalls.is_empty()
138        }
139
140        pub fn set_alias(&mut self, index: usize, alias: usize) {
141            self.aliases[index] = alias;
142        }
143
144        pub fn get_aliases(self) -> Vec<u64> {
145            self.aliases.into_iter().map(|a| a as u64).collect()
146        }
147    }
148}
149
150#[cfg(test)]
151mod test {
152    use crate::hash;
153    use crate::rand::WeightedIndex;
154
155    #[test]
156    fn test_should_correctly_compute_odds_and_aliases() {
157        // Example taken from https://www.keithschwarz.com/darts-dice-coins/
158        let weights = vec![5, 8, 4, 10, 4, 4, 5];
159        let weighted_index = WeightedIndex::new(weights);
160
161        assert_eq!(weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
162
163        assert_eq!(weighted_index.get_no_alias_odds(), &[35, 40, 28, 29, 28, 28, 35]);
164    }
165
166    #[test]
167    fn test_sample_should_produce_correct_distribution() {
168        let weights = vec![5, 1, 1];
169        let weighted_index = WeightedIndex::new(weights);
170
171        let n_samples = 1_000_000;
172        let mut seed = hash(&[0; 32]);
173        let mut counts: [i32; 3] = [0, 0, 0];
174        for _ in 0..n_samples {
175            let index = weighted_index.sample(seed);
176            counts[index] += 1;
177            seed = hash(&seed);
178        }
179
180        assert_relative_closeness(counts[0], 5 * counts[1]);
181        assert_relative_closeness(counts[1], counts[2]);
182    }
183
184    /// Assert y is within 0.5% of x.
185    #[track_caller]
186    fn assert_relative_closeness(x: i32, y: i32) {
187        let diff = (y - x).abs();
188        let relative_diff = f64::from(diff) / f64::from(x);
189        assert!(relative_diff < 0.005);
190    }
191
192    fn hash(input: &[u8]) -> [u8; 32] {
193        hash::hash(input).0
194    }
195}