1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4
5use super::reservoir::ReservoirSampler;
6
7#[derive(Debug, Clone)]
8pub enum SamplingStrategy {
9 None,
11
12 Random { size: usize },
14
15 Reservoir { size: usize },
17
18 Stratified {
20 key_columns: Vec<String>,
21 samples_per_stratum: usize,
22 },
23
24 Progressive {
26 initial_size: usize,
27 confidence_level: f64,
28 max_size: usize,
29 },
30
31 Systematic { interval: usize },
33
34 Importance { weight_threshold: f64 },
36
37 MultiStage { stages: Vec<SamplingStrategy> },
39}
40
41pub struct SamplingState {
43 progressive_samples: usize,
45 progressive_confidence: f64,
46
47 stratum_counts: HashMap<String, usize>,
49 stratum_samples: HashMap<String, usize>,
50
51 reservoir_sampler: Option<ReservoirSampler>,
53}
54
55impl SamplingState {
56 pub fn new() -> Self {
57 Self {
58 progressive_samples: 0,
59 progressive_confidence: 0.0,
60 stratum_counts: HashMap::new(),
61 stratum_samples: HashMap::new(),
62 reservoir_sampler: None,
63 }
64 }
65
66 pub fn init_reservoir(&mut self, capacity: usize) {
68 self.reservoir_sampler = Some(ReservoirSampler::new(capacity));
69 }
70
71 pub fn get_or_init_reservoir(&mut self, capacity: usize) -> &mut ReservoirSampler {
73 if self.reservoir_sampler.is_none() {
74 self.init_reservoir(capacity);
75 }
76 self.reservoir_sampler
77 .as_mut()
78 .expect("Reservoir sampler should be initialized after init_reservoir call")
79 }
80}
81
82impl Default for SamplingState {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl SamplingStrategy {
89 pub fn adaptive(total_rows: Option<usize>, file_size_mb: f64) -> Self {
91 match (total_rows, file_size_mb) {
92 (Some(rows), size_mb) if rows <= 10_000 && size_mb < 10.0 => SamplingStrategy::None,
93 (Some(rows), _) if rows <= 100_000 => SamplingStrategy::Random { size: 10_000 },
94 (Some(rows), _) if rows <= 1_000_000 => SamplingStrategy::Progressive {
95 initial_size: 10_000,
96 confidence_level: 0.95,
97 max_size: 50_000,
98 },
99 (_, size_mb) if size_mb > 1000.0 => SamplingStrategy::MultiStage {
100 stages: vec![
101 SamplingStrategy::Systematic { interval: 100 },
102 SamplingStrategy::Progressive {
103 initial_size: 5_000,
104 confidence_level: 0.99,
105 max_size: 25_000,
106 },
107 ],
108 },
109 _ => SamplingStrategy::Reservoir { size: 100_000 },
110 }
111 }
112
113 pub fn stratified(key_columns: Vec<String>, samples_per_stratum: usize) -> Self {
115 Self::Stratified {
116 key_columns,
117 samples_per_stratum,
118 }
119 }
120
121 pub fn importance(weight_threshold: f64) -> Self {
123 Self::Importance { weight_threshold }
124 }
125
126 pub fn should_include(&self, row_index: usize, total_processed: usize) -> bool {
128 self.should_include_with_state(row_index, total_processed, &mut SamplingState::new(), None)
129 }
130
131 pub fn should_include_with_state(
133 &self,
134 row_index: usize,
135 total_processed: usize,
136 state: &mut SamplingState,
137 row_data: Option<&HashMap<String, String>>,
138 ) -> bool {
139 match self {
140 SamplingStrategy::None => true,
141
142 SamplingStrategy::Random { size } => {
143 self.random_sample(row_index, total_processed, *size)
144 }
145
146 #[allow(clippy::manual_is_multiple_of)]
147 SamplingStrategy::Systematic { interval } => row_index % interval == 0,
148
149 SamplingStrategy::Reservoir { size } => {
150 self.reservoir_sample(row_index, total_processed, *size, state)
151 }
152
153 SamplingStrategy::Stratified {
154 key_columns,
155 samples_per_stratum,
156 } => self.stratified_sample(row_data, key_columns, *samples_per_stratum, state),
157
158 SamplingStrategy::Progressive {
159 initial_size,
160 confidence_level,
161 max_size,
162 } => self.progressive_sample(*initial_size, *confidence_level, *max_size, state),
163
164 SamplingStrategy::Importance { weight_threshold } => {
165 self.importance_sample(row_data, *weight_threshold)
166 }
167
168 SamplingStrategy::MultiStage { stages } => {
169 stages.iter().all(|stage| {
171 stage.should_include_with_state(row_index, total_processed, state, row_data)
172 })
173 }
174 }
175 }
176
177 fn random_sample(&self, row_index: usize, total_processed: usize, size: usize) -> bool {
178 if total_processed <= size {
179 return true;
180 }
181
182 let mut hasher = DefaultHasher::new();
183 row_index.hash(&mut hasher);
184 let hash = hasher.finish();
185
186 let probability = size as f64 / total_processed as f64;
187 let threshold = (probability * u64::MAX as f64) as u64;
188
189 hash < threshold
190 }
191
192 fn reservoir_sample(
193 &self,
194 row_index: usize,
195 _total_processed: usize,
196 size: usize,
197 state: &mut SamplingState,
198 ) -> bool {
199 let reservoir = state.get_or_init_reservoir(size);
201 reservoir.process_record(row_index)
202 }
203
204 fn stratified_sample(
205 &self,
206 row_data: Option<&HashMap<String, String>>,
207 key_columns: &[String],
208 samples_per_stratum: usize,
209 state: &mut SamplingState,
210 ) -> bool {
211 if let Some(data) = row_data {
212 let stratum_id = key_columns
214 .iter()
215 .filter_map(|col| data.get(col))
216 .cloned()
217 .collect::<Vec<_>>()
218 .join("|");
219
220 *state
222 .stratum_counts
223 .entry(stratum_id.to_string())
224 .or_insert(0) += 1;
225
226 let current_samples = *state.stratum_samples.get(&stratum_id).unwrap_or(&0);
228
229 if current_samples < samples_per_stratum {
230 *state.stratum_samples.entry(stratum_id).or_insert(0) += 1;
231 true
232 } else {
233 false
234 }
235 } else {
236 false
238 }
239 }
240
241 fn progressive_sample(
242 &self,
243 initial_size: usize,
244 confidence_level: f64,
245 max_size: usize,
246 state: &mut SamplingState,
247 ) -> bool {
248 if state.progressive_samples < initial_size {
249 state.progressive_samples += 1;
250 return true;
251 }
252
253 let current_confidence = 1.0 - (1.0 / (state.progressive_samples as f64).sqrt());
256 state.progressive_confidence = current_confidence;
257
258 if current_confidence < confidence_level && state.progressive_samples < max_size {
259 state.progressive_samples += 1;
260 true
261 } else {
262 false
263 }
264 }
265
266 fn importance_sample(
267 &self,
268 row_data: Option<&HashMap<String, String>>,
269 weight_threshold: f64,
270 ) -> bool {
271 if let Some(data) = row_data {
272 let weight = self.calculate_importance_weight(data);
274 weight >= weight_threshold
275 } else {
276 false
277 }
278 }
279
280 fn calculate_importance_weight(&self, data: &HashMap<String, String>) -> f64 {
281 let non_empty_count = data.values().filter(|v| !v.is_empty()).count();
287 let total_values = data.len();
288
289 if total_values == 0 {
290 return 0.0;
291 }
292
293 let completeness = non_empty_count as f64 / total_values as f64;
294
295 let has_unusual_patterns = data.values().any(|v| {
297 v.len() > 1000 ||
299 v.chars().all(|c| c.is_ascii_digit()) ||
301 v.chars().any(|c| !c.is_ascii_alphanumeric() && !c.is_whitespace())
303 });
304
305 let anomaly_score = if has_unusual_patterns { 0.3 } else { 0.0 };
306
307 completeness * 0.7 + anomaly_score
309 }
310
311 pub fn target_sample_size(&self) -> Option<usize> {
312 match self {
313 SamplingStrategy::None => None,
314 SamplingStrategy::Random { size } => Some(*size),
315 SamplingStrategy::Reservoir { size } => Some(*size),
316 SamplingStrategy::Stratified {
317 samples_per_stratum,
318 ..
319 } => Some(*samples_per_stratum),
320 SamplingStrategy::Progressive { max_size, .. } => Some(*max_size),
321 SamplingStrategy::Systematic { .. } => None,
322 SamplingStrategy::Importance { .. } => None,
323 SamplingStrategy::MultiStage { stages } => {
324 stages.iter().filter_map(|s| s.target_sample_size()).min()
326 }
327 }
328 }
329
330 pub fn description(&self) -> String {
332 match self {
333 SamplingStrategy::None => "Full dataset analysis".to_string(),
334 SamplingStrategy::Random { size } => format!("Random sampling ({} records)", size),
335 SamplingStrategy::Reservoir { size } => {
336 format!("Reservoir sampling ({} records)", size)
337 }
338 SamplingStrategy::Stratified {
339 key_columns,
340 samples_per_stratum,
341 } => {
342 format!(
343 "Stratified by {} ({} per stratum)",
344 key_columns.join(", "),
345 samples_per_stratum
346 )
347 }
348 SamplingStrategy::Progressive {
349 initial_size,
350 confidence_level,
351 max_size,
352 } => {
353 format!(
354 "Progressive sampling ({}-{} records, {}% confidence)",
355 initial_size,
356 max_size,
357 (confidence_level * 100.0) as u8
358 )
359 }
360 SamplingStrategy::Systematic { interval } => {
361 format!("Systematic (every {}th record)", interval)
362 }
363 SamplingStrategy::Importance { weight_threshold } => {
364 format!("Importance sampling (weight > {:.2})", weight_threshold)
365 }
366 SamplingStrategy::MultiStage { stages } => {
367 format!("Multi-stage ({} stages)", stages.len())
368 }
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_random_sampling() {
379 let strategy = SamplingStrategy::Random { size: 100 };
380 let mut included_count = 0;
381
382 for i in 0..1000 {
383 if strategy.should_include(i, 1000) {
384 included_count += 1;
385 }
386 }
387
388 assert!(included_count > 50 && included_count < 150);
390 }
391
392 #[test]
393 fn test_systematic_sampling() {
394 let strategy = SamplingStrategy::Systematic { interval: 10 };
395 let mut state = SamplingState::new();
396
397 for i in 0..100 {
398 let included = strategy.should_include_with_state(i, i + 1, &mut state, None);
399 if i % 10 == 0 {
400 assert!(included);
401 } else {
402 assert!(!included);
403 }
404 }
405 }
406
407 #[test]
408 fn test_progressive_sampling() {
409 let strategy = SamplingStrategy::Progressive {
410 initial_size: 10,
411 confidence_level: 0.95,
412 max_size: 50,
413 };
414 let mut state = SamplingState::new();
415 let mut included_count = 0;
416
417 for i in 0..100 {
418 if strategy.should_include_with_state(i, i + 1, &mut state, None) {
419 included_count += 1;
420 }
421 }
422
423 assert!((10..=50).contains(&included_count));
425 }
426
427 #[test]
428 fn test_adaptive_strategy() {
429 let small = SamplingStrategy::adaptive(Some(5_000), 1.0);
431 matches!(small, SamplingStrategy::None);
432
433 let medium = SamplingStrategy::adaptive(Some(50_000), 10.0);
435 matches!(medium, SamplingStrategy::Random { .. });
436
437 let large = SamplingStrategy::adaptive(Some(10_000_000), 2000.0);
439 matches!(large, SamplingStrategy::MultiStage { .. });
440 }
441}