Skip to main content

dataprof_core/sampling/
reservoir.rs

1use rand::{Rng, SeedableRng};
2use rand_chacha::ChaCha8Rng;
3use std::collections::HashMap;
4
5/// Enhanced reservoir sampling implementation based on Vitter's algorithm
6///
7/// This is a proper implementation of Algorithm R with optimizations:
8/// - True randomness with seedable RNG for reproducibility
9/// - Optimized skip calculation using geometric distribution
10/// - Memory-efficient storage of sample indices
11/// - Support for weighted sampling
12#[derive(Debug, Clone)]
13pub struct ReservoirSampler {
14    /// Maximum size of the reservoir
15    capacity: usize,
16    /// Current sample (stores row indices)
17    reservoir: Vec<usize>,
18    /// Total number of records processed
19    total_processed: usize,
20    /// Random number generator (seeded for reproducibility)
21    rng: ChaCha8Rng,
22    /// Skip optimization - next record to consider
23    next_record: usize,
24    /// Statistics for analysis
25    stats: ReservoirStats,
26}
27
28/// Statistics for reservoir sampling performance
29#[derive(Debug, Clone, Default)]
30pub struct ReservoirStats {
31    pub records_processed: usize,
32    pub records_sampled: usize,
33    pub replacement_count: usize,
34    pub skip_count: usize,
35    pub efficiency_ratio: f64,
36}
37
38impl ReservoirSampler {
39    /// Create a new reservoir sampler with specified capacity
40    pub fn new(capacity: usize) -> Self {
41        Self::seed(capacity, 42) // Default seed for reproducibility
42    }
43
44    /// Create a new reservoir sampler with custom seed
45    pub fn seed(capacity: usize, seed: u64) -> Self {
46        Self {
47            capacity,
48            reservoir: Vec::with_capacity(capacity),
49            total_processed: 0,
50            rng: ChaCha8Rng::seed_from_u64(seed),
51            next_record: 0,
52            stats: ReservoirStats::default(),
53        }
54    }
55
56    /// Process a new record and decide if it should be included
57    /// Returns true if the record is selected for the sample
58    pub fn process_record(&mut self, record_index: usize) -> bool {
59        self.total_processed += 1;
60        self.stats.records_processed += 1;
61
62        // Phase 1: Fill the reservoir with first k records
63        if self.reservoir.len() < self.capacity {
64            self.reservoir.push(record_index);
65            self.stats.records_sampled += 1;
66            return true;
67        }
68
69        // Phase 2: Reservoir is full, use replacement algorithm
70        self.apply_vitter_algorithm(record_index)
71    }
72
73    /// Apply Vitter's Algorithm R with skip optimization
74    fn apply_vitter_algorithm(&mut self, record_index: usize) -> bool {
75        // Skip records using geometric distribution for efficiency
76        if self.total_processed < self.next_record {
77            return false;
78        }
79
80        // Calculate if this record should replace one in the reservoir
81        let random_index = self.rng.random_range(0..self.total_processed);
82
83        if random_index < self.capacity {
84            // Replace the record at random_index in reservoir
85            let replace_position = random_index % self.capacity;
86            self.reservoir[replace_position] = record_index;
87            self.stats.replacement_count += 1;
88            self.stats.records_sampled += 1;
89
90            // Calculate next skip using geometric distribution
91            self.calculate_next_skip();
92
93            return true;
94        }
95
96        false
97    }
98
99    /// Calculate next skip distance using geometric distribution
100    /// This optimizes performance by skipping records that won't be selected
101    fn calculate_next_skip(&mut self) {
102        // Use geometric distribution to calculate skip distance
103        // This is based on Vitter's Algorithm S optimization
104        let u: f64 = self.rng.random();
105        let skip = if u > 0.0 {
106            ((self.total_processed as f64) * (u.powf(1.0 / self.capacity as f64) - 1.0)) as usize
107        } else {
108            1
109        };
110
111        self.next_record = self.total_processed + skip.max(1);
112        self.stats.skip_count += skip;
113    }
114
115    /// Get current sample as a vector of indices
116    pub fn get_sample_indices(&self) -> &[usize] {
117        &self.reservoir
118    }
119
120    /// Get current sample size
121    pub fn sample_size(&self) -> usize {
122        self.reservoir.len()
123    }
124
125    /// Check if reservoir is full
126    pub fn is_full(&self) -> bool {
127        self.reservoir.len() >= self.capacity
128    }
129
130    /// Get sampling statistics
131    pub fn get_stats(&self) -> &ReservoirStats {
132        &self.stats
133    }
134
135    /// Calculate current sampling ratio
136    pub fn sampling_ratio(&self) -> f64 {
137        if self.total_processed > 0 {
138            self.reservoir.len() as f64 / self.total_processed as f64
139        } else {
140            0.0
141        }
142    }
143
144    /// Reset the sampler for reuse
145    pub fn reset(&mut self) {
146        self.reservoir.clear();
147        self.total_processed = 0;
148        self.next_record = 0;
149        self.stats = ReservoirStats::default();
150    }
151
152    /// Set new seed for reproducible results
153    pub fn set_seed(&mut self, seed: u64) {
154        self.rng = ChaCha8Rng::seed_from_u64(seed);
155    }
156
157    /// Update efficiency statistics
158    pub fn update_efficiency_stats(&mut self) {
159        self.stats.efficiency_ratio = if self.stats.records_processed > 0 {
160            self.stats.records_sampled as f64 / self.stats.records_processed as f64
161        } else {
162            0.0
163        };
164    }
165}
166
167/// Weighted reservoir sampling for stratified sampling
168#[derive(Debug, Clone)]
169pub struct WeightedReservoirSampler {
170    base_sampler: ReservoirSampler,
171    /// Weights for each record type/stratum
172    weights: HashMap<String, f64>,
173    /// Total weight processed
174    total_weight: f64,
175}
176
177impl WeightedReservoirSampler {
178    pub fn new(capacity: usize, weights: HashMap<String, f64>) -> Self {
179        Self {
180            base_sampler: ReservoirSampler::new(capacity),
181            weights,
182            total_weight: 0.0,
183        }
184    }
185
186    /// Process a record with associated weight category
187    pub fn process_weighted_record(&mut self, record_index: usize, category: &str) -> bool {
188        let weight = self.weights.get(category).copied().unwrap_or(1.0);
189        self.total_weight += weight;
190
191        // Adjust sampling probability based on weight
192        let adjusted_probability = weight / self.total_weight;
193        let u: f64 = self.base_sampler.rng.random();
194
195        if u < adjusted_probability {
196            self.base_sampler.process_record(record_index)
197        } else {
198            self.base_sampler.total_processed += 1;
199            false
200        }
201    }
202
203    pub fn get_sample_indices(&self) -> &[usize] {
204        self.base_sampler.get_sample_indices()
205    }
206
207    pub fn sampling_ratio(&self) -> f64 {
208        self.base_sampler.sampling_ratio()
209    }
210}
211
212/// Multi-reservoir sampling for handling multiple data types
213#[derive(Debug)]
214pub struct MultiReservoirSampler {
215    reservoirs: HashMap<String, ReservoirSampler>,
216    default_capacity: usize,
217}
218
219impl MultiReservoirSampler {
220    pub fn new(default_capacity: usize) -> Self {
221        Self {
222            reservoirs: HashMap::new(),
223            default_capacity,
224        }
225    }
226
227    /// Process a record for a specific category/type
228    pub fn process_categorized_record(&mut self, record_index: usize, category: &str) -> bool {
229        let reservoir = self
230            .reservoirs
231            .entry(category.to_string())
232            .or_insert_with(|| ReservoirSampler::new(self.default_capacity));
233
234        reservoir.process_record(record_index)
235    }
236
237    /// Get combined sample from all reservoirs
238    pub fn get_combined_sample(&self) -> Vec<usize> {
239        let mut combined = Vec::new();
240
241        for reservoir in self.reservoirs.values() {
242            combined.extend_from_slice(reservoir.get_sample_indices());
243        }
244
245        // Sort for consistent ordering
246        combined.sort_unstable();
247        combined
248    }
249
250    /// Get samples by category
251    pub fn get_samples_by_category(&self) -> HashMap<String, Vec<usize>> {
252        self.reservoirs
253            .iter()
254            .map(|(category, reservoir)| {
255                (
256                    category.to_string(),
257                    reservoir.get_sample_indices().to_vec(),
258                )
259            })
260            .collect()
261    }
262
263    /// Get statistics for all reservoirs
264    pub fn get_all_stats(&self) -> HashMap<String, ReservoirStats> {
265        self.reservoirs
266            .iter()
267            .map(|(category, reservoir)| (category.to_string(), reservoir.get_stats().clone()))
268            .collect()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_basic_reservoir_sampling() {
278        let mut sampler = ReservoirSampler::new(10);
279
280        // Process 100 records
281        let mut selected_count = 0;
282        for i in 0..100 {
283            if sampler.process_record(i) {
284                selected_count += 1;
285            }
286        }
287
288        // Should have exactly 10 samples
289        assert_eq!(sampler.sample_size(), 10);
290        assert_eq!(sampler.get_sample_indices().len(), 10);
291        assert!(selected_count >= 10); // May be more due to replacements
292    }
293
294    #[test]
295    fn test_reservoir_filling_phase() {
296        let mut sampler = ReservoirSampler::new(5);
297
298        // First 5 records should all be selected
299        for i in 0..5 {
300            assert!(sampler.process_record(i));
301        }
302
303        assert_eq!(sampler.sample_size(), 5);
304        assert!(sampler.is_full());
305    }
306
307    #[test]
308    fn test_replacement_phase() {
309        let mut sampler = ReservoirSampler::seed(3, 42); // Fixed seed for reproducibility
310
311        // Fill reservoir
312        for i in 0..3 {
313            sampler.process_record(i);
314        }
315
316        // Process more records
317        let _initial_sample = sampler.get_sample_indices().to_vec();
318
319        for i in 3..20 {
320            sampler.process_record(i);
321        }
322
323        let final_sample = sampler.get_sample_indices().to_vec();
324
325        // Sample size should remain the same
326        assert_eq!(final_sample.len(), 3);
327
328        // Some replacements should have occurred
329        assert!(sampler.get_stats().replacement_count > 0);
330    }
331
332    #[test]
333    fn test_sampling_ratio() {
334        let mut sampler = ReservoirSampler::new(10);
335
336        for i in 0..100 {
337            sampler.process_record(i);
338        }
339
340        let ratio = sampler.sampling_ratio();
341        assert!((ratio - 0.1).abs() < 0.01); // Should be ~10%
342    }
343
344    #[test]
345    fn test_reset_functionality() {
346        let mut sampler = ReservoirSampler::new(5);
347
348        for i in 0..10 {
349            sampler.process_record(i);
350        }
351
352        assert_eq!(sampler.sample_size(), 5);
353        assert!(sampler.total_processed > 0);
354
355        sampler.reset();
356
357        assert_eq!(sampler.sample_size(), 0);
358        assert_eq!(sampler.total_processed, 0);
359    }
360
361    #[test]
362    fn test_weighted_sampling() {
363        let mut weights = HashMap::new();
364        weights.insert("high".to_string(), 3.0);
365        weights.insert("low".to_string(), 1.0);
366
367        let mut sampler = WeightedReservoirSampler::new(10, weights);
368
369        let mut _high_selected = 0;
370        let mut _low_selected = 0;
371
372        // Process records with different weights
373        for i in 0..50 {
374            let category = if i % 2 == 0 { "high" } else { "low" };
375            if sampler.process_weighted_record(i, category) {
376                if category == "high" {
377                    _high_selected += 1;
378                } else {
379                    _low_selected += 1;
380                }
381            }
382        }
383
384        // High weight records should be selected more frequently
385        // This is probabilistic, so we allow some variance
386        assert!(sampler.get_sample_indices().len() <= 10);
387    }
388
389    #[test]
390    fn test_multi_reservoir() {
391        let mut sampler = MultiReservoirSampler::new(5);
392
393        for i in 0..20 {
394            let category = format!("type_{}", i % 3);
395            sampler.process_categorized_record(i, &category);
396        }
397
398        let combined = sampler.get_combined_sample();
399        assert!(combined.len() <= 15); // Max 5 per category * 3 categories
400
401        let by_category = sampler.get_samples_by_category();
402        assert_eq!(by_category.len(), 3); // Should have 3 categories
403    }
404
405    #[test]
406    fn test_deterministic_with_seed() {
407        let mut sampler1 = ReservoirSampler::seed(5, 123);
408        let mut sampler2 = ReservoirSampler::seed(5, 123);
409
410        for i in 0..50 {
411            sampler1.process_record(i);
412            sampler2.process_record(i);
413        }
414
415        // Same seed should produce identical samples
416        assert_eq!(sampler1.get_sample_indices(), sampler2.get_sample_indices());
417    }
418}