use crate::core::errors::DataProfilerError;
use crate::database::security::{validate_base_query, validate_sql_identifier};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingConfig {
pub strategy: SamplingStrategy,
pub sample_size: usize,
pub seed: Option<u64>,
pub stratify_column: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SamplingStrategy {
Random,
Systematic,
Reservoir,
Stratified,
Temporal { column_name: String },
MultiStage,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
strategy: SamplingStrategy::Reservoir,
sample_size: 10000,
seed: None,
stratify_column: None,
}
}
}
impl SamplingConfig {
pub fn quick_sample(sample_size: usize) -> Self {
Self {
strategy: SamplingStrategy::Random,
sample_size,
seed: Some(42), stratify_column: None,
}
}
pub fn representative_sample(sample_size: usize, stratify_column: Option<String>) -> Self {
Self {
strategy: if stratify_column.is_some() {
SamplingStrategy::Stratified
} else {
SamplingStrategy::Systematic
},
sample_size,
seed: Some(42),
stratify_column,
}
}
pub fn temporal_sample(sample_size: usize, time_column: String) -> Self {
Self {
strategy: SamplingStrategy::Temporal {
column_name: time_column,
},
sample_size,
seed: Some(42),
stratify_column: None,
}
}
pub fn generate_sample_query(
&self,
base_query: &str,
total_rows: u64,
) -> Result<String, DataProfilerError> {
if total_rows as usize <= self.sample_size {
return Ok(base_query.to_string());
}
let sampling_ratio = self.sample_size as f64 / total_rows as f64;
match &self.strategy {
SamplingStrategy::Random => {
let seed = self.seed.unwrap_or(42);
if base_query.trim().to_uppercase().starts_with("SELECT") {
let validated_query = validate_base_query(base_query)?;
Ok(format!(
"SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
validated_query, seed, self.sample_size
))
} else {
validate_sql_identifier(base_query)?;
Ok(format!(
"SELECT * FROM {} ORDER BY RANDOM({}) LIMIT {}",
base_query, seed, self.sample_size
))
}
}
SamplingStrategy::Systematic => {
let step = (total_rows as f64 / self.sample_size as f64).ceil() as u64;
if base_query.trim().to_uppercase().starts_with("SELECT") {
let validated_query = validate_base_query(base_query)?;
Ok(format!(
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM ({})) AS numbered WHERE rn % {} = 1",
validated_query, step
))
} else {
validate_sql_identifier(base_query)?;
Ok(format!(
"SELECT * FROM (SELECT *, ROW_NUMBER() OVER() as rn FROM {}) AS numbered WHERE rn % {} = 1",
base_query, step
))
}
}
SamplingStrategy::Reservoir => {
self.generate_tablesample_query(base_query, sampling_ratio)
}
SamplingStrategy::Stratified => {
if let Some(stratify_col) = &self.stratify_column {
validate_sql_identifier(stratify_col)?;
self.generate_stratified_query(base_query, stratify_col, total_rows)
} else {
let mut fallback_config = self.clone();
fallback_config.strategy = SamplingStrategy::Random;
fallback_config.generate_sample_query(base_query, total_rows)
}
}
SamplingStrategy::Temporal { column_name } => {
validate_sql_identifier(column_name)?;
self.generate_temporal_query(base_query, column_name, total_rows)
}
SamplingStrategy::MultiStage => {
let mut config = self.clone();
config.strategy = SamplingStrategy::Systematic;
config.generate_sample_query(base_query, total_rows)
}
}
}
fn generate_tablesample_query(
&self,
base_query: &str,
sampling_ratio: f64,
) -> Result<String, DataProfilerError> {
let percentage = (sampling_ratio * 100.0).min(100.0);
if base_query.trim().to_uppercase().starts_with("SELECT") {
let validated_query = validate_base_query(base_query)?;
let seed = self.seed.unwrap_or(42);
Ok(format!(
"SELECT * FROM ({}) AS sample_subquery ORDER BY RANDOM({}) LIMIT {}",
validated_query, seed, self.sample_size
))
} else {
validate_sql_identifier(base_query)?;
Ok(format!(
"SELECT * FROM {} TABLESAMPLE SYSTEM ({:.2}) LIMIT {}",
base_query, percentage, self.sample_size
))
}
}
fn generate_stratified_query(
&self,
base_query: &str,
stratify_col: &str,
_total_rows: u64,
) -> Result<String, DataProfilerError> {
let sample_per_stratum = self.sample_size / 10;
if base_query.trim().to_uppercase().starts_with("SELECT") {
let validated_query = validate_base_query(base_query)?;
Ok(format!(
r#"
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
FROM ({}) AS base_query
) stratified
WHERE stratum_rn <= {}
LIMIT {}
"#,
stratify_col, validated_query, sample_per_stratum, self.sample_size
))
} else {
validate_sql_identifier(base_query)?;
Ok(format!(
r#"
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER(PARTITION BY {} ORDER BY RANDOM()) as stratum_rn
FROM {}
) stratified
WHERE stratum_rn <= {}
LIMIT {}
"#,
stratify_col, base_query, sample_per_stratum, self.sample_size
))
}
}
fn generate_temporal_query(
&self,
base_query: &str,
time_col: &str,
total_rows: u64,
) -> Result<String, DataProfilerError> {
if base_query.trim().to_uppercase().starts_with("SELECT") {
let validated_query = validate_base_query(base_query)?;
Ok(format!(
r#"
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
FROM ({}) AS base_query
) temporal
WHERE time_rn % {} = 1
LIMIT {}
"#,
time_col,
validated_query,
(total_rows as f64 / self.sample_size as f64).ceil() as u64,
self.sample_size
))
} else {
validate_sql_identifier(base_query)?;
Ok(format!(
r#"
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER(ORDER BY {}) as time_rn
FROM {}
) temporal
WHERE time_rn % {} = 1
LIMIT {}
"#,
time_col,
base_query,
(total_rows as f64 / self.sample_size as f64).ceil() as u64,
self.sample_size
))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SampleInfo {
pub total_rows: u64,
pub sampled_rows: u64,
pub sampling_ratio: f64,
pub strategy: SamplingStrategy,
pub is_representative: bool,
pub confidence_margin: f64,
}
impl SampleInfo {
pub fn new(total_rows: u64, sampled_rows: u64, strategy: SamplingStrategy) -> Self {
let sampling_ratio = if total_rows > 0 {
sampled_rows as f64 / total_rows as f64
} else {
1.0
};
let is_representative = match strategy {
SamplingStrategy::Systematic | SamplingStrategy::Stratified => sampled_rows >= 1000,
SamplingStrategy::Random | SamplingStrategy::Reservoir => sampled_rows >= 500,
SamplingStrategy::Temporal { .. } => sampled_rows >= 2000, SamplingStrategy::MultiStage => sampled_rows >= 1500,
};
let confidence_margin = if sampled_rows > 0 {
1.96 / (sampled_rows as f64).sqrt() } else {
1.0
};
Self {
total_rows,
sampled_rows,
sampling_ratio,
strategy,
is_representative,
confidence_margin,
}
}
pub fn get_warning(&self) -> Option<String> {
if !self.is_representative {
Some(format!(
"Sample size ({}) may be too small for reliable analysis. \
Consider increasing sample size for better representation.",
self.sampled_rows
))
} else if self.confidence_margin > 0.1 {
Some(format!(
"Large confidence margin ({:.2}). \
Statistics may have high uncertainty.",
self.confidence_margin
))
} else {
None
}
}
pub fn get_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self.sampled_rows < 1000 {
recommendations.push(
"Increase sample size to at least 1000 rows for better reliability".to_string(),
);
}
if self.sampling_ratio < 0.01 && self.total_rows > 100000 {
recommendations.push(
"Consider stratified sampling for large datasets to ensure representativeness"
.to_string(),
);
}
match &self.strategy {
SamplingStrategy::Random if self.total_rows > 1000000 => {
recommendations.push(
"For very large datasets, consider systematic or reservoir sampling"
.to_string(),
);
}
SamplingStrategy::Temporal { .. } if self.sampled_rows < 2000 => {
recommendations.push(
"Temporal sampling requires larger samples to capture time patterns"
.to_string(),
);
}
_ => {}
}
recommendations
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_random_sample_query() {
let config = SamplingConfig::quick_sample(1000);
let query = config
.generate_sample_query("users", 10000)
.expect("Failed to generate sample query");
assert!(query.contains("RANDOM"));
assert!(query.contains("LIMIT 1000"));
}
#[test]
fn test_generate_systematic_sample_query() {
let config = SamplingConfig {
strategy: SamplingStrategy::Systematic,
sample_size: 1000,
seed: Some(42),
stratify_column: None,
};
let query = config
.generate_sample_query("orders", 10000)
.expect("Failed to generate sample query");
assert!(query.contains("ROW_NUMBER()"));
assert!(query.contains("% 10 = 1")); }
#[test]
fn test_sample_info_calculations() {
let info = SampleInfo::new(10000, 1000, SamplingStrategy::Random);
assert_eq!(info.sampling_ratio, 0.1);
assert!(info.is_representative);
assert!(info.confidence_margin < 0.1);
}
#[test]
fn test_small_sample_warning() {
let info = SampleInfo::new(10000, 100, SamplingStrategy::Random);
assert!(!info.is_representative);
assert!(info.get_warning().is_some());
assert!(!info.get_recommendations().is_empty());
}
}