1use aliases::Aliases;
2use borsh::{BorshDeserialize, BorshSerialize};
3use unc_primitives_core::types::Power;
4
5#[derive(
6 Default, BorshSerialize, BorshDeserialize, serde::Serialize, Clone, Debug, PartialEq, Eq,
7)]
8pub struct WeightedIndex {
9 weight_sum: Power,
10 aliases: Vec<u64>,
11 no_alias_odds: Vec<Power>,
12}
13
14impl WeightedIndex {
15 pub fn new(weights: Vec<Power>) -> Self {
16 let n = Power::from(weights.len() as u64);
17 let mut aliases = Aliases::new(weights.len());
18
19 let mut no_alias_odds = weights;
20 let mut weight_sum: Power = 0;
21 for w in no_alias_odds.iter_mut() {
22 weight_sum += *w;
23
24 *w *= n;
25 }
26
27 for (index, &odds) in no_alias_odds.iter().enumerate() {
28 if odds < weight_sum {
29 aliases.push_small(index);
30 } else {
31 aliases.push_big(index);
32 }
33 }
34
35 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
36 let s = aliases.pop_small();
37 let b = aliases.pop_big();
38
39 aliases.set_alias(s, b);
40 no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
41
42 if no_alias_odds[b] < weight_sum {
43 aliases.push_small(b);
44 } else {
45 aliases.push_big(b);
46 }
47 }
48
49 while !aliases.smalls_is_empty() {
50 no_alias_odds[aliases.pop_small()] = weight_sum;
51 }
52
53 while !aliases.bigs_is_empty() {
54 no_alias_odds[aliases.pop_big()] = weight_sum;
55 }
56
57 Self { weight_sum, no_alias_odds, aliases: aliases.get_aliases() }
58 }
59
60 pub fn sample(&self, seed: [u8; 32]) -> usize {
61 let usize_seed = Self::copy_8_bytes(&seed[0..8]);
62 let mut power_seed = [0u8; 16];
63 power_seed[0..8].copy_from_slice(&seed[8..16]);
64 let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
65 let uniform_weight = Power::from_le_bytes(power_seed) % self.weight_sum;
66
67 if uniform_weight < self.no_alias_odds[uniform_index] {
68 uniform_index
69 } else {
70 self.aliases[uniform_index] as usize
71 }
72 }
73
74 pub fn get_aliases(&self) -> &[u64] {
75 &self.aliases
76 }
77
78 pub fn get_no_alias_odds(&self) -> &[Power] {
79 &self.no_alias_odds
80 }
81
82 fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
83 let mut result = [0u8; 8];
84 result.clone_from_slice(arr);
85 result
86 }
87
88 }
94
95mod aliases {
97 pub struct Aliases {
98 aliases: Vec<usize>,
99 smalls: Vec<usize>,
100 bigs: Vec<usize>,
101 }
102
103 impl Aliases {
104 pub fn new(n: usize) -> Self {
105 Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
106 }
107
108 pub fn push_big(&mut self, b: usize) {
109 self.bigs.push(b);
110 }
111
112 pub fn pop_big(&mut self) -> usize {
113 self.bigs.pop().unwrap()
114 }
115
116 pub fn bigs_is_empty(&self) -> bool {
117 self.bigs.is_empty()
118 }
119
120 pub fn push_small(&mut self, s: usize) {
121 self.smalls.push(s);
122 }
123
124 pub fn pop_small(&mut self) -> usize {
125 self.smalls.pop().unwrap()
126 }
127
128 pub fn smalls_is_empty(&self) -> bool {
129 self.smalls.is_empty()
130 }
131
132 pub fn set_alias(&mut self, index: usize, alias: usize) {
133 self.aliases[index] = alias;
134 }
135
136 pub fn get_aliases(self) -> Vec<u64> {
137 self.aliases.into_iter().map(|a| a as u64).collect()
138 }
139 }
140}
141
142#[cfg(test)]
143mod test {
144 use crate::hash;
145 use crate::rand::WeightedIndex;
146
147 #[test]
148 fn test_should_correctly_compute_odds_and_aliases() {
149 let weights = vec![5, 8, 4, 10, 4, 4, 5];
151 let weighted_index = WeightedIndex::new(weights);
152
153 assert_eq!(weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
154
155 assert_eq!(weighted_index.get_no_alias_odds(), &[35, 40, 28, 29, 28, 28, 35]);
156 }
157
158 #[test]
159 fn test_sample_should_produce_correct_distribution() {
160 let weights = vec![5, 1, 1];
161 let weighted_index = WeightedIndex::new(weights);
162
163 let n_samples = 1_000_000;
164 let mut seed = hash(&[0; 32]);
165 let mut counts: [i32; 3] = [0, 0, 0];
166 for _ in 0..n_samples {
167 let index = weighted_index.sample(seed);
168 counts[index] += 1;
169 seed = hash(&seed);
170 }
171
172 assert_relative_closeness(counts[0], 5 * counts[1]);
173 assert_relative_closeness(counts[1], counts[2]);
174 }
175
176 #[track_caller]
178 fn assert_relative_closeness(x: i32, y: i32) {
179 let diff = (y - x).abs();
180 let relative_diff = f64::from(diff) / f64::from(x);
181 assert!(relative_diff < 0.005);
182 }
183
184 fn hash(input: &[u8]) -> [u8; 32] {
185 hash::hash(input).0
186 }
187}