dynamic_weighted_sampler/
dynamic_weighted_sampler.rs1use rand::{distr::weighted::WeightedIndex, seq::IteratorRandom, Rng};
2use rand_distr::Distribution;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use sugars::cvec;
6
7const DEFAULT_CAPACITY: usize = 1000;
8
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct DynamicWeightedSampler {
12 max_value: f64,
13 n_levels: usize,
14 total_weight: f64,
15 weights: Vec<f64>,
16 level_weight: Vec<f64>,
17 level_bucket: Vec<Vec<usize>>,
18 rev_level_bucket: Vec<usize>, level_max: Vec<f64>,
20}
21
22impl DynamicWeightedSampler {
23 pub fn new(max_value: f64) -> Self {
24 Self::new_with_capacity(max_value, DEFAULT_CAPACITY)
25 }
26
27 pub fn new_with_capacity(max_value: f64, physical_capacity: usize) -> Self {
28 assert!(physical_capacity > 0);
29 let n_levels = max_value.log2().ceil() as usize + 1;
30 let max_value = 2f64.powf(max_value.log2().ceil());
31 let total_weight = 0.;
32 let weights = vec![0.; physical_capacity];
33 let level_weight = vec![0.; n_levels];
34 let level_bucket = vec![vec![]; n_levels];
35 let rev_level_bucket = vec![0; physical_capacity];
36 let top_level = n_levels - 1;
37 let level_max = cvec![2usize.pow(top_level as u32 - i) as f64; i in 0u32..(n_levels as u32)];
38 Self {
39 max_value,
40 n_levels,
41 total_weight,
42 weights,
43 level_weight,
44 level_bucket,
45 rev_level_bucket,
46 level_max,
47 }
48 }
49
50 pub fn insert(&mut self, id: usize, weight: f64) {
51 assert!(weight > 0.);
52 if id > self.weights.len() - 1 {
53 self.weights.resize(id + 1, 0.);
54 self.rev_level_bucket.resize(id + 1, 0);
55 }
56 assert!(self.weights[id] == 0., "Inserting element id {id} with weight {weight}, but it already existed with weight {}", self.weights[id]);
57 assert!(weight <= self.max_value, "Adding element {id} with weight {weight} exceeds the maximum weight capacity of {}", self.max_value);
58 self.weights[id] = weight;
59 self.total_weight += weight;
60 let level = self.level(weight);
61 self.insert_to_level(id, level, weight)
62 }
63
64 fn level(&self, weight: f64) -> usize {
65 assert!(weight <= self.max_value, "{weight} > {}", self.max_value);
66 assert!(weight > 0.);
67 let top_level = self.n_levels - 1;
68 let level_from_top = log2_ceil2(weight);
69 assert!(top_level >= level_from_top);
70 let level = top_level - level_from_top;
71 level
72 }
73
74 #[inline(always)]
75 fn insert_to_level(&mut self, id: usize, level: usize, weight: f64) {
76 self.level_weight[level] += weight;
77 self.level_bucket[level].push(id);
78 self.rev_level_bucket[id] = self.level_bucket[level].len() - 1;
79 }
80
81 #[inline(always)]
82 fn remove_from_level(&mut self, id: usize, level: usize, weight: f64) {
83 assert_eq!(self.level_bucket[level][self.rev_level_bucket[id]], id);
84 self.level_weight[level] -= weight;
85 let idx_in_level = self.rev_level_bucket[id];
86 let last_idx_in_level = self.level_bucket[level].len() - 1;
87 if idx_in_level != last_idx_in_level {
88 let id_in_last_idx = self.level_bucket[level][last_idx_in_level];
90 self.level_bucket[level].swap(idx_in_level, last_idx_in_level);
91 self.rev_level_bucket[id_in_last_idx] = idx_in_level;
92 }
93 self.level_bucket[level].pop();
95 self.rev_level_bucket[id] = 0;
96 }
97
98 pub fn remove(&mut self, id: usize) -> f64 {
99 assert!(self.weights[id] > 0., "removing element {id} with 0 weight");
100 let weight = self.weights[id];
101 self.weights[id] = 0.;
102 self.total_weight -= weight;
103 let level = self.level(weight);
104 self.remove_from_level(id, level, weight);
105 weight
106 }
107
108 pub fn update(&mut self, id: usize, new_weight: f64) {
109 if new_weight == 0. {
110 self.remove(id);
112 return
113 }
114 let curr_weight = self.weights[id];
115 if curr_weight == 0. {
116 self.insert(id, new_weight);
118 return;
119 }
120 let curr_level = self.level(curr_weight);
122 let new_level = self.level(new_weight);
123 self.total_weight += new_weight - curr_weight;
125 self.weights[id] = new_weight;
126 if curr_level == new_level {
127 self.update_weight_in_level(curr_level, curr_weight, new_weight);
129 } else {
130 if curr_weight > 0. {
132 self.remove_from_level(id, curr_level, curr_weight);
133 }
134 if new_weight > 0. {
136 self.insert_to_level(id, new_level, new_weight);
137 }
138 }
139 }
140
141 pub fn update_delta(&mut self, id: usize, delta: f64) {
142 let new_weight = self.weights.get(id).unwrap_or(&0.) + delta;
143 self.update(id, new_weight);
144 }
145
146 #[inline(always)]
147 fn update_weight_in_level(&mut self, level: usize, curr_weight: f64, new_weight: f64) {
148 self.level_weight[level] += new_weight - curr_weight;
149 }
150
151 pub fn get_weight(&self, id: usize) -> f64 {
152 self.weights[id]
153 }
154
155 pub fn get_total_weight(&self) -> f64 {
156 self.total_weight
157 }
158
159 pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
160 assert!(self.total_weight <= self.max_value, "weighted sampler total weight {} is bigger than max weight {}.", self.total_weight, self.max_value);
161 let levels_sampler = WeightedIndex::new(self.level_weight.iter().copied()).unwrap();
162 let level = levels_sampler.sample(rng);
163
164 loop {
165 let idx_in_level = (0..self.level_bucket[level].len()).choose(rng).unwrap();
166 let sampled_id = self.level_bucket[level][idx_in_level];
167 let weight = self.weights[sampled_id];
168 debug_assert!(weight <= self.level_max[level] && (level == self.n_levels - 1 || (self.level_max[level+1] < weight )));
169 let u = rng.random::<f64>() * self.level_max[level];
170 if u <= weight {
171 break sampled_id;
172 }
173 }
174 }
175
176 pub fn check_invariant(&self) -> bool {
177 self.level_weight.iter().sum::<f64>() == self.total_weight &&
178 self.weights.iter().sum::<f64>() == self.total_weight &&
179 self.total_weight <= self.max_value
180 }
181}
182
183fn log2_ceil2(weight: f64) -> usize {
184 let b: u64 = weight.to_bits();
185 let e = (b >> 52) & ((1<<11)-1);
187 let frac = b & ((1<<52) -1);
188 let z = if frac==0 { e as i64 - 1023 } else { e as i64 -1022 };
189 z as usize
190}
191
192fn _log2_ceil(weight: f64) -> usize {
193 let lookup_table: [f64; 34] = [
195 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 524288.0, 1048576.0, 2097152.0, 4194304.0, 8388608.0, 16777216.0, 33554432.0, 67108864.0, 134217728.0, 268435456.0, 536870912.0, 1073741824.0, 2147483648.0, 4294967296.0, 8589934592.0, ];
230
231 match lookup_table.binary_search_by(|&upper_bound| upper_bound.partial_cmp(&weight).unwrap()) {
233 Ok(index) => index as usize, Err(_) => weight.log2().ceil() as usize, }
236}
237
238#[cfg(test)]
239mod test_weighted_sampler {
240 use std::time::Instant;
241
242 use std::collections::HashMap;
243 use rand::rng;
244
245 use super::*;
246
247 #[test]
248 fn test_distr() {
249 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
250 let mut samples: HashMap<usize, usize> = HashMap::new();
251
252 sampler.insert(1, 999.);
253 sampler.insert(2, 1.);
254
255 let n_samples = 1_000_000;
256 let start = Instant::now();
257 for _ in 1..n_samples {
258 let sample = sampler.sample(&mut rng());
259 *samples.entry(sample).or_default() += 1;
260 }
261 let duration = start.elapsed();
262
263 assert!(duration.as_secs() <= 3); approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.999, epsilon=1e-4);
265 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.001, epsilon=1e-4);
266
267 println!("{:?}", sampler);
268 sampler.update(1, 99.);
269 println!("{:?}", sampler);
270
271 samples.drain();
272 let n_samples = 1_000;
273 for _ in 1..n_samples {
274 let sample = sampler.sample(&mut rng());
275 *samples.entry(sample).or_default() += 1;
276 }
277
278 approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.99, epsilon=1e-2);
279 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.01, epsilon=1e-2);
280 }
281
282 #[test]
283 fn test_remove() {
284 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
285 let level = sampler.level(500.);
286 sampler.insert(1, 500.);
287 assert_eq!(Some(&1), sampler.level_bucket[level].get(0));
288 sampler.insert(2, 510.);
289 assert_eq!(Some(&2), sampler.level_bucket[level].get(1));
290 sampler.remove(1);
291 assert_eq!(Some(&2), sampler.level_bucket[level].get(0));
292 sampler.insert(1, 500.);
293 assert_eq!(Some(&1), sampler.level_bucket[level].get(1));
294 sampler.remove(1);
295 }
296
297 #[test]
298 fn test_level() {
299 let sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
300 assert_eq!(11, sampler.n_levels);
301 assert_eq!(11-1, sampler.level(1.));
302 assert_eq!(11-2, sampler.level(2.));
303 assert_eq!(11-3, sampler.level(3.));
304 assert_eq!(11-3, sampler.level(4.));
305 }
306}