Skip to main content

dataprof_core/sampling/
strategies.rs

1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4
5use super::reservoir::ReservoirSampler;
6
7#[derive(Debug, Clone)]
8pub enum SamplingStrategy {
9    /// No sampling - analyze all data
10    None,
11
12    /// Simple random sampling with fixed size
13    Random { size: usize },
14
15    /// Reservoir sampling for streaming data
16    Reservoir { size: usize },
17
18    /// Stratified sampling balanced by categories
19    Stratified {
20        key_columns: Vec<String>,
21        samples_per_stratum: usize,
22    },
23
24    /// Progressive sampling - stop when confidence is reached
25    Progressive {
26        initial_size: usize,
27        confidence_level: f64,
28        max_size: usize,
29    },
30
31    /// Systematic sampling (every Nth row)
32    Systematic { interval: usize },
33
34    /// Importance sampling for anomaly detection
35    Importance { weight_threshold: f64 },
36
37    /// Multi-stage sampling (combination of strategies)
38    MultiStage { stages: Vec<SamplingStrategy> },
39}
40
41/// State for advanced sampling strategies
42pub struct SamplingState {
43    /// Progressive sampling state
44    progressive_samples: usize,
45    progressive_confidence: f64,
46
47    /// Stratified sampling state
48    stratum_counts: HashMap<String, usize>,
49    stratum_samples: HashMap<String, usize>,
50
51    /// Enhanced reservoir sampler
52    reservoir_sampler: Option<ReservoirSampler>,
53}
54
55impl SamplingState {
56    pub fn new() -> Self {
57        Self {
58            progressive_samples: 0,
59            progressive_confidence: 0.0,
60            stratum_counts: HashMap::new(),
61            stratum_samples: HashMap::new(),
62            reservoir_sampler: None,
63        }
64    }
65
66    /// Initialize reservoir sampler with given capacity
67    pub fn init_reservoir(&mut self, capacity: usize) {
68        self.reservoir_sampler = Some(ReservoirSampler::new(capacity));
69    }
70
71    /// Get or initialize reservoir sampler
72    pub fn get_or_init_reservoir(&mut self, capacity: usize) -> &mut ReservoirSampler {
73        if self.reservoir_sampler.is_none() {
74            self.init_reservoir(capacity);
75        }
76        self.reservoir_sampler
77            .as_mut()
78            .expect("Reservoir sampler should be initialized after init_reservoir call")
79    }
80}
81
82impl Default for SamplingState {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl SamplingStrategy {
89    /// Create adaptive strategy based on data characteristics
90    pub fn adaptive(total_rows: Option<usize>, file_size_mb: f64) -> Self {
91        match (total_rows, file_size_mb) {
92            (Some(rows), size_mb) if rows <= 10_000 && size_mb < 10.0 => SamplingStrategy::None,
93            (Some(rows), _) if rows <= 100_000 => SamplingStrategy::Random { size: 10_000 },
94            (Some(rows), _) if rows <= 1_000_000 => SamplingStrategy::Progressive {
95                initial_size: 10_000,
96                confidence_level: 0.95,
97                max_size: 50_000,
98            },
99            (_, size_mb) if size_mb > 1000.0 => SamplingStrategy::MultiStage {
100                stages: vec![
101                    SamplingStrategy::Systematic { interval: 100 },
102                    SamplingStrategy::Progressive {
103                        initial_size: 5_000,
104                        confidence_level: 0.99,
105                        max_size: 25_000,
106                    },
107                ],
108            },
109            _ => SamplingStrategy::Reservoir { size: 100_000 },
110        }
111    }
112
113    /// Create stratified sampling strategy
114    pub fn stratified(key_columns: Vec<String>, samples_per_stratum: usize) -> Self {
115        Self::Stratified {
116            key_columns,
117            samples_per_stratum,
118        }
119    }
120
121    /// Create importance sampling strategy
122    pub fn importance(weight_threshold: f64) -> Self {
123        Self::Importance { weight_threshold }
124    }
125
126    /// Check if row should be included in sample
127    pub fn should_include(&self, row_index: usize, total_processed: usize) -> bool {
128        self.should_include_with_state(row_index, total_processed, &mut SamplingState::new(), None)
129    }
130
131    /// Check if row should be included with state tracking
132    pub fn should_include_with_state(
133        &self,
134        row_index: usize,
135        total_processed: usize,
136        state: &mut SamplingState,
137        row_data: Option<&HashMap<String, String>>,
138    ) -> bool {
139        match self {
140            SamplingStrategy::None => true,
141
142            SamplingStrategy::Random { size } => {
143                self.random_sample(row_index, total_processed, *size)
144            }
145
146            #[allow(clippy::manual_is_multiple_of)]
147            SamplingStrategy::Systematic { interval } => row_index % interval == 0,
148
149            SamplingStrategy::Reservoir { size } => {
150                self.reservoir_sample(row_index, total_processed, *size, state)
151            }
152
153            SamplingStrategy::Stratified {
154                key_columns,
155                samples_per_stratum,
156            } => self.stratified_sample(row_data, key_columns, *samples_per_stratum, state),
157
158            SamplingStrategy::Progressive {
159                initial_size,
160                confidence_level,
161                max_size,
162            } => self.progressive_sample(*initial_size, *confidence_level, *max_size, state),
163
164            SamplingStrategy::Importance { weight_threshold } => {
165                self.importance_sample(row_data, *weight_threshold)
166            }
167
168            SamplingStrategy::MultiStage { stages } => {
169                // Apply all stages in sequence
170                stages.iter().all(|stage| {
171                    stage.should_include_with_state(row_index, total_processed, state, row_data)
172                })
173            }
174        }
175    }
176
177    fn random_sample(&self, row_index: usize, total_processed: usize, size: usize) -> bool {
178        if total_processed <= size {
179            return true;
180        }
181
182        let mut hasher = DefaultHasher::new();
183        row_index.hash(&mut hasher);
184        let hash = hasher.finish();
185
186        let probability = size as f64 / total_processed as f64;
187        let threshold = (probability * u64::MAX as f64) as u64;
188
189        hash < threshold
190    }
191
192    fn reservoir_sample(
193        &self,
194        row_index: usize,
195        _total_processed: usize,
196        size: usize,
197        state: &mut SamplingState,
198    ) -> bool {
199        // Use the enhanced reservoir sampler
200        let reservoir = state.get_or_init_reservoir(size);
201        reservoir.process_record(row_index)
202    }
203
204    fn stratified_sample(
205        &self,
206        row_data: Option<&HashMap<String, String>>,
207        key_columns: &[String],
208        samples_per_stratum: usize,
209        state: &mut SamplingState,
210    ) -> bool {
211        if let Some(data) = row_data {
212            // Create stratum identifier from specified columns
213            let stratum_id = key_columns
214                .iter()
215                .filter_map(|col| data.get(col))
216                .cloned()
217                .collect::<Vec<_>>()
218                .join("|");
219
220            // Count total rows in this stratum
221            *state
222                .stratum_counts
223                .entry(stratum_id.to_string())
224                .or_insert(0) += 1;
225
226            // Check if we need more samples from this stratum
227            let current_samples = *state.stratum_samples.get(&stratum_id).unwrap_or(&0);
228
229            if current_samples < samples_per_stratum {
230                *state.stratum_samples.entry(stratum_id).or_insert(0) += 1;
231                true
232            } else {
233                false
234            }
235        } else {
236            // No row data available, fall back to random sampling
237            false
238        }
239    }
240
241    fn progressive_sample(
242        &self,
243        initial_size: usize,
244        confidence_level: f64,
245        max_size: usize,
246        state: &mut SamplingState,
247    ) -> bool {
248        if state.progressive_samples < initial_size {
249            state.progressive_samples += 1;
250            return true;
251        }
252
253        // Calculate confidence based on current sample size
254        // This is a simplified confidence calculation
255        let current_confidence = 1.0 - (1.0 / (state.progressive_samples as f64).sqrt());
256        state.progressive_confidence = current_confidence;
257
258        if current_confidence < confidence_level && state.progressive_samples < max_size {
259            state.progressive_samples += 1;
260            true
261        } else {
262            false
263        }
264    }
265
266    fn importance_sample(
267        &self,
268        row_data: Option<&HashMap<String, String>>,
269        weight_threshold: f64,
270    ) -> bool {
271        if let Some(data) = row_data {
272            // Calculate importance weight based on data characteristics
273            let weight = self.calculate_importance_weight(data);
274            weight >= weight_threshold
275        } else {
276            false
277        }
278    }
279
280    fn calculate_importance_weight(&self, data: &HashMap<String, String>) -> f64 {
281        // Simple importance calculation based on:
282        // 1. Number of non-empty values
283        // 2. Diversity of values
284        // 3. Presence of anomalous patterns
285
286        let non_empty_count = data.values().filter(|v| !v.is_empty()).count();
287        let total_values = data.len();
288
289        if total_values == 0 {
290            return 0.0;
291        }
292
293        let completeness = non_empty_count as f64 / total_values as f64;
294
295        // Check for unusual patterns that might indicate anomalies
296        let has_unusual_patterns = data.values().any(|v| {
297            // Very long strings might be anomalous
298            v.len() > 1000 ||
299            // All digits might be IDs
300            v.chars().all(|c| c.is_ascii_digit()) ||
301            // Mixed case and special characters
302            v.chars().any(|c| !c.is_ascii_alphanumeric() && !c.is_whitespace())
303        });
304
305        let anomaly_score = if has_unusual_patterns { 0.3 } else { 0.0 };
306
307        // Combine scores
308        completeness * 0.7 + anomaly_score
309    }
310
311    pub fn target_sample_size(&self) -> Option<usize> {
312        match self {
313            SamplingStrategy::None => None,
314            SamplingStrategy::Random { size } => Some(*size),
315            SamplingStrategy::Reservoir { size } => Some(*size),
316            SamplingStrategy::Stratified {
317                samples_per_stratum,
318                ..
319            } => Some(*samples_per_stratum),
320            SamplingStrategy::Progressive { max_size, .. } => Some(*max_size),
321            SamplingStrategy::Systematic { .. } => None,
322            SamplingStrategy::Importance { .. } => None,
323            SamplingStrategy::MultiStage { stages } => {
324                // Return the minimum target size across all stages
325                stages.iter().filter_map(|s| s.target_sample_size()).min()
326            }
327        }
328    }
329
330    /// Get description of the sampling strategy
331    pub fn description(&self) -> String {
332        match self {
333            SamplingStrategy::None => "Full dataset analysis".to_string(),
334            SamplingStrategy::Random { size } => format!("Random sampling ({} records)", size),
335            SamplingStrategy::Reservoir { size } => {
336                format!("Reservoir sampling ({} records)", size)
337            }
338            SamplingStrategy::Stratified {
339                key_columns,
340                samples_per_stratum,
341            } => {
342                format!(
343                    "Stratified by {} ({} per stratum)",
344                    key_columns.join(", "),
345                    samples_per_stratum
346                )
347            }
348            SamplingStrategy::Progressive {
349                initial_size,
350                confidence_level,
351                max_size,
352            } => {
353                format!(
354                    "Progressive sampling ({}-{} records, {}% confidence)",
355                    initial_size,
356                    max_size,
357                    (confidence_level * 100.0) as u8
358                )
359            }
360            SamplingStrategy::Systematic { interval } => {
361                format!("Systematic (every {}th record)", interval)
362            }
363            SamplingStrategy::Importance { weight_threshold } => {
364                format!("Importance sampling (weight > {:.2})", weight_threshold)
365            }
366            SamplingStrategy::MultiStage { stages } => {
367                format!("Multi-stage ({} stages)", stages.len())
368            }
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_random_sampling() {
379        let strategy = SamplingStrategy::Random { size: 100 };
380        let mut included_count = 0;
381
382        for i in 0..1000 {
383            if strategy.should_include(i, 1000) {
384                included_count += 1;
385            }
386        }
387
388        // Should be approximately 100 (within reasonable variance)
389        assert!(included_count > 50 && included_count < 150);
390    }
391
392    #[test]
393    fn test_systematic_sampling() {
394        let strategy = SamplingStrategy::Systematic { interval: 10 };
395        let mut state = SamplingState::new();
396
397        for i in 0..100 {
398            let included = strategy.should_include_with_state(i, i + 1, &mut state, None);
399            if i % 10 == 0 {
400                assert!(included);
401            } else {
402                assert!(!included);
403            }
404        }
405    }
406
407    #[test]
408    fn test_progressive_sampling() {
409        let strategy = SamplingStrategy::Progressive {
410            initial_size: 10,
411            confidence_level: 0.95,
412            max_size: 50,
413        };
414        let mut state = SamplingState::new();
415        let mut included_count = 0;
416
417        for i in 0..100 {
418            if strategy.should_include_with_state(i, i + 1, &mut state, None) {
419                included_count += 1;
420            }
421        }
422
423        // Should sample at least initial_size but not more than max_size
424        assert!((10..=50).contains(&included_count));
425    }
426
427    #[test]
428    fn test_adaptive_strategy() {
429        // Small dataset - should use no sampling
430        let small = SamplingStrategy::adaptive(Some(5_000), 1.0);
431        matches!(small, SamplingStrategy::None);
432
433        // Medium dataset - should use random sampling
434        let medium = SamplingStrategy::adaptive(Some(50_000), 10.0);
435        matches!(medium, SamplingStrategy::Random { .. });
436
437        // Large file - should use multi-stage
438        let large = SamplingStrategy::adaptive(Some(10_000_000), 2000.0);
439        matches!(large, SamplingStrategy::MultiStage { .. });
440    }
441}