1use crate::types::Balance;
2use aliases::Aliases;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7 Default,
8 BorshSerialize,
9 BorshDeserialize,
10 serde::Serialize,
11 Clone,
12 Debug,
13 PartialEq,
14 Eq,
15 ProtocolSchema,
16)]
17pub struct StakeWeightedIndex {
18 stake_sum: Balance,
19 aliases: Vec<u64>,
20 no_alias_odds: Vec<Balance>,
21}
22
23impl StakeWeightedIndex {
26 pub fn new(stakes: Vec<Balance>) -> Self {
27 let n = stakes.len() as u64;
28 let mut aliases = Aliases::new(stakes.len());
29
30 let mut no_alias_odds = stakes;
31 let mut stake_sum = Balance::ZERO;
32 for w in &mut no_alias_odds {
33 stake_sum = stake_sum.checked_add(*w).unwrap();
34 *w = w.checked_mul(u128::from(n)).unwrap();
35 }
36
37 for (index, &odds) in no_alias_odds.iter().enumerate() {
38 if odds < stake_sum {
39 aliases.push_small(index);
40 } else {
41 aliases.push_big(index);
42 }
43 }
44
45 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
46 let s = aliases.pop_small();
47 let b = aliases.pop_big();
48
49 aliases.set_alias(s, b);
50 no_alias_odds[b] = no_alias_odds[b]
51 .checked_sub(stake_sum)
52 .unwrap()
53 .checked_add(no_alias_odds[s])
54 .unwrap();
55
56 if no_alias_odds[b] < stake_sum {
57 aliases.push_small(b);
58 } else {
59 aliases.push_big(b);
60 }
61 }
62
63 while !aliases.smalls_is_empty() {
64 no_alias_odds[aliases.pop_small()] = stake_sum;
65 }
66
67 while !aliases.bigs_is_empty() {
68 no_alias_odds[aliases.pop_big()] = stake_sum;
69 }
70
71 Self { stake_sum, no_alias_odds, aliases: aliases.get_aliases() }
72 }
73
74 pub fn sample(&self, seed: [u8; 32]) -> usize {
75 let usize_seed = Self::copy_8_bytes(&seed[0..8]);
76 let balance_seed = Self::copy_16_bytes(&seed[8..24]);
77 let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
78 let uniform_stake = Balance::from_yoctonear(
79 u128::from_le_bytes(balance_seed) % self.stake_sum.as_yoctonear(),
80 );
81
82 if uniform_stake < self.no_alias_odds[uniform_index] {
83 uniform_index
84 } else {
85 self.aliases[uniform_index] as usize
86 }
87 }
88
89 pub fn get_aliases(&self) -> &[u64] {
90 &self.aliases
91 }
92
93 pub fn get_no_alias_odds(&self) -> &[Balance] {
94 &self.no_alias_odds
95 }
96
97 fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
98 let mut result = [0u8; 8];
99 result.clone_from_slice(arr);
100 result
101 }
102
103 fn copy_16_bytes(arr: &[u8]) -> [u8; 16] {
104 let mut result = [0u8; 16];
105 result.clone_from_slice(arr);
106 result
107 }
108}
109
110mod aliases {
112 pub struct Aliases {
113 aliases: Vec<usize>,
114 smalls: Vec<usize>,
115 bigs: Vec<usize>,
116 }
117
118 impl Aliases {
119 pub fn new(n: usize) -> Self {
120 Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
121 }
122
123 pub fn push_big(&mut self, b: usize) {
124 self.bigs.push(b);
125 }
126
127 pub fn pop_big(&mut self) -> usize {
128 self.bigs.pop().unwrap()
129 }
130
131 pub fn bigs_is_empty(&self) -> bool {
132 self.bigs.is_empty()
133 }
134
135 pub fn push_small(&mut self, s: usize) {
136 self.smalls.push(s);
137 }
138
139 pub fn pop_small(&mut self) -> usize {
140 self.smalls.pop().unwrap()
141 }
142
143 pub fn smalls_is_empty(&self) -> bool {
144 self.smalls.is_empty()
145 }
146
147 pub fn set_alias(&mut self, index: usize, alias: usize) {
148 self.aliases[index] = alias;
149 }
150
151 pub fn get_aliases(self) -> Vec<u64> {
152 self.aliases.into_iter().map(|a| a as u64).collect()
153 }
154 }
155}
156
157#[cfg(test)]
158mod test {
159 use crate::hash;
160 use crate::rand::StakeWeightedIndex;
161 use near_primitives_core::types::Balance;
162
163 #[test]
164 fn test_should_correctly_compute_odds_and_aliases() {
165 let stakes = vec![
167 Balance::from_yoctonear(5),
168 Balance::from_yoctonear(8),
169 Balance::from_yoctonear(4),
170 Balance::from_yoctonear(10),
171 Balance::from_yoctonear(4),
172 Balance::from_yoctonear(4),
173 Balance::from_yoctonear(5),
174 ];
175 let stake_weighted_index = StakeWeightedIndex::new(stakes);
176
177 assert_eq!(stake_weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
178
179 assert_eq!(
180 stake_weighted_index.get_no_alias_odds(),
181 &[
182 Balance::from_yoctonear(35),
183 Balance::from_yoctonear(40),
184 Balance::from_yoctonear(28),
185 Balance::from_yoctonear(29),
186 Balance::from_yoctonear(28),
187 Balance::from_yoctonear(28),
188 Balance::from_yoctonear(35)
189 ]
190 );
191 }
192
193 #[test]
194 fn test_sample_should_produce_correct_distribution() {
195 let stakes = vec![
196 Balance::from_yoctonear(5),
197 Balance::from_yoctonear(1),
198 Balance::from_yoctonear(1),
199 ];
200 let stake_weighted_index = StakeWeightedIndex::new(stakes);
201
202 let n_samples = 1_000_000;
203 let mut seed = hash(&[0; 32]);
204 let mut counts: [i32; 3] = [0, 0, 0];
205 for _ in 0..n_samples {
206 let index = stake_weighted_index.sample(seed);
207 counts[index] += 1;
208 seed = hash(&seed);
209 }
210
211 assert_relative_closeness(counts[0], 5 * counts[1]);
212 assert_relative_closeness(counts[1], counts[2]);
213 }
214
215 #[track_caller]
217 fn assert_relative_closeness(x: i32, y: i32) {
218 let diff = (y - x).abs();
219 let relative_diff = f64::from(diff) / f64::from(x);
220 assert!(relative_diff < 0.005);
221 }
222
223 fn hash(input: &[u8]) -> [u8; 32] {
224 hash::hash(input).0
225 }
226}