dynamic_weighted_sampler/
dynamic_weighted_sampler.rs1use rand::{distr::weighted::WeightedIndex, seq::IteratorRandom, Rng};
2use rand_distr::Distribution;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use sugars::cvec;
6
7const DEFAULT_CAPACITY: usize = 1000;
8
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct DynamicWeightedSampler {
12 max_value: f64,
13 n_levels: usize,
14 total_weight: f64,
15 weights: Vec<f64>,
16 level_weight: Vec<f64>,
17 level_bucket: Vec<Vec<usize>>,
18 rev_level_bucket: Vec<usize>, level_max: Vec<f64>,
20}
21
22impl DynamicWeightedSampler {
23 pub fn new(max_value: f64) -> Self {
24 Self::new_with_capacity(max_value, DEFAULT_CAPACITY)
25 }
26
27 pub fn new_with_capacity(max_value: f64, physical_capacity: usize) -> Self {
28 assert!(physical_capacity > 0);
29 let n_levels = max_value.log2().ceil() as usize + 1;
30 let max_value = 2f64.powf(max_value.log2().ceil());
31 let total_weight = 0.;
32 let weights = vec![0.; physical_capacity];
33 let level_weight = vec![0.; n_levels];
34 let level_bucket = vec![vec![]; n_levels];
35 let rev_level_bucket = vec![0; physical_capacity];
36 let top_level = n_levels - 1;
37 let level_max = cvec![2usize.pow(top_level as u32 - i) as f64; i in 0u32..(n_levels as u32)];
38 Self {
39 max_value,
40 n_levels,
41 total_weight,
42 weights,
43 level_weight,
44 level_bucket,
45 rev_level_bucket,
46 level_max,
47 }
48 }
49
50 pub fn insert(&mut self, id: usize, weight: f64) {
51 assert!(weight > 0.);
52 if id > self.weights.len() - 1 {
53 self.weights.resize(id + 1, 0.);
54 self.rev_level_bucket.resize(id + 1, 0);
55 }
56 assert!(self.weights[id] == 0., "Inserting element id {id} with weight {weight}, but it already existed with weight {}", self.weights[id]);
57 assert!(weight <= self.max_value, "Adding element {id} with weight {weight} exceeds the maximum weight capacity of {}", self.max_value);
58 self.weights[id] = weight;
59 self.total_weight += weight;
60 let level = self.level(weight);
61 self.insert_to_level(id, level, weight)
62 }
63
64 #[inline]
65 fn level(&self, weight: f64) -> usize {
66 assert!(weight <= self.max_value, "{weight} > {}", self.max_value);
67 assert!(weight > 0.);
68 let top_level = self.n_levels - 1;
69 let level_from_top = log2_ceil2(weight);
70 assert!(top_level >= level_from_top);
71 let level = top_level - level_from_top;
72 level
73 }
74
75 #[inline]
76 fn insert_to_level(&mut self, id: usize, level: usize, weight: f64) {
77 self.level_weight[level] += weight;
78 self.level_bucket[level].push(id);
79 self.rev_level_bucket[id] = self.level_bucket[level].len() - 1;
80 }
81
82 #[inline]
83 fn remove_from_level(&mut self, id: usize, level: usize, weight: f64) {
84 debug_assert_eq!(self.level_bucket[level][self.rev_level_bucket[id]], id);
85 self.level_weight[level] -= weight;
86 let idx_in_level = self.rev_level_bucket[id];
87 let last_idx_in_level = self.level_bucket[level].len() - 1;
88 if idx_in_level != last_idx_in_level {
89 let id_in_last_idx = self.level_bucket[level][last_idx_in_level];
91 self.level_bucket[level].swap(idx_in_level, last_idx_in_level);
92 self.rev_level_bucket[id_in_last_idx] = idx_in_level;
93 }
94 self.level_bucket[level].pop();
96 self.rev_level_bucket[id] = 0;
97 }
98
99 pub fn remove(&mut self, id: usize) -> f64 {
100 assert!(self.weights[id] > 0., "removing element {id} with 0 weight");
101 let weight = self.weights[id];
102 self.weights[id] = 0.;
103 self.total_weight -= weight;
104 let level = self.level(weight);
105 self.remove_from_level(id, level, weight);
106 weight
107 }
108
109 pub fn update(&mut self, id: usize, new_weight: f64) {
110 if self.get_weight(id) == new_weight {
111 return;
113 }
114 if new_weight == 0. {
115 self.remove(id);
117 return
118 }
119 let curr_weight = self.weights[id];
120 if curr_weight == 0. {
121 self.insert(id, new_weight);
123 return;
124 }
125 let curr_level = self.level(curr_weight);
127 let new_level = self.level(new_weight);
128 self.total_weight += new_weight - curr_weight;
130 self.weights[id] = new_weight;
131 if curr_level == new_level {
132 self.update_weight_in_level(curr_level, curr_weight, new_weight);
134 } else {
135 self.remove_from_level(id, curr_level, curr_weight);
137 self.insert_to_level(id, new_level, new_weight);
139 }
140 }
141
142 pub fn update_delta(&mut self, id: usize, delta: f64) {
143 let new_weight = self.weights.get(id).unwrap_or(&0.) + delta;
144 self.update(id, new_weight);
145 }
146
147 #[inline]
148 fn update_weight_in_level(&mut self, level: usize, curr_weight: f64, new_weight: f64) {
149 self.level_weight[level] += new_weight - curr_weight;
150 }
151
152 #[inline]
153 pub fn get_weight(&self, id: usize) -> f64 {
154 self.weights[id]
155 }
156
157 #[inline]
158 pub fn get_total_weight(&self) -> f64 {
159 self.total_weight
160 }
161
162 pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<usize> {
163 assert!(self.total_weight <= self.max_value, "weighted sampler total weight {} is bigger than max weight {}.", self.total_weight, self.max_value);
164 let levels_sampler = WeightedIndex::new(self.level_weight.iter().copied()).ok()?;
165 let level = levels_sampler.sample(rng);
166
167 loop {
168 let idx_in_level = (0..self.level_bucket[level].len()).choose(rng).unwrap();
169 let sampled_id = self.level_bucket[level][idx_in_level];
170 let weight = self.weights[sampled_id];
171 debug_assert!(weight <= self.level_max[level] && (level == self.n_levels - 1 || (self.level_max[level+1] < weight )));
172 let u = rng.random::<f64>() * self.level_max[level];
173 if u <= weight {
174 break Some(sampled_id);
175 }
176 }
177 }
178
179 pub fn check_invariant(&self) -> bool {
180 self.level_weight.iter().sum::<f64>() == self.total_weight &&
181 self.weights.iter().sum::<f64>() == self.total_weight &&
182 self.total_weight <= self.max_value
183 }
184}
185
186fn log2_ceil2(weight: f64) -> usize {
187 let b: u64 = weight.to_bits();
188 let e = (b >> 52) & ((1<<11)-1);
190 let frac = b & ((1<<52) -1);
191 let z = if frac==0 { e as i64 - 1023 } else { e as i64 -1022 };
192 z as usize
193}
194
195fn _log2_ceil(weight: f64) -> usize {
196 let lookup_table: [f64; 34] = [
198 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 524288.0, 1048576.0, 2097152.0, 4194304.0, 8388608.0, 16777216.0, 33554432.0, 67108864.0, 134217728.0, 268435456.0, 536870912.0, 1073741824.0, 2147483648.0, 4294967296.0, 8589934592.0, ];
233
234 match lookup_table.binary_search_by(|&upper_bound| upper_bound.partial_cmp(&weight).unwrap()) {
236 Ok(index) => index as usize, Err(_) => weight.log2().ceil() as usize, }
239}
240
241#[cfg(test)]
242mod test_weighted_sampler {
243 use std::time::Instant;
244
245 use std::collections::HashMap;
246 use rand::rng;
247
248 use super::*;
249
250 #[test]
251 fn test_distr() {
252 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
253 let mut samples: HashMap<usize, usize> = HashMap::new();
254
255 sampler.insert(1, 999.);
256 sampler.insert(2, 1.);
257
258 let n_samples = 1_000_000;
259 let start = Instant::now();
260 for _ in 1..n_samples {
261 let sample = sampler.sample(&mut rng()).unwrap();
262 *samples.entry(sample).or_default() += 1;
263 }
264 let duration = start.elapsed();
265
266 assert!(duration.as_secs() <= 3); approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.999, epsilon=1e-4);
268 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.001, epsilon=1e-4);
269
270 println!("{:?}", sampler);
271 sampler.update(1, 99.);
272 println!("{:?}", sampler);
273
274 samples.drain();
275 let n_samples = 1_000;
276 for _ in 1..n_samples {
277 let sample = sampler.sample(&mut rng()).unwrap();
278 *samples.entry(sample).or_default() += 1;
279 }
280
281 approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.99, epsilon=1e-2);
282 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.01, epsilon=1e-2);
283 }
284
285 #[test]
286 fn test_remove() {
287 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
288 let level = sampler.level(500.);
289 sampler.insert(1, 500.);
290 assert_eq!(Some(&1), sampler.level_bucket[level].get(0));
291 sampler.insert(2, 510.);
292 assert_eq!(Some(&2), sampler.level_bucket[level].get(1));
293 sampler.remove(1);
294 assert_eq!(Some(&2), sampler.level_bucket[level].get(0));
295 sampler.insert(1, 500.);
296 assert_eq!(Some(&1), sampler.level_bucket[level].get(1));
297 sampler.remove(1);
298 }
299
300 #[test]
301 fn test_level() {
302 let sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
303 assert_eq!(11, sampler.n_levels);
304 assert_eq!(11-1, sampler.level(1.));
305 assert_eq!(11-2, sampler.level(2.));
306 assert_eq!(11-3, sampler.level(3.));
307 assert_eq!(11-3, sampler.level(4.));
308 }
309}