dynamic_weighted_sampler/
dynamic_weighted_sampler.rs

1use rand::{distr::weighted::WeightedIndex, seq::IteratorRandom, Rng};
2use rand_distr::Distribution;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use sugars::cvec;
6
7const DEFAULT_CAPACITY: usize = 1000;
8
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct DynamicWeightedSampler {
12    max_value: f64,
13    n_levels: usize,
14    total_weight: f64,
15    weights: Vec<f64>,
16    level_weight: Vec<f64>,
17    level_bucket: Vec<Vec<usize>>,
18    rev_level_bucket: Vec<usize>, // maps id -> idx within a level
19    level_max: Vec<f64>,
20}
21
22impl DynamicWeightedSampler {
23    pub fn new(max_value: f64) -> Self {
24        Self::new_with_capacity(max_value, DEFAULT_CAPACITY)
25    }
26
27    pub fn new_with_capacity(max_value: f64, physical_capacity: usize) -> Self {
28        assert!(physical_capacity > 0);
29        let n_levels = max_value.log2().ceil() as usize + 1;
30        let max_value = 2f64.powf(max_value.log2().ceil());
31        let total_weight = 0.;
32        let weights = vec![0.; physical_capacity];
33        let level_weight = vec![0.; n_levels];
34        let level_bucket = vec![vec![]; n_levels];
35        let rev_level_bucket = vec![0; physical_capacity];
36        let top_level = n_levels - 1;
37        let level_max = cvec![2usize.pow(top_level as u32 - i) as f64; i in 0u32..(n_levels as u32)];
38        Self {
39            max_value,
40            n_levels,
41            total_weight,
42            weights,
43            level_weight,
44            level_bucket,
45            rev_level_bucket,
46            level_max,
47        }
48    }
49
50    pub fn insert(&mut self, id: usize, weight: f64) {
51        assert!(weight > 0.);
52        if id > self.weights.len() - 1 {
53            self.weights.resize(id + 1, 0.);
54            self.rev_level_bucket.resize(id + 1, 0);
55        }
56        assert!(self.weights[id] == 0., "Inserting element id {id} with weight {weight}, but it already existed with weight {}", self.weights[id]);
57        assert!(weight <= self.max_value, "Adding element {id} with weight {weight} exceeds the maximum weight capacity of {}", self.max_value);
58        self.weights[id] = weight;
59        self.total_weight += weight;
60        let level = self.level(weight);
61        self.insert_to_level(id, level, weight)
62    }
63
64    #[inline(always)]
65    fn level(&self, weight: f64) -> usize {
66        assert!(weight <= self.max_value, "{weight} > {}", self.max_value);
67        assert!(weight > 0.);
68        let top_level = self.n_levels - 1;
69        let level_from_top = log2_ceil2(weight);
70        assert!(top_level >= level_from_top);
71        let level =  top_level - level_from_top;
72        level
73    }
74
75    #[inline(always)]
76    fn insert_to_level(&mut self, id: usize, level: usize, weight: f64) {
77        self.level_weight[level] += weight;
78        self.level_bucket[level].push(id);
79        self.rev_level_bucket[id] = self.level_bucket[level].len() - 1;
80    }
81
82    #[inline(always)]
83    fn remove_from_level(&mut self, id: usize, level: usize, weight: f64) {
84        debug_assert_eq!(self.level_bucket[level][self.rev_level_bucket[id]], id);
85        self.level_weight[level] -= weight;
86        let idx_in_level = self.rev_level_bucket[id];
87        let last_idx_in_level = self.level_bucket[level].len() - 1;
88        if idx_in_level != last_idx_in_level {
89            // swap with last element
90            let id_in_last_idx = self.level_bucket[level][last_idx_in_level];
91            self.level_bucket[level].swap(idx_in_level, last_idx_in_level);
92            self.rev_level_bucket[id_in_last_idx] = idx_in_level;
93        }
94        // idx is last, just remove
95        self.level_bucket[level].pop();
96        self.rev_level_bucket[id] = 0;
97    }
98
99    pub fn remove(&mut self, id: usize) -> f64 {
100        assert!(self.weights[id] > 0., "removing element {id} with 0 weight");
101        let weight = self.weights[id];
102        self.weights[id] = 0.;
103        self.total_weight -= weight;
104        let level = self.level(weight);
105        self.remove_from_level(id, level, weight);
106        weight
107    }
108
109    pub fn update(&mut self, id: usize, new_weight: f64) {
110        if self.get_weight(id) == new_weight {
111            // nothing to do
112            return;
113        }
114        if new_weight == 0. {
115            // remove it completely if the weight is 0
116            self.remove(id);
117            return
118        }
119        let curr_weight = self.weights[id];
120        if curr_weight == 0. {
121            // if the previous weight was 0, just insert it
122            self.insert(id, new_weight);
123            return;
124        }
125        // otherwise update the weight
126        let curr_level = self.level(curr_weight);
127        let new_level = self.level(new_weight);
128        // Update the weight at the global level
129        self.total_weight += new_weight - curr_weight;
130        self.weights[id] = new_weight;
131        if curr_level == new_level {
132            // If the level didn't change, just update the level's weight
133            self.update_weight_in_level(curr_level, curr_weight, new_weight);
134        } else {
135            // Otherwise, remove the element from the current level (if any)
136            if curr_weight > 0. {
137                self.remove_from_level(id, curr_level, curr_weight);
138            }
139            // and insert it to the new level (if any)
140            if new_weight > 0. {
141                self.insert_to_level(id, new_level, new_weight);
142            }
143        }
144    }
145
146    pub fn update_delta(&mut self, id: usize, delta: f64) {
147        let new_weight = self.weights.get(id).unwrap_or(&0.) + delta;
148        self.update(id, new_weight);
149    }
150
151    #[inline(always)]
152    fn update_weight_in_level(&mut self, level: usize, curr_weight: f64, new_weight: f64) {
153        self.level_weight[level] += new_weight - curr_weight;
154    }
155
156    #[inline(always)]
157    pub fn get_weight(&self, id: usize) -> f64 {
158        self.weights[id]
159    }
160
161    #[inline(always)]
162    pub fn get_total_weight(&self) -> f64 {
163        self.total_weight
164    }
165
166    pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<usize> {
167        assert!(self.total_weight <= self.max_value, "weighted sampler total weight {} is bigger than max weight {}.", self.total_weight, self.max_value);
168        let levels_sampler = WeightedIndex::new(self.level_weight.iter().copied()).ok()?;
169        let level = levels_sampler.sample(rng);
170
171        loop {
172            let idx_in_level = (0..self.level_bucket[level].len()).choose(rng).unwrap();
173            let sampled_id = self.level_bucket[level][idx_in_level];
174            let weight = self.weights[sampled_id];
175            debug_assert!(weight <= self.level_max[level] && (level == self.n_levels - 1 || (self.level_max[level+1] < weight )));
176            let u = rng.random::<f64>() * self.level_max[level];
177            if u <= weight {
178                break Some(sampled_id);
179            }
180        }
181    }
182
183    pub fn check_invariant(&self) -> bool {
184        self.level_weight.iter().sum::<f64>() == self.total_weight &&
185        self.weights.iter().sum::<f64>() == self.total_weight &&
186            self.total_weight <= self.max_value
187    }
188}
189
190fn log2_ceil2(weight: f64) -> usize {
191    let b: u64 = weight.to_bits();
192    // let s = (b >> 63) & 1;
193    let e = (b >> 52) & ((1<<11)-1);
194    let frac = b & ((1<<52) -1);
195    let z = if frac==0 { e as i64 - 1023 } else { e as i64 -1022 };
196    z as usize
197}
198
199fn _log2_ceil(weight: f64) -> usize {
200    // Define a lookup table with the first 34 powers of two, starting from 1.0
201    let lookup_table: [f64; 34] = [
202        1.0,          // ceil(log2(weight)) == 0 for (0, 1.0]
203        2.0,          // ceil(log2(weight)) == 1 for (1.0, 2.0]
204        4.0,          // ceil(log2(weight)) == 2 for (2.0, 4.0]
205        8.0,          // ceil(log2(weight)) == 3 for (4.0, 8.0]
206        16.0,         // ceil(log2(weight)) == 4 for (8.0, 16.0]
207        32.0,         // ceil(log2(weight)) == 5 for (16.0, 32.0]
208        64.0,         // ceil(log2(weight)) == 6 for (32.0, 64.0]
209        128.0,        // ceil(log2(weight)) == 7 for (64.0, 128.0]
210        256.0,        // ceil(log2(weight)) == 8 for (128.0, 256.0]
211        512.0,        // ceil(log2(weight)) == 9 for (256.0, 512.0]
212        1024.0,       // ceil(log2(weight)) == 10 for (512.0, 1024.0]
213        2048.0,       // ceil(log2(weight)) == 11 for (1024.0, 2048.0]
214        4096.0,       // ceil(log2(weight)) == 12 for (2048.0, 4096.0]
215        8192.0,       // ceil(log2(weight)) == 13 for (4096.0, 8192.0]
216        16384.0,      // ceil(log2(weight)) == 14 for (8192.0, 16384.0]
217        32768.0,      // ceil(log2(weight)) == 15 for (16384.0, 32768.0]
218        65536.0,      // ceil(log2(weight)) == 16 for (32768.0, 65536.0]
219        131072.0,     // ceil(log2(weight)) == 17 for (65536.0, 131072.0]
220        262144.0,     // ceil(log2(weight)) == 18 for (131072.0, 262144.0]
221        524288.0,     // ceil(log2(weight)) == 19 for (262144.0, 524288.0]
222        1048576.0,    // ceil(log2(weight)) == 20 for (524288.0, 1048576.0]
223        2097152.0,    // ceil(log2(weight)) == 21 for (1048576.0, 2097152.0]
224        4194304.0,    // ceil(log2(weight)) == 22 for (2097152.0, 4194304.0]
225        8388608.0,    // ceil(log2(weight)) == 23 for (4194304.0, 8388608.0]
226        16777216.0,   // ceil(log2(weight)) == 24 for (8388608.0, 16777216.0]
227        33554432.0,   // ceil(log2(weight)) == 25 for (16777216.0, 33554432.0]
228        67108864.0,   // ceil(log2(weight)) == 26 for (33554432.0, 67108864.0]
229        134217728.0,  // ceil(log2(weight)) == 27 for (67108864.0, 134217728.0]
230        268435456.0,  // ceil(log2(weight)) == 28 for (134217728.0, 268435456.0]
231        536870912.0,  // ceil(log2(weight)) == 29 for (268435456.0, 536870912.0]
232        1073741824.0, // ceil(log2(weight)) == 30 for (536870912.0, 1073741824.0]
233        2147483648.0, // ceil(log2(weight)) == 31 for (1073741824.0, 2147483648.0]
234        4294967296.0, // ceil(log2(weight)) == 32 for (2147483648.0, 4294967296.0]
235        8589934592.0, // ceil(log2(weight)) == 33 for (4294967296.0, 8589934592.0]
236    ];
237
238    // Use binary search to find the index in the lookup table.
239    match lookup_table.binary_search_by(|&upper_bound| upper_bound.partial_cmp(&weight).unwrap()) {
240        Ok(index) => index as usize,        // Exact match found
241        Err(_) => weight.log2().ceil() as usize,       // No match, but `Err` gives the insertion point
242    }
243}
244
245#[cfg(test)]
246mod test_weighted_sampler {
247    use std::time::Instant;
248
249    use std::collections::HashMap;
250    use rand::rng;
251
252    use super::*;
253
254    #[test]
255    fn test_distr() {
256        let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
257        let mut samples: HashMap<usize, usize> = HashMap::new();
258
259        sampler.insert(1, 999.);
260        sampler.insert(2, 1.);
261
262        let n_samples = 1_000_000;
263        let start = Instant::now();
264        for _ in 1..n_samples {
265            let sample = sampler.sample(&mut rng()).unwrap();
266            *samples.entry(sample).or_default() += 1;
267        }
268        let duration = start.elapsed();
269
270        assert!(duration.as_secs() <= 3); // 2-3 microseconds per sample
271        approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.999, epsilon=1e-4);
272        approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.001, epsilon=1e-4);
273
274        println!("{:?}", sampler);
275        sampler.update(1, 99.);
276        println!("{:?}", sampler);
277
278        samples.drain();
279        let n_samples = 1_000;
280        for _ in 1..n_samples {
281            let sample = sampler.sample(&mut rng()).unwrap();
282            *samples.entry(sample).or_default() += 1;
283        }
284
285        approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.99, epsilon=1e-2);
286        approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.01, epsilon=1e-2);
287    }
288
289    #[test]
290    fn test_remove() {
291        let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
292        let level = sampler.level(500.);
293        sampler.insert(1, 500.);
294        assert_eq!(Some(&1), sampler.level_bucket[level].get(0));
295        sampler.insert(2, 510.);
296        assert_eq!(Some(&2), sampler.level_bucket[level].get(1));
297        sampler.remove(1);
298        assert_eq!(Some(&2), sampler.level_bucket[level].get(0));
299        sampler.insert(1, 500.);
300        assert_eq!(Some(&1), sampler.level_bucket[level].get(1));
301        sampler.remove(1);
302    }
303
304    #[test]
305    fn test_level() {
306        let sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
307        assert_eq!(11, sampler.n_levels);
308        assert_eq!(11-1, sampler.level(1.));
309        assert_eq!(11-2, sampler.level(2.));
310        assert_eq!(11-3, sampler.level(3.));
311        assert_eq!(11-3, sampler.level(4.));
312    }
313}