dynamic_weighted_sampler/
dynamic_weighted_sampler.rs1use rand::{distr::weighted::WeightedIndex, seq::IteratorRandom, Rng};
2use rand_distr::Distribution;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use sugars::cvec;
6
7const DEFAULT_CAPACITY: usize = 1000;
8
9#[derive(Debug, Clone)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11pub struct DynamicWeightedSampler {
12 max_value: f64,
13 n_levels: usize,
14 total_weight: f64,
15 weights: Vec<f64>,
16 level_weight: Vec<f64>,
17 level_bucket: Vec<Vec<usize>>,
18 rev_level_bucket: Vec<usize>, level_max: Vec<f64>,
20}
21
22impl DynamicWeightedSampler {
23 pub fn new(max_value: f64) -> Self {
24 Self::new_with_capacity(max_value, DEFAULT_CAPACITY)
25 }
26
27 pub fn new_with_capacity(max_value: f64, physical_capacity: usize) -> Self {
28 assert!(physical_capacity > 0);
29 let n_levels = max_value.log2().ceil() as usize + 1;
30 let max_value = 2f64.powf(max_value.log2().ceil());
31 let total_weight = 0.;
32 let weights = vec![0.; physical_capacity];
33 let level_weight = vec![0.; n_levels];
34 let level_bucket = vec![vec![]; n_levels];
35 let rev_level_bucket = vec![0; physical_capacity];
36 let top_level = n_levels - 1;
37 let level_max = cvec![2usize.pow(top_level as u32 - i) as f64; i in 0u32..(n_levels as u32)];
38 Self {
39 max_value,
40 n_levels,
41 total_weight,
42 weights,
43 level_weight,
44 level_bucket,
45 rev_level_bucket,
46 level_max,
47 }
48 }
49
50 pub fn insert(&mut self, id: usize, weight: f64) {
51 assert!(weight > 0.);
52 if id > self.weights.len() - 1 {
53 self.weights.resize(id + 1, 0.);
54 self.rev_level_bucket.resize(id + 1, 0);
55 }
56 assert!(self.weights[id] == 0., "Inserting element id {id} with weight {weight}, but it already existed with weight {}", self.weights[id]);
57 assert!(weight <= self.max_value, "Adding element {id} with weight {weight} exceeds the maximum weight capacity of {}", self.max_value);
58 self.weights[id] = weight;
59 self.total_weight += weight;
60 let level = self.level(weight);
61 self.insert_to_level(id, level, weight)
62 }
63
64 fn level(&self, weight: f64) -> usize {
65 assert!(weight <= self.max_value, "{weight} > {}", self.max_value);
66 assert!(weight > 0.);
67 let top_level = self.n_levels - 1;
68 let level_from_top = log2_ceil2(weight);
69 assert!(top_level >= level_from_top);
70 let level = top_level - level_from_top;
71 level
72 }
73
74 #[inline(always)]
75 fn insert_to_level(&mut self, id: usize, level: usize, weight: f64) {
76 self.level_weight[level] += weight;
77 self.level_bucket[level].push(id);
78 self.rev_level_bucket[id] = self.level_bucket[level].len() - 1;
79 }
80
81 #[inline(always)]
82 fn remove_from_level(&mut self, id: usize, level: usize, weight: f64) {
83 assert_eq!(self.level_bucket[level][self.rev_level_bucket[id]], id);
84 self.level_weight[level] -= weight;
85 let idx_in_level = self.rev_level_bucket[id];
86 let last_idx_in_level = self.level_bucket[level].len() - 1;
87 if idx_in_level != last_idx_in_level {
88 let id_in_last_idx = self.level_bucket[level][last_idx_in_level];
90 self.level_bucket[level].swap(idx_in_level, last_idx_in_level);
91 self.rev_level_bucket[id_in_last_idx] = idx_in_level;
92 }
93 self.level_bucket[level].pop();
95 self.rev_level_bucket[id] = 0;
96 }
97
98 pub fn remove(&mut self, id: usize) -> f64 {
99 assert!(self.weights[id] > 0., "removing element {id} with 0 weight");
100 let weight = self.weights[id];
101 self.weights[id] = 0.;
102 self.total_weight -= weight;
103 let level = self.level(weight);
104 self.remove_from_level(id, level, weight);
105 weight
106 }
107
108 pub fn update(&mut self, id: usize, new_weight: f64) {
109 if self.get_weight(id) == new_weight {
110 return;
112 }
113 if new_weight == 0. {
114 self.remove(id);
116 return
117 }
118 let curr_weight = self.weights[id];
119 if curr_weight == 0. {
120 self.insert(id, new_weight);
122 return;
123 }
124 let curr_level = self.level(curr_weight);
126 let new_level = self.level(new_weight);
127 self.total_weight += new_weight - curr_weight;
129 self.weights[id] = new_weight;
130 if curr_level == new_level {
131 self.update_weight_in_level(curr_level, curr_weight, new_weight);
133 } else {
134 if curr_weight > 0. {
136 self.remove_from_level(id, curr_level, curr_weight);
137 }
138 if new_weight > 0. {
140 self.insert_to_level(id, new_level, new_weight);
141 }
142 }
143 }
144
145 pub fn update_delta(&mut self, id: usize, delta: f64) {
146 let new_weight = self.weights.get(id).unwrap_or(&0.) + delta;
147 self.update(id, new_weight);
148 }
149
150 #[inline(always)]
151 fn update_weight_in_level(&mut self, level: usize, curr_weight: f64, new_weight: f64) {
152 self.level_weight[level] += new_weight - curr_weight;
153 }
154
155 pub fn get_weight(&self, id: usize) -> f64 {
156 self.weights[id]
157 }
158
159 pub fn get_total_weight(&self) -> f64 {
160 self.total_weight
161 }
162
163 pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
164 assert!(self.total_weight <= self.max_value, "weighted sampler total weight {} is bigger than max weight {}.", self.total_weight, self.max_value);
165 let levels_sampler = WeightedIndex::new(self.level_weight.iter().copied()).unwrap();
166 let level = levels_sampler.sample(rng);
167
168 loop {
169 let idx_in_level = (0..self.level_bucket[level].len()).choose(rng).unwrap();
170 let sampled_id = self.level_bucket[level][idx_in_level];
171 let weight = self.weights[sampled_id];
172 debug_assert!(weight <= self.level_max[level] && (level == self.n_levels - 1 || (self.level_max[level+1] < weight )));
173 let u = rng.random::<f64>() * self.level_max[level];
174 if u <= weight {
175 break sampled_id;
176 }
177 }
178 }
179
180 pub fn check_invariant(&self) -> bool {
181 self.level_weight.iter().sum::<f64>() == self.total_weight &&
182 self.weights.iter().sum::<f64>() == self.total_weight &&
183 self.total_weight <= self.max_value
184 }
185}
186
187fn log2_ceil2(weight: f64) -> usize {
188 let b: u64 = weight.to_bits();
189 let e = (b >> 52) & ((1<<11)-1);
191 let frac = b & ((1<<52) -1);
192 let z = if frac==0 { e as i64 - 1023 } else { e as i64 -1022 };
193 z as usize
194}
195
196fn _log2_ceil(weight: f64) -> usize {
197 let lookup_table: [f64; 34] = [
199 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 524288.0, 1048576.0, 2097152.0, 4194304.0, 8388608.0, 16777216.0, 33554432.0, 67108864.0, 134217728.0, 268435456.0, 536870912.0, 1073741824.0, 2147483648.0, 4294967296.0, 8589934592.0, ];
234
235 match lookup_table.binary_search_by(|&upper_bound| upper_bound.partial_cmp(&weight).unwrap()) {
237 Ok(index) => index as usize, Err(_) => weight.log2().ceil() as usize, }
240}
241
242#[cfg(test)]
243mod test_weighted_sampler {
244 use std::time::Instant;
245
246 use std::collections::HashMap;
247 use rand::rng;
248
249 use super::*;
250
251 #[test]
252 fn test_distr() {
253 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
254 let mut samples: HashMap<usize, usize> = HashMap::new();
255
256 sampler.insert(1, 999.);
257 sampler.insert(2, 1.);
258
259 let n_samples = 1_000_000;
260 let start = Instant::now();
261 for _ in 1..n_samples {
262 let sample = sampler.sample(&mut rng());
263 *samples.entry(sample).or_default() += 1;
264 }
265 let duration = start.elapsed();
266
267 assert!(duration.as_secs() <= 3); approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.999, epsilon=1e-4);
269 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.001, epsilon=1e-4);
270
271 println!("{:?}", sampler);
272 sampler.update(1, 99.);
273 println!("{:?}", sampler);
274
275 samples.drain();
276 let n_samples = 1_000;
277 for _ in 1..n_samples {
278 let sample = sampler.sample(&mut rng());
279 *samples.entry(sample).or_default() += 1;
280 }
281
282 approx::assert_abs_diff_eq!(samples[&1] as f64 / n_samples as f64, 0.99, epsilon=1e-2);
283 approx::assert_abs_diff_eq!(samples[&2] as f64 / n_samples as f64, 0.01, epsilon=1e-2);
284 }
285
286 #[test]
287 fn test_remove() {
288 let mut sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
289 let level = sampler.level(500.);
290 sampler.insert(1, 500.);
291 assert_eq!(Some(&1), sampler.level_bucket[level].get(0));
292 sampler.insert(2, 510.);
293 assert_eq!(Some(&2), sampler.level_bucket[level].get(1));
294 sampler.remove(1);
295 assert_eq!(Some(&2), sampler.level_bucket[level].get(0));
296 sampler.insert(1, 500.);
297 assert_eq!(Some(&1), sampler.level_bucket[level].get(1));
298 sampler.remove(1);
299 }
300
301 #[test]
302 fn test_level() {
303 let sampler = DynamicWeightedSampler::new_with_capacity(1000., 5);
304 assert_eq!(11, sampler.n_levels);
305 assert_eq!(11-1, sampler.level(1.));
306 assert_eq!(11-2, sampler.level(2.));
307 assert_eq!(11-3, sampler.level(3.));
308 assert_eq!(11-3, sampler.level(4.));
309 }
310}