1use crate::types::Balance;
2use aliases::Aliases;
3use borsh::{BorshDeserialize, BorshSerialize};
4use near_schema_checker_lib::ProtocolSchema;
5
6#[derive(
7 Default,
8 BorshSerialize,
9 BorshDeserialize,
10 serde::Serialize,
11 Clone,
12 Debug,
13 PartialEq,
14 Eq,
15 ProtocolSchema,
16)]
17pub struct WeightedIndex {
18 weight_sum: Balance,
19 aliases: Vec<u64>,
20 no_alias_odds: Vec<Balance>,
21}
22
23impl WeightedIndex {
25 pub fn new(weights: Vec<Balance>) -> Self {
26 let n = Balance::from(weights.len() as u64);
27 let mut aliases = Aliases::new(weights.len());
28
29 let mut no_alias_odds = weights;
30 let mut weight_sum: Balance = 0;
31 for w in &mut no_alias_odds {
32 weight_sum += *w;
33 *w *= n;
34 }
35
36 for (index, &odds) in no_alias_odds.iter().enumerate() {
37 if odds < weight_sum {
38 aliases.push_small(index);
39 } else {
40 aliases.push_big(index);
41 }
42 }
43
44 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
45 let s = aliases.pop_small();
46 let b = aliases.pop_big();
47
48 aliases.set_alias(s, b);
49 no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
50
51 if no_alias_odds[b] < weight_sum {
52 aliases.push_small(b);
53 } else {
54 aliases.push_big(b);
55 }
56 }
57
58 while !aliases.smalls_is_empty() {
59 no_alias_odds[aliases.pop_small()] = weight_sum;
60 }
61
62 while !aliases.bigs_is_empty() {
63 no_alias_odds[aliases.pop_big()] = weight_sum;
64 }
65
66 Self { weight_sum, no_alias_odds, aliases: aliases.get_aliases() }
67 }
68
69 pub fn sample(&self, seed: [u8; 32]) -> usize {
70 let usize_seed = Self::copy_8_bytes(&seed[0..8]);
71 let balance_seed = Self::copy_16_bytes(&seed[8..24]);
72 let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
73 let uniform_weight = Balance::from_le_bytes(balance_seed) % self.weight_sum;
74
75 if uniform_weight < self.no_alias_odds[uniform_index] {
76 uniform_index
77 } else {
78 self.aliases[uniform_index] as usize
79 }
80 }
81
82 pub fn get_aliases(&self) -> &[u64] {
83 &self.aliases
84 }
85
86 pub fn get_no_alias_odds(&self) -> &[Balance] {
87 &self.no_alias_odds
88 }
89
90 fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
91 let mut result = [0u8; 8];
92 result.clone_from_slice(arr);
93 result
94 }
95
96 fn copy_16_bytes(arr: &[u8]) -> [u8; 16] {
97 let mut result = [0u8; 16];
98 result.clone_from_slice(arr);
99 result
100 }
101}
102
103mod aliases {
105 pub struct Aliases {
106 aliases: Vec<usize>,
107 smalls: Vec<usize>,
108 bigs: Vec<usize>,
109 }
110
111 impl Aliases {
112 pub fn new(n: usize) -> Self {
113 Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
114 }
115
116 pub fn push_big(&mut self, b: usize) {
117 self.bigs.push(b);
118 }
119
120 pub fn pop_big(&mut self) -> usize {
121 self.bigs.pop().unwrap()
122 }
123
124 pub fn bigs_is_empty(&self) -> bool {
125 self.bigs.is_empty()
126 }
127
128 pub fn push_small(&mut self, s: usize) {
129 self.smalls.push(s);
130 }
131
132 pub fn pop_small(&mut self) -> usize {
133 self.smalls.pop().unwrap()
134 }
135
136 pub fn smalls_is_empty(&self) -> bool {
137 self.smalls.is_empty()
138 }
139
140 pub fn set_alias(&mut self, index: usize, alias: usize) {
141 self.aliases[index] = alias;
142 }
143
144 pub fn get_aliases(self) -> Vec<u64> {
145 self.aliases.into_iter().map(|a| a as u64).collect()
146 }
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use crate::hash;
153 use crate::rand::WeightedIndex;
154
155 #[test]
156 fn test_should_correctly_compute_odds_and_aliases() {
157 let weights = vec![5, 8, 4, 10, 4, 4, 5];
159 let weighted_index = WeightedIndex::new(weights);
160
161 assert_eq!(weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
162
163 assert_eq!(weighted_index.get_no_alias_odds(), &[35, 40, 28, 29, 28, 28, 35]);
164 }
165
166 #[test]
167 fn test_sample_should_produce_correct_distribution() {
168 let weights = vec![5, 1, 1];
169 let weighted_index = WeightedIndex::new(weights);
170
171 let n_samples = 1_000_000;
172 let mut seed = hash(&[0; 32]);
173 let mut counts: [i32; 3] = [0, 0, 0];
174 for _ in 0..n_samples {
175 let index = weighted_index.sample(seed);
176 counts[index] += 1;
177 seed = hash(&seed);
178 }
179
180 assert_relative_closeness(counts[0], 5 * counts[1]);
181 assert_relative_closeness(counts[1], counts[2]);
182 }
183
184 #[track_caller]
186 fn assert_relative_closeness(x: i32, y: i32) {
187 let diff = (y - x).abs();
188 let relative_diff = f64::from(diff) / f64::from(x);
189 assert!(relative_diff < 0.005);
190 }
191
192 fn hash(input: &[u8]) -> [u8; 32] {
193 hash::hash(input).0
194 }
195}