Skip to main content

dataprof_db/
sampling.rs

1//! Database table sampling strategies for large datasets
2
3use crate::DataProfilerError;
4use crate::security::{validate_base_query, validate_sql_identifier};
5use serde::{Deserialize, Serialize};
6
7/// Configuration for database table sampling
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SamplingConfig {
10    /// Sampling strategy to use
11    pub strategy: SamplingStrategy,
12    /// Target sample size (number of rows)
13    pub sample_size: usize,
14    /// Random seed for reproducible sampling (optional)
15    pub seed: Option<u64>,
16    /// Whether to stratify sampling by a column (optional)
17    pub stratify_column: Option<String>,
18}
19
20/// Available sampling strategies for large databases
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum SamplingStrategy {
23    /// Simple random sampling - fastest but may skip patterns
24    Random,
25    /// Systematic sampling - every nth row
26    Systematic,
27    /// Reservoir sampling - single pass, memory efficient
28    Reservoir,
29    /// Stratified sampling - maintains class distribution
30    Stratified,
31    /// Time-based sampling - for temporal data
32    Temporal { column_name: String },
33    /// Multi-stage sampling - first sample tables, then rows
34    MultiStage,
35}
36
37impl Default for SamplingConfig {
38    fn default() -> Self {
39        Self {
40            strategy: SamplingStrategy::Reservoir,
41            sample_size: 10000,
42            seed: None,
43            stratify_column: None,
44        }
45    }
46}
47
48impl SamplingConfig {
49    /// Create a new sampling config for quick analysis
50    pub fn quick_sample(sample_size: usize) -> Self {
51        Self {
52            strategy: SamplingStrategy::Random,
53            sample_size,
54            seed: Some(42),
55            stratify_column: None,
56        }
57    }
58
59    /// Create a config for representative sampling
60    pub fn representative_sample(sample_size: usize, stratify_column: Option<String>) -> Self {
61        Self {
62            strategy: if stratify_column.is_some() {
63                SamplingStrategy::Stratified
64            } else {
65                SamplingStrategy::Systematic
66            },
67            sample_size,
68            seed: Some(42),
69            stratify_column,
70        }
71    }
72
73    /// Create a config for temporal data sampling
74    pub fn temporal_sample(sample_size: usize, time_column: String) -> Self {
75        Self {
76            strategy: SamplingStrategy::Temporal {
77                column_name: time_column,
78            },
79            sample_size,
80            seed: Some(42),
81            stratify_column: None,
82        }
83    }
84
85    /// Generate the appropriate SQL sampling query
86    pub fn generate_sample_query(
87        &self,
88        base_query: &str,
89        total_rows: u64,
90    ) -> Result<String, DataProfilerError> {
91        if total_rows as usize <= self.sample_size {
92            return Ok(base_query.to_string());
93        }
94
95        let sampling_ratio = self.sample_size as f64 / total_rows as f64;
96
97        match &self.strategy {
98            SamplingStrategy::Random => {
99                let seed = self.seed.unwrap_or(42);
100                if base_query.trim().to_uppercase().starts_with("SELECT") {
101                    let validated_query = validate_base_query(base_query)?;
102                    Ok(format!(
103                        "SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
104                        validated_query, seed, self.sample_size
105                    ))
106                } else {
107                    validate_sql_identifier(base_query)?;
108                    Ok(format!(
109                        "SELECT * FROM {} ORDER BY RANDOM({}) LIMIT {}",
110                        base_query, seed, self.sample_size
111                    ))
112                }
113            }
114            SamplingStrategy::Systematic => {
115                let step = (total_rows as f64 / self.sample_size as f64).ceil() as u64;
116                if base_query.trim().to_uppercase().starts_with("SELECT") {
117                    let validated_query = validate_base_query(base_query)?;
118                    Ok(format!(
119                        "SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM ({})) AS numbered WHERE rn % {} = 1",
120                        validated_query, step
121                    ))
122                } else {
123                    validate_sql_identifier(base_query)?;
124                    Ok(format!(
125                        "SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM {}) AS numbered WHERE rn % {} = 1",
126                        base_query, step
127                    ))
128                }
129            }
130            SamplingStrategy::Reservoir => {
131                self.generate_tablesample_query(base_query, sampling_ratio)
132            }
133            SamplingStrategy::Stratified => {
134                if let Some(stratify_col) = &self.stratify_column {
135                    validate_sql_identifier(stratify_col)?;
136                    self.generate_stratified_query(base_query, stratify_col, total_rows)
137                } else {
138                    let mut fallback_config = self.clone();
139                    fallback_config.strategy = SamplingStrategy::Random;
140                    fallback_config.generate_sample_query(base_query, total_rows)
141                }
142            }
143            SamplingStrategy::Temporal { column_name } => {
144                validate_sql_identifier(column_name)?;
145                self.generate_temporal_query(base_query, column_name, total_rows)
146            }
147            SamplingStrategy::MultiStage => {
148                let mut config = self.clone();
149                config.strategy = SamplingStrategy::Systematic;
150                config.generate_sample_query(base_query, total_rows)
151            }
152        }
153    }
154
155    /// Generate a TABLESAMPLE query (PostgreSQL/SQL Server)
156    fn generate_tablesample_query(
157        &self,
158        base_query: &str,
159        sampling_ratio: f64,
160    ) -> Result<String, DataProfilerError> {
161        let percentage = (sampling_ratio * 100.0).min(100.0);
162
163        if base_query.trim().to_uppercase().starts_with("SELECT") {
164            let validated_query = validate_base_query(base_query)?;
165            let seed = self.seed.unwrap_or(42);
166            Ok(format!(
167                "SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
168                validated_query, seed, self.sample_size
169            ))
170        } else {
171            validate_sql_identifier(base_query)?;
172            Ok(format!(
173                "SELECT * FROM {} TABLESAMPLE SYSTEM ({:.2}) LIMIT {}",
174                base_query, percentage, self.sample_size
175            ))
176        }
177    }
178
179    /// Generate stratified sampling query
180    fn generate_stratified_query(
181        &self,
182        base_query: &str,
183        stratify_col: &str,
184        _total_rows: u64,
185    ) -> Result<String, DataProfilerError> {
186        let sample_per_stratum = self.sample_size / 10;
187
188        if base_query.trim().to_uppercase().starts_with("SELECT") {
189            let validated_query = validate_base_query(base_query)?;
190            Ok(format!(
191                r#"
192                SELECT * FROM (
193                    SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
194                    FROM ({}) AS base_query
195                ) stratified
196                WHERE stratum_rn <= {}
197                LIMIT {}
198                "#,
199                stratify_col, validated_query, sample_per_stratum, self.sample_size
200            ))
201        } else {
202            validate_sql_identifier(base_query)?;
203            Ok(format!(
204                r#"
205                SELECT * FROM (
206                    SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
207                    FROM {}
208                ) stratified
209                WHERE stratum_rn <= {}
210                LIMIT {}
211                "#,
212                stratify_col, base_query, sample_per_stratum, self.sample_size
213            ))
214        }
215    }
216
217    /// Generate temporal sampling query
218    fn generate_temporal_query(
219        &self,
220        base_query: &str,
221        time_col: &str,
222        total_rows: u64,
223    ) -> Result<String, DataProfilerError> {
224        if base_query.trim().to_uppercase().starts_with("SELECT") {
225            let validated_query = validate_base_query(base_query)?;
226            Ok(format!(
227                r#"
228                SELECT * FROM (
229                    SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
230                    FROM ({}) AS base_query
231                ) temporal
232                WHERE time_rn % {} = 1
233                LIMIT {}
234                "#,
235                time_col,
236                validated_query,
237                (total_rows as f64 / self.sample_size as f64).ceil() as u64,
238                self.sample_size
239            ))
240        } else {
241            validate_sql_identifier(base_query)?;
242            Ok(format!(
243                r#"
244                SELECT * FROM (
245                    SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
246                    FROM {}
247                ) temporal
248                WHERE time_rn % {} = 1
249                LIMIT {}
250                "#,
251                time_col,
252                base_query,
253                (total_rows as f64 / self.sample_size as f64).ceil() as u64,
254                self.sample_size
255            ))
256        }
257    }
258}
259
260/// Information about the sampling process
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct SampleInfo {
263    /// Total rows in the original table/query
264    pub total_rows: u64,
265    /// Number of rows in the sample
266    pub sampled_rows: u64,
267    /// Sampling ratio (0.0 - 1.0)
268    pub sampling_ratio: f64,
269    /// Sampling strategy used
270    pub strategy: SamplingStrategy,
271    /// Whether the sample is representative
272    pub is_representative: bool,
273    /// Estimated confidence interval for statistics
274    pub confidence_margin: f64,
275}
276
277impl SampleInfo {
278    /// Create new sample info
279    pub fn new(total_rows: u64, sampled_rows: u64, strategy: SamplingStrategy) -> Self {
280        let sampling_ratio = if total_rows > 0 {
281            sampled_rows as f64 / total_rows as f64
282        } else {
283            1.0
284        };
285
286        let is_representative = match strategy {
287            SamplingStrategy::Systematic | SamplingStrategy::Stratified => sampled_rows >= 1000,
288            SamplingStrategy::Random | SamplingStrategy::Reservoir => sampled_rows >= 500,
289            SamplingStrategy::Temporal { .. } => sampled_rows >= 2000,
290            SamplingStrategy::MultiStage => sampled_rows >= 1500,
291        };
292
293        let confidence_margin = if sampled_rows > 0 {
294            1.96 / (sampled_rows as f64).sqrt()
295        } else {
296            1.0
297        };
298
299        Self {
300            total_rows,
301            sampled_rows,
302            sampling_ratio,
303            strategy,
304            is_representative,
305            confidence_margin,
306        }
307    }
308
309    /// Get a warning message if the sample might not be representative
310    pub fn get_warning(&self) -> Option<String> {
311        if !self.is_representative {
312            Some(format!(
313                "Sample size ({}) may be too small for reliable analysis. \
314                Consider increasing sample size for better representation.",
315                self.sampled_rows
316            ))
317        } else if self.confidence_margin > 0.1 {
318            Some(format!(
319                "Large confidence margin ({:.2}). \
320                Statistics may have high uncertainty.",
321                self.confidence_margin
322            ))
323        } else {
324            None
325        }
326    }
327
328    /// Get recommended actions for improving sample quality
329    pub fn get_recommendations(&self) -> Vec<String> {
330        let mut recommendations = Vec::new();
331
332        if self.sampled_rows < 1000 {
333            recommendations.push(
334                "Increase sample size to at least 1000 rows for better reliability".to_string(),
335            );
336        }
337
338        if self.sampling_ratio < 0.01 && self.total_rows > 100000 {
339            recommendations.push(
340                "Consider stratified sampling for large datasets to ensure representativeness"
341                    .to_string(),
342            );
343        }
344
345        match &self.strategy {
346            SamplingStrategy::Random if self.total_rows > 1000000 => {
347                recommendations.push(
348                    "For very large datasets, consider systematic or reservoir sampling"
349                        .to_string(),
350                );
351            }
352            SamplingStrategy::Temporal { .. } if self.sampled_rows < 2000 => {
353                recommendations.push(
354                    "Temporal sampling requires larger samples to capture time patterns"
355                        .to_string(),
356                );
357            }
358            _ => {}
359        }
360
361        recommendations
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_generate_random_sample_query() {
371        let config = SamplingConfig::quick_sample(1000);
372        let query = config
373            .generate_sample_query("users", 10000)
374            .expect("Failed to generate sample query");
375
376        assert!(query.contains("RANDOM"));
377        assert!(query.contains("LIMIT 1000"));
378    }
379
380    #[test]
381    fn test_generate_systematic_sample_query() {
382        let config = SamplingConfig {
383            strategy: SamplingStrategy::Systematic,
384            sample_size: 1000,
385            seed: Some(42),
386            stratify_column: None,
387        };
388
389        let query = config
390            .generate_sample_query("orders", 10000)
391            .expect("Failed to generate sample query");
392
393        assert!(query.contains("ROW_NUMBER()"));
394        assert!(query.contains("% 10 = 1"));
395    }
396
397    #[test]
398    fn test_sample_info_calculations() {
399        let info = SampleInfo::new(10000, 1000, SamplingStrategy::Random);
400
401        assert_eq!(info.sampling_ratio, 0.1);
402        assert!(info.is_representative);
403        assert!(info.confidence_margin < 0.1);
404    }
405
406    #[test]
407    fn test_small_sample_warning() {
408        let info = SampleInfo::new(10000, 100, SamplingStrategy::Random);
409
410        assert!(!info.is_representative);
411        assert!(info.get_warning().is_some());
412        assert!(!info.get_recommendations().is_empty());
413    }
414}