1use crate::DataProfilerError;
4use crate::security::{validate_base_query, validate_sql_identifier};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SamplingConfig {
10 pub strategy: SamplingStrategy,
12 pub sample_size: usize,
14 pub seed: Option<u64>,
16 pub stratify_column: Option<String>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum SamplingStrategy {
23 Random,
25 Systematic,
27 Reservoir,
29 Stratified,
31 Temporal { column_name: String },
33 MultiStage,
35}
36
37impl Default for SamplingConfig {
38 fn default() -> Self {
39 Self {
40 strategy: SamplingStrategy::Reservoir,
41 sample_size: 10000,
42 seed: None,
43 stratify_column: None,
44 }
45 }
46}
47
48impl SamplingConfig {
49 pub fn quick_sample(sample_size: usize) -> Self {
51 Self {
52 strategy: SamplingStrategy::Random,
53 sample_size,
54 seed: Some(42),
55 stratify_column: None,
56 }
57 }
58
59 pub fn representative_sample(sample_size: usize, stratify_column: Option<String>) -> Self {
61 Self {
62 strategy: if stratify_column.is_some() {
63 SamplingStrategy::Stratified
64 } else {
65 SamplingStrategy::Systematic
66 },
67 sample_size,
68 seed: Some(42),
69 stratify_column,
70 }
71 }
72
73 pub fn temporal_sample(sample_size: usize, time_column: String) -> Self {
75 Self {
76 strategy: SamplingStrategy::Temporal {
77 column_name: time_column,
78 },
79 sample_size,
80 seed: Some(42),
81 stratify_column: None,
82 }
83 }
84
85 pub fn generate_sample_query(
87 &self,
88 base_query: &str,
89 total_rows: u64,
90 ) -> Result<String, DataProfilerError> {
91 if total_rows as usize <= self.sample_size {
92 return Ok(base_query.to_string());
93 }
94
95 let sampling_ratio = self.sample_size as f64 / total_rows as f64;
96
97 match &self.strategy {
98 SamplingStrategy::Random => {
99 let seed = self.seed.unwrap_or(42);
100 if base_query.trim().to_uppercase().starts_with("SELECT") {
101 let validated_query = validate_base_query(base_query)?;
102 Ok(format!(
103 "SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
104 validated_query, seed, self.sample_size
105 ))
106 } else {
107 validate_sql_identifier(base_query)?;
108 Ok(format!(
109 "SELECT * FROM {} ORDER BY RANDOM({}) LIMIT {}",
110 base_query, seed, self.sample_size
111 ))
112 }
113 }
114 SamplingStrategy::Systematic => {
115 let step = (total_rows as f64 / self.sample_size as f64).ceil() as u64;
116 if base_query.trim().to_uppercase().starts_with("SELECT") {
117 let validated_query = validate_base_query(base_query)?;
118 Ok(format!(
119 "SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM ({})) AS numbered WHERE rn % {} = 1",
120 validated_query, step
121 ))
122 } else {
123 validate_sql_identifier(base_query)?;
124 Ok(format!(
125 "SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM {}) AS numbered WHERE rn % {} = 1",
126 base_query, step
127 ))
128 }
129 }
130 SamplingStrategy::Reservoir => {
131 self.generate_tablesample_query(base_query, sampling_ratio)
132 }
133 SamplingStrategy::Stratified => {
134 if let Some(stratify_col) = &self.stratify_column {
135 validate_sql_identifier(stratify_col)?;
136 self.generate_stratified_query(base_query, stratify_col, total_rows)
137 } else {
138 let mut fallback_config = self.clone();
139 fallback_config.strategy = SamplingStrategy::Random;
140 fallback_config.generate_sample_query(base_query, total_rows)
141 }
142 }
143 SamplingStrategy::Temporal { column_name } => {
144 validate_sql_identifier(column_name)?;
145 self.generate_temporal_query(base_query, column_name, total_rows)
146 }
147 SamplingStrategy::MultiStage => {
148 let mut config = self.clone();
149 config.strategy = SamplingStrategy::Systematic;
150 config.generate_sample_query(base_query, total_rows)
151 }
152 }
153 }
154
155 fn generate_tablesample_query(
157 &self,
158 base_query: &str,
159 sampling_ratio: f64,
160 ) -> Result<String, DataProfilerError> {
161 let percentage = (sampling_ratio * 100.0).min(100.0);
162
163 if base_query.trim().to_uppercase().starts_with("SELECT") {
164 let validated_query = validate_base_query(base_query)?;
165 let seed = self.seed.unwrap_or(42);
166 Ok(format!(
167 "SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
168 validated_query, seed, self.sample_size
169 ))
170 } else {
171 validate_sql_identifier(base_query)?;
172 Ok(format!(
173 "SELECT * FROM {} TABLESAMPLE SYSTEM ({:.2}) LIMIT {}",
174 base_query, percentage, self.sample_size
175 ))
176 }
177 }
178
179 fn generate_stratified_query(
181 &self,
182 base_query: &str,
183 stratify_col: &str,
184 _total_rows: u64,
185 ) -> Result<String, DataProfilerError> {
186 let sample_per_stratum = self.sample_size / 10;
187
188 if base_query.trim().to_uppercase().starts_with("SELECT") {
189 let validated_query = validate_base_query(base_query)?;
190 Ok(format!(
191 r#"
192 SELECT * FROM (
193 SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
194 FROM ({}) AS base_query
195 ) stratified
196 WHERE stratum_rn <= {}
197 LIMIT {}
198 "#,
199 stratify_col, validated_query, sample_per_stratum, self.sample_size
200 ))
201 } else {
202 validate_sql_identifier(base_query)?;
203 Ok(format!(
204 r#"
205 SELECT * FROM (
206 SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
207 FROM {}
208 ) stratified
209 WHERE stratum_rn <= {}
210 LIMIT {}
211 "#,
212 stratify_col, base_query, sample_per_stratum, self.sample_size
213 ))
214 }
215 }
216
217 fn generate_temporal_query(
219 &self,
220 base_query: &str,
221 time_col: &str,
222 total_rows: u64,
223 ) -> Result<String, DataProfilerError> {
224 if base_query.trim().to_uppercase().starts_with("SELECT") {
225 let validated_query = validate_base_query(base_query)?;
226 Ok(format!(
227 r#"
228 SELECT * FROM (
229 SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
230 FROM ({}) AS base_query
231 ) temporal
232 WHERE time_rn % {} = 1
233 LIMIT {}
234 "#,
235 time_col,
236 validated_query,
237 (total_rows as f64 / self.sample_size as f64).ceil() as u64,
238 self.sample_size
239 ))
240 } else {
241 validate_sql_identifier(base_query)?;
242 Ok(format!(
243 r#"
244 SELECT * FROM (
245 SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
246 FROM {}
247 ) temporal
248 WHERE time_rn % {} = 1
249 LIMIT {}
250 "#,
251 time_col,
252 base_query,
253 (total_rows as f64 / self.sample_size as f64).ceil() as u64,
254 self.sample_size
255 ))
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct SampleInfo {
263 pub total_rows: u64,
265 pub sampled_rows: u64,
267 pub sampling_ratio: f64,
269 pub strategy: SamplingStrategy,
271 pub is_representative: bool,
273 pub confidence_margin: f64,
275}
276
277impl SampleInfo {
278 pub fn new(total_rows: u64, sampled_rows: u64, strategy: SamplingStrategy) -> Self {
280 let sampling_ratio = if total_rows > 0 {
281 sampled_rows as f64 / total_rows as f64
282 } else {
283 1.0
284 };
285
286 let is_representative = match strategy {
287 SamplingStrategy::Systematic | SamplingStrategy::Stratified => sampled_rows >= 1000,
288 SamplingStrategy::Random | SamplingStrategy::Reservoir => sampled_rows >= 500,
289 SamplingStrategy::Temporal { .. } => sampled_rows >= 2000,
290 SamplingStrategy::MultiStage => sampled_rows >= 1500,
291 };
292
293 let confidence_margin = if sampled_rows > 0 {
294 1.96 / (sampled_rows as f64).sqrt()
295 } else {
296 1.0
297 };
298
299 Self {
300 total_rows,
301 sampled_rows,
302 sampling_ratio,
303 strategy,
304 is_representative,
305 confidence_margin,
306 }
307 }
308
309 pub fn get_warning(&self) -> Option<String> {
311 if !self.is_representative {
312 Some(format!(
313 "Sample size ({}) may be too small for reliable analysis. \
314 Consider increasing sample size for better representation.",
315 self.sampled_rows
316 ))
317 } else if self.confidence_margin > 0.1 {
318 Some(format!(
319 "Large confidence margin ({:.2}). \
320 Statistics may have high uncertainty.",
321 self.confidence_margin
322 ))
323 } else {
324 None
325 }
326 }
327
328 pub fn get_recommendations(&self) -> Vec<String> {
330 let mut recommendations = Vec::new();
331
332 if self.sampled_rows < 1000 {
333 recommendations.push(
334 "Increase sample size to at least 1000 rows for better reliability".to_string(),
335 );
336 }
337
338 if self.sampling_ratio < 0.01 && self.total_rows > 100000 {
339 recommendations.push(
340 "Consider stratified sampling for large datasets to ensure representativeness"
341 .to_string(),
342 );
343 }
344
345 match &self.strategy {
346 SamplingStrategy::Random if self.total_rows > 1000000 => {
347 recommendations.push(
348 "For very large datasets, consider systematic or reservoir sampling"
349 .to_string(),
350 );
351 }
352 SamplingStrategy::Temporal { .. } if self.sampled_rows < 2000 => {
353 recommendations.push(
354 "Temporal sampling requires larger samples to capture time patterns"
355 .to_string(),
356 );
357 }
358 _ => {}
359 }
360
361 recommendations
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_generate_random_sample_query() {
371 let config = SamplingConfig::quick_sample(1000);
372 let query = config
373 .generate_sample_query("users", 10000)
374 .expect("Failed to generate sample query");
375
376 assert!(query.contains("RANDOM"));
377 assert!(query.contains("LIMIT 1000"));
378 }
379
380 #[test]
381 fn test_generate_systematic_sample_query() {
382 let config = SamplingConfig {
383 strategy: SamplingStrategy::Systematic,
384 sample_size: 1000,
385 seed: Some(42),
386 stratify_column: None,
387 };
388
389 let query = config
390 .generate_sample_query("orders", 10000)
391 .expect("Failed to generate sample query");
392
393 assert!(query.contains("ROW_NUMBER()"));
394 assert!(query.contains("% 10 = 1"));
395 }
396
397 #[test]
398 fn test_sample_info_calculations() {
399 let info = SampleInfo::new(10000, 1000, SamplingStrategy::Random);
400
401 assert_eq!(info.sampling_ratio, 0.1);
402 assert!(info.is_representative);
403 assert!(info.confidence_margin < 0.1);
404 }
405
406 #[test]
407 fn test_small_sample_warning() {
408 let info = SampleInfo::new(10000, 100, SamplingStrategy::Random);
409
410 assert!(!info.is_representative);
411 assert!(info.get_warning().is_some());
412 assert!(!info.get_recommendations().is_empty());
413 }
414}