dynamic_weighted_sampler/
dynamic_weighted_sampler.rs1use rand::{distr::weighted::WeightedIndex, seq::IteratorRandom, Rng};
2use rand_distr::Distribution;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use sugars::cvec;
6
7const DEFAULT_CAPACITY: usize = 1000;
8
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct DynamicWeightedSampler {
12 max_value: f64,
13 n_levels: usize,
14 total_weight: f64,
15 weights: Vec<f64>,
16 level_weight: Vec<f64>,
17 level_bucket: Vec<Vec<usize>>,
18 rev_level_bucket: Vec<usize>, level_max: Vec<f64>,
20}
21
22impl DynamicWeightedSampler {
23 pub fn new(max_value: f64) -> Self {
24 Self::new_with_capacity(max_value, DEFAULT_CAPACITY)
25 }
26
27 pub fn new_with_capacity(max_value: f64, physical_capacity: usize) -> Self {
28 assert!(physical_capacity > 0);
29 let n_levels = max_value.log2().ceil() as usize + 1;
30 let max_value = 2f64.powf(max_value.log2().ceil());
31 let total_weight = 0.;
32 let weights = vec![0.; physical_capacity];
33 let level_weight = vec![0.; n_levels];
34 let level_bucket = vec![vec![]; n_levels];
35 let rev_level_bucket = vec![0; physical_capacity];
36 let top_level = n_levels - 1;
37 let level_max = cvec![2usize.pow(top_level as u32 - i) as f64; i in 0u32..(n_levels as u32)];
38 Self {
39 max_value,
40 n_levels,
41 total_weight,
42 weights,
43 level_weight,
44 level_bucket,
45 rev_level_bucket,
46 level_max,
47 }
48 }
49
50 pub fn insert(&mut self, id: usize, weight: f64) {
51 assert!(weight > 0.);
52 if id > self.weights.len() - 1 {
53 self.weights.resize(id + 1, 0.);
54 self.rev_level_bucket.resize(id + 1, 0);
55 }
56 assert!(self.weights[id] == 0., "Inserting element id {id} with weight {weight}, but it already existed with weight {}", self.weights[id]);
57 assert!(weight <= self.max_value, "Adding element {id} with weight {weight} exceeds the maximum weight capacity of {}", self.max_value);
58 self.weights[id] = weight;
59 self.total_weight += weight;
60 let level = self.level(weight);
61 self.insert_to_level(id, level, weight)
62 }
63
64 #[inline(always)]
65 fn level(&self, weight: f64) -> usize {
66 assert!(weight <= self.max_value, "{weight} > {}", self.max_value);
67 assert!(weight > 0.);
68 let top_level = self.n_levels - 1;
69 let level_from_top = log2_ceil2(weight);
70 assert!(top_level >= level_from_top);
71 let level = top_level - level_from_top;
72 level
73 }
74
75 #[inline(always)]
76 fn insert_to_level(&mut self, id: usize, level: usize, weight: f64) {
77 self.level_weight[level] += weight;
78 self.level_bucket[level].push(id);
79 self.rev_level_bucket[id] = self.level_bucket[level].len() - 1;
80 }
81
82 #[inline(always)]
83 fn remove_from_level(&mut self, id: usize, level: usize, weight: f64) {
84 debug_assert_eq!(self.level_bucket[level][self.rev_level_bucket[id]], id);
85 self.level_weight[level] -= weight;
86 let idx_in_level = self.rev_level_bucket[id];
87 let last_idx_in_level = self.level_bucket[level].len() - 1;
88 if idx_in_level != last_idx_in_level {
89 let id_in_last_idx = self.level_bucket[level][last_idx_in_level];
91 self.level_bucket[level].swap(idx_in_level, last_idx_in_level);
92 self.rev_level_bucket[id_in_last_idx] = idx_in_level;
93 }
94 self.level_bucket[level].pop();
96 self.rev_level_bucket[id] = 0;
97 }
98
99 pub fn remove(&mut self, id: usize) -> f64 {
100 assert!(self.weights[id] > 0., "removing element {id} with 0 weight");
101 let weight = self.weights[id];
102 self.weights[id] = 0.;
103 self.total_weight -= weight;
104 let level = self.level(weight);
105 self.remove_from_level(id, level, weight);
106 weight
107 }
108
109 pub fn update(&mut self, id: usize, new_weight: f64) {
110 if self.get_weight(id) == new_weight {
111 return;
113 }
114 if new_weight == 0. {
115 self.remove(id);
117 return
118 }
119 let curr_weight = self.weights[id];
120 if curr_weight == 0. {
121 self.insert(id, new_weight);
123 return;
124 }
125 let curr_level = self.level(curr_weight);
127 let new_level = self.level(new_weight);
128 self.total_weight += new_weight - curr_weight;
130 self.weights[id] = new_weight;
131 if curr_level == new_level {
132 self.update_weight_in_level(curr_level, curr_weight, new_weight);
134 } else {
135 if curr_weight > 0. {
137 self.remove_from_level(id, curr_level, curr_weight);
138 }
139 if new_weight > 0. {
141 self.insert_to_level(id, new_level, new_weight);
142 }
143 }
144 }
145
146 pub fn update_delta(&mut self, id: usize, delta: f64) {
147 let new_weight = self.weights.get(id).unwrap_or(&0.) + delta;
148 self.update(id, new_weight);
149 }
150
151 #[inline(always)]
152 fn update_weight_in_level(&mut self, level: usize, curr_weight: f64, new_weight: f64) {
153 self.level_weight[level] += new_weight - curr_weight;
154 }
155
156 #[inline(always)]
157 pub fn get_weight(&self, id: usize) -> f64 {
158 self.weights[id]
159 }
160
161 #[inline(always)]
162 pub fn get_total_weight(&self) -> f64 {
163 self.total_weight
164 }
165
166 pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<usize> {
167 assert!(self.total_weight <= self.max_value, "weighted sampler total weight {} is bigger than max weight {}.", self.total_weight, self.max_value);
168 let levels_sampler = WeightedIndex::new(self.level_weight.iter().copied()).ok()?;
169 let level = levels_sampler.sample(rng);
170
171 loop {
172 let idx_in_level = (0..self.level_bucket[level].len()).choose(rng).unwrap();
173 let sampled_id = self.level_bucket[level][idx_in_level];
174 let weight = self.weights[sampled_id];
175 debug_assert!(weight <= self.level_max[level] && (level == self.n_levels - 1 || (self.level_max[level+1] < weight )));
176 let u = rng.random::<f64>() * self.level_max[level];
177 if u <= weight {
178 break Some(sampled_id);
179 }
180 }
181 }
182
183 pub fn check_invariant(&self) -> bool {
184 self.level_weight.iter().sum::<f64>() == self.total_weight &&
185 self.weights.iter().sum::<f64>() == self.total_weight &&
186 self.total_weight <= self.max_value
187 }
188}
189
190fn log2_ceil2(weight: f64) -> usize {
191 let b: u64 = weight.to_bits();
192 let e = (b >> 52) & ((1<<11)-1);
194 let frac = b & ((1<<52) -1);
195 let z = if frac==0 { e as i64 - 1023 } else { e as i64 -1022 };
196 z as usize
197}
198
199fn _log2_ceil(weight: f64) -> usize {
200 let lookup_table: [f64; 34] = [
202 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 524288.0, 1048576.0, 2097152.0, 4194304.0, 8388608.0, 16777216.0, 33554432.0, 67108864.0, 134217728.0, 268435456.0, 536870912.0, 1073741824.0, 2147483648.0, 4294967296.0, 8589934592.0, ];
237
238 match lookup_table.binary_search_by(|&upper_bound| upper_bound.partial_cmp(&weight).unwrap()) {
240 Ok(index) => index as usize, Err(_) => weight.log2().ceil() as usize, }
243}
244
245#[cfg(test)]
246mod test_weighted_sampler {
247 use std::time::Instant;
248
249 use std::collections::HashMap;
250 use rand::rng;
251
252 use super::*;
253
254 #[test]
255 fn test_distr() {
256 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
257 let mut samples: HashMap<usize, usize> = HashMap::new();
258
259 sampler.insert(1, 999.);
260 sampler.insert(2, 1.);
261
262 let n_samples = 1_000_000;
263 let start = Instant::now();
264 for _ in 1..n_samples {
265 let sample = sampler.sample(&mut rng()).unwrap();
266 *samples.entry(sample).or_default() += 1;
267 }
268 let duration = start.elapsed();
269
270 assert!(duration.as_secs() <= 3); approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.999, epsilon=1e-4);
272 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.001, epsilon=1e-4);
273
274 println!("{:?}", sampler);
275 sampler.update(1, 99.);
276 println!("{:?}", sampler);
277
278 samples.drain();
279 let n_samples = 1_000;
280 for _ in 1..n_samples {
281 let sample = sampler.sample(&mut rng()).unwrap();
282 *samples.entry(sample).or_default() += 1;
283 }
284
285 approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.99, epsilon=1e-2);
286 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.01, epsilon=1e-2);
287 }
288
289 #[test]
290 fn test_remove() {
291 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
292 let level = sampler.level(500.);
293 sampler.insert(1, 500.);
294 assert_eq!(Some(&1), sampler.level_bucket[level].get(0));
295 sampler.insert(2, 510.);
296 assert_eq!(Some(&2), sampler.level_bucket[level].get(1));
297 sampler.remove(1);
298 assert_eq!(Some(&2), sampler.level_bucket[level].get(0));
299 sampler.insert(1, 500.);
300 assert_eq!(Some(&1), sampler.level_bucket[level].get(1));
301 sampler.remove(1);
302 }
303
304 #[test]
305 fn test_level() {
306 let sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
307 assert_eq!(11, sampler.n_levels);
308 assert_eq!(11-1, sampler.level(1.));
309 assert_eq!(11-2, sampler.level(2.));
310 assert_eq!(11-3, sampler.level(3.));
311 assert_eq!(11-3, sampler.level(4.));
312 }
313}