unc_primitives/
rand.rs

1use aliases::Aliases;
2use borsh::{BorshDeserialize, BorshSerialize};
3use unc_primitives_core::types::Power;
4
5#[derive(
6    Default, BorshSerialize, BorshDeserialize, serde::Serialize, Clone, Debug, PartialEq, Eq,
7)]
8pub struct WeightedIndex {
9    weight_sum: Power,
10    aliases: Vec<u64>,
11    no_alias_odds: Vec<Power>,
12}
13
14impl WeightedIndex {
15    pub fn new(weights: Vec<Power>) -> Self {
16        let n = Power::from(weights.len() as u64);
17        let mut aliases = Aliases::new(weights.len());
18
19        let mut no_alias_odds = weights;
20        let mut weight_sum: Power = 0;
21        for w in no_alias_odds.iter_mut() {
22            weight_sum += *w;
23
24            *w *= n;
25        }
26
27        for (index, &odds) in no_alias_odds.iter().enumerate() {
28            if odds < weight_sum {
29                aliases.push_small(index);
30            } else {
31                aliases.push_big(index);
32            }
33        }
34
35        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
36            let s = aliases.pop_small();
37            let b = aliases.pop_big();
38
39            aliases.set_alias(s, b);
40            no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
41
42            if no_alias_odds[b] < weight_sum {
43                aliases.push_small(b);
44            } else {
45                aliases.push_big(b);
46            }
47        }
48
49        while !aliases.smalls_is_empty() {
50            no_alias_odds[aliases.pop_small()] = weight_sum;
51        }
52
53        while !aliases.bigs_is_empty() {
54            no_alias_odds[aliases.pop_big()] = weight_sum;
55        }
56
57        Self { weight_sum, no_alias_odds, aliases: aliases.get_aliases() }
58    }
59
60    pub fn sample(&self, seed: [u8; 32]) -> usize {
61        let usize_seed = Self::copy_8_bytes(&seed[0..8]);
62        let mut power_seed = [0u8; 16];
63        power_seed[0..8].copy_from_slice(&seed[8..16]);
64        let uniform_index = usize::from_le_bytes(usize_seed) % self.aliases.len();
65        let uniform_weight = Power::from_le_bytes(power_seed) % self.weight_sum;
66
67        if uniform_weight < self.no_alias_odds[uniform_index] {
68            uniform_index
69        } else {
70            self.aliases[uniform_index] as usize
71        }
72    }
73
74    pub fn get_aliases(&self) -> &[u64] {
75        &self.aliases
76    }
77
78    pub fn get_no_alias_odds(&self) -> &[Power] {
79        &self.no_alias_odds
80    }
81
82    fn copy_8_bytes(arr: &[u8]) -> [u8; 8] {
83        let mut result = [0u8; 8];
84        result.clone_from_slice(arr);
85        result
86    }
87
88    // fn copy_16_bytes(arr: &[u8]) -> [u8; 16] {
89    //     let mut result = [0u8; 16];
90    //     result.clone_from_slice(arr);
91    //     result
92    // }
93}
94
95/// Sub-module to encapsulate helper struct for managing aliases
96mod aliases {
97    pub struct Aliases {
98        aliases: Vec<usize>,
99        smalls: Vec<usize>,
100        bigs: Vec<usize>,
101    }
102
103    impl Aliases {
104        pub fn new(n: usize) -> Self {
105            Self { aliases: vec![0; n], smalls: Vec::with_capacity(n), bigs: Vec::with_capacity(n) }
106        }
107
108        pub fn push_big(&mut self, b: usize) {
109            self.bigs.push(b);
110        }
111
112        pub fn pop_big(&mut self) -> usize {
113            self.bigs.pop().unwrap()
114        }
115
116        pub fn bigs_is_empty(&self) -> bool {
117            self.bigs.is_empty()
118        }
119
120        pub fn push_small(&mut self, s: usize) {
121            self.smalls.push(s);
122        }
123
124        pub fn pop_small(&mut self) -> usize {
125            self.smalls.pop().unwrap()
126        }
127
128        pub fn smalls_is_empty(&self) -> bool {
129            self.smalls.is_empty()
130        }
131
132        pub fn set_alias(&mut self, index: usize, alias: usize) {
133            self.aliases[index] = alias;
134        }
135
136        pub fn get_aliases(self) -> Vec<u64> {
137            self.aliases.into_iter().map(|a| a as u64).collect()
138        }
139    }
140}
141
142#[cfg(test)]
143mod test {
144    use crate::hash;
145    use crate::rand::WeightedIndex;
146
147    #[test]
148    fn test_should_correctly_compute_odds_and_aliases() {
149        // Example taken from https://www.keithschwarz.com/darts-dice-coins/
150        let weights = vec![5, 8, 4, 10, 4, 4, 5];
151        let weighted_index = WeightedIndex::new(weights);
152
153        assert_eq!(weighted_index.get_aliases(), &[1, 0, 3, 1, 3, 3, 3]);
154
155        assert_eq!(weighted_index.get_no_alias_odds(), &[35, 40, 28, 29, 28, 28, 35]);
156    }
157
158    #[test]
159    fn test_sample_should_produce_correct_distribution() {
160        let weights = vec![5, 1, 1];
161        let weighted_index = WeightedIndex::new(weights);
162
163        let n_samples = 1_000_000;
164        let mut seed = hash(&[0; 32]);
165        let mut counts: [i32; 3] = [0, 0, 0];
166        for _ in 0..n_samples {
167            let index = weighted_index.sample(seed);
168            counts[index] += 1;
169            seed = hash(&seed);
170        }
171
172        assert_relative_closeness(counts[0], 5 * counts[1]);
173        assert_relative_closeness(counts[1], counts[2]);
174    }
175
176    /// Assert y is within 0.5% of x.
177    #[track_caller]
178    fn assert_relative_closeness(x: i32, y: i32) {
179        let diff = (y - x).abs();
180        let relative_diff = f64::from(diff) / f64::from(x);
181        assert!(relative_diff < 0.005);
182    }
183
184    fn hash(input: &[u8]) -> [u8; 32] {
185        hash::hash(input).0
186    }
187}