dataprof_core/sampling/
reservoir.rs1use rand::{Rng, SeedableRng};
2use rand_chacha::ChaCha8Rng;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
13pub struct ReservoirSampler {
14 capacity: usize,
16 reservoir: Vec<usize>,
18 total_processed: usize,
20 rng: ChaCha8Rng,
22 next_record: usize,
24 stats: ReservoirStats,
26}
27
28#[derive(Debug, Clone, Default)]
30pub struct ReservoirStats {
31 pub records_processed: usize,
32 pub records_sampled: usize,
33 pub replacement_count: usize,
34 pub skip_count: usize,
35 pub efficiency_ratio: f64,
36}
37
38impl ReservoirSampler {
39 pub fn new(capacity: usize) -> Self {
41 Self::seed(capacity, 42) }
43
44 pub fn seed(capacity: usize, seed: u64) -> Self {
46 Self {
47 capacity,
48 reservoir: Vec::with_capacity(capacity),
49 total_processed: 0,
50 rng: ChaCha8Rng::seed_from_u64(seed),
51 next_record: 0,
52 stats: ReservoirStats::default(),
53 }
54 }
55
56 pub fn process_record(&mut self, record_index: usize) -> bool {
59 self.total_processed += 1;
60 self.stats.records_processed += 1;
61
62 if self.reservoir.len() < self.capacity {
64 self.reservoir.push(record_index);
65 self.stats.records_sampled += 1;
66 return true;
67 }
68
69 self.apply_vitter_algorithm(record_index)
71 }
72
73 fn apply_vitter_algorithm(&mut self, record_index: usize) -> bool {
75 if self.total_processed < self.next_record {
77 return false;
78 }
79
80 let random_index = self.rng.random_range(0..self.total_processed);
82
83 if random_index < self.capacity {
84 let replace_position = random_index % self.capacity;
86 self.reservoir[replace_position] = record_index;
87 self.stats.replacement_count += 1;
88 self.stats.records_sampled += 1;
89
90 self.calculate_next_skip();
92
93 return true;
94 }
95
96 false
97 }
98
99 fn calculate_next_skip(&mut self) {
102 let u: f64 = self.rng.random();
105 let skip = if u > 0.0 {
106 ((self.total_processed as f64) * (u.powf(1.0 / self.capacity as f64) - 1.0)) as usize
107 } else {
108 1
109 };
110
111 self.next_record = self.total_processed + skip.max(1);
112 self.stats.skip_count += skip;
113 }
114
115 pub fn get_sample_indices(&self) -> &[usize] {
117 &self.reservoir
118 }
119
120 pub fn sample_size(&self) -> usize {
122 self.reservoir.len()
123 }
124
125 pub fn is_full(&self) -> bool {
127 self.reservoir.len() >= self.capacity
128 }
129
130 pub fn get_stats(&self) -> &ReservoirStats {
132 &self.stats
133 }
134
135 pub fn sampling_ratio(&self) -> f64 {
137 if self.total_processed > 0 {
138 self.reservoir.len() as f64 / self.total_processed as f64
139 } else {
140 0.0
141 }
142 }
143
144 pub fn reset(&mut self) {
146 self.reservoir.clear();
147 self.total_processed = 0;
148 self.next_record = 0;
149 self.stats = ReservoirStats::default();
150 }
151
152 pub fn set_seed(&mut self, seed: u64) {
154 self.rng = ChaCha8Rng::seed_from_u64(seed);
155 }
156
157 pub fn update_efficiency_stats(&mut self) {
159 self.stats.efficiency_ratio = if self.stats.records_processed > 0 {
160 self.stats.records_sampled as f64 / self.stats.records_processed as f64
161 } else {
162 0.0
163 };
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct WeightedReservoirSampler {
170 base_sampler: ReservoirSampler,
171 weights: HashMap<String, f64>,
173 total_weight: f64,
175}
176
177impl WeightedReservoirSampler {
178 pub fn new(capacity: usize, weights: HashMap<String, f64>) -> Self {
179 Self {
180 base_sampler: ReservoirSampler::new(capacity),
181 weights,
182 total_weight: 0.0,
183 }
184 }
185
186 pub fn process_weighted_record(&mut self, record_index: usize, category: &str) -> bool {
188 let weight = self.weights.get(category).copied().unwrap_or(1.0);
189 self.total_weight += weight;
190
191 let adjusted_probability = weight / self.total_weight;
193 let u: f64 = self.base_sampler.rng.random();
194
195 if u < adjusted_probability {
196 self.base_sampler.process_record(record_index)
197 } else {
198 self.base_sampler.total_processed += 1;
199 false
200 }
201 }
202
203 pub fn get_sample_indices(&self) -> &[usize] {
204 self.base_sampler.get_sample_indices()
205 }
206
207 pub fn sampling_ratio(&self) -> f64 {
208 self.base_sampler.sampling_ratio()
209 }
210}
211
212#[derive(Debug)]
214pub struct MultiReservoirSampler {
215 reservoirs: HashMap<String, ReservoirSampler>,
216 default_capacity: usize,
217}
218
219impl MultiReservoirSampler {
220 pub fn new(default_capacity: usize) -> Self {
221 Self {
222 reservoirs: HashMap::new(),
223 default_capacity,
224 }
225 }
226
227 pub fn process_categorized_record(&mut self, record_index: usize, category: &str) -> bool {
229 let reservoir = self
230 .reservoirs
231 .entry(category.to_string())
232 .or_insert_with(|| ReservoirSampler::new(self.default_capacity));
233
234 reservoir.process_record(record_index)
235 }
236
237 pub fn get_combined_sample(&self) -> Vec<usize> {
239 let mut combined = Vec::new();
240
241 for reservoir in self.reservoirs.values() {
242 combined.extend_from_slice(reservoir.get_sample_indices());
243 }
244
245 combined.sort_unstable();
247 combined
248 }
249
250 pub fn get_samples_by_category(&self) -> HashMap<String, Vec<usize>> {
252 self.reservoirs
253 .iter()
254 .map(|(category, reservoir)| {
255 (
256 category.to_string(),
257 reservoir.get_sample_indices().to_vec(),
258 )
259 })
260 .collect()
261 }
262
263 pub fn get_all_stats(&self) -> HashMap<String, ReservoirStats> {
265 self.reservoirs
266 .iter()
267 .map(|(category, reservoir)| (category.to_string(), reservoir.get_stats().clone()))
268 .collect()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_basic_reservoir_sampling() {
278 let mut sampler = ReservoirSampler::new(10);
279
280 let mut selected_count = 0;
282 for i in 0..100 {
283 if sampler.process_record(i) {
284 selected_count += 1;
285 }
286 }
287
288 assert_eq!(sampler.sample_size(), 10);
290 assert_eq!(sampler.get_sample_indices().len(), 10);
291 assert!(selected_count >= 10); }
293
294 #[test]
295 fn test_reservoir_filling_phase() {
296 let mut sampler = ReservoirSampler::new(5);
297
298 for i in 0..5 {
300 assert!(sampler.process_record(i));
301 }
302
303 assert_eq!(sampler.sample_size(), 5);
304 assert!(sampler.is_full());
305 }
306
307 #[test]
308 fn test_replacement_phase() {
309 let mut sampler = ReservoirSampler::seed(3, 42); for i in 0..3 {
313 sampler.process_record(i);
314 }
315
316 let _initial_sample = sampler.get_sample_indices().to_vec();
318
319 for i in 3..20 {
320 sampler.process_record(i);
321 }
322
323 let final_sample = sampler.get_sample_indices().to_vec();
324
325 assert_eq!(final_sample.len(), 3);
327
328 assert!(sampler.get_stats().replacement_count > 0);
330 }
331
332 #[test]
333 fn test_sampling_ratio() {
334 let mut sampler = ReservoirSampler::new(10);
335
336 for i in 0..100 {
337 sampler.process_record(i);
338 }
339
340 let ratio = sampler.sampling_ratio();
341 assert!((ratio - 0.1).abs() < 0.01); }
343
344 #[test]
345 fn test_reset_functionality() {
346 let mut sampler = ReservoirSampler::new(5);
347
348 for i in 0..10 {
349 sampler.process_record(i);
350 }
351
352 assert_eq!(sampler.sample_size(), 5);
353 assert!(sampler.total_processed > 0);
354
355 sampler.reset();
356
357 assert_eq!(sampler.sample_size(), 0);
358 assert_eq!(sampler.total_processed, 0);
359 }
360
361 #[test]
362 fn test_weighted_sampling() {
363 let mut weights = HashMap::new();
364 weights.insert("high".to_string(), 3.0);
365 weights.insert("low".to_string(), 1.0);
366
367 let mut sampler = WeightedReservoirSampler::new(10, weights);
368
369 let mut _high_selected = 0;
370 let mut _low_selected = 0;
371
372 for i in 0..50 {
374 let category = if i % 2 == 0 { "high" } else { "low" };
375 if sampler.process_weighted_record(i, category) {
376 if category == "high" {
377 _high_selected += 1;
378 } else {
379 _low_selected += 1;
380 }
381 }
382 }
383
384 assert!(sampler.get_sample_indices().len() <= 10);
387 }
388
389 #[test]
390 fn test_multi_reservoir() {
391 let mut sampler = MultiReservoirSampler::new(5);
392
393 for i in 0..20 {
394 let category = format!("type_{}", i % 3);
395 sampler.process_categorized_record(i, &category);
396 }
397
398 let combined = sampler.get_combined_sample();
399 assert!(combined.len() <= 15); let by_category = sampler.get_samples_by_category();
402 assert_eq!(by_category.len(), 3); }
404
405 #[test]
406 fn test_deterministic_with_seed() {
407 let mut sampler1 = ReservoirSampler::seed(5, 123);
408 let mut sampler2 = ReservoirSampler::seed(5, 123);
409
410 for i in 0..50 {
411 sampler1.process_record(i);
412 sampler2.process_record(i);
413 }
414
415 assert_eq!(sampler1.get_sample_indices(), sampler2.get_sample_indices());
417 }
418}