use std::collections::HashMap;
use std::sync::Arc;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::{info, instrument};
use crate::analyzers::errors::AnalyzerError;
pub type ProfilerResult<T> = Result<T, AnalyzerError>;
#[derive(Debug, Clone)]
pub struct ProfilerConfig {
pub cardinality_threshold: u64,
pub sample_size: u64,
pub max_memory_bytes: u64,
pub enable_parallel: bool,
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
cardinality_threshold: 100,
sample_size: 10000,
max_memory_bytes: 512 * 1024 * 1024, enable_parallel: true,
}
}
}
pub type ProgressCallback = Arc<dyn Fn(ProfilerProgress) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct ProfilerProgress {
pub current_pass: u8,
pub total_passes: u8,
pub column_name: String,
pub message: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DetectedDataType {
Boolean,
Integer,
Double,
Date,
Timestamp,
String,
Mixed,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BasicStatistics {
pub row_count: u64,
pub null_count: u64,
pub null_percentage: f64,
pub approximate_cardinality: u64,
pub min_value: Option<String>,
pub max_value: Option<String>,
pub sample_values: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoricalBucket {
pub value: String,
pub count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoricalHistogram {
pub buckets: Vec<CategoricalBucket>,
pub total_count: u64,
pub entropy: f64,
pub top_values: Vec<(String, u64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NumericDistribution {
pub mean: Option<f64>,
pub std_dev: Option<f64>,
pub variance: Option<f64>,
pub quantiles: HashMap<String, f64>, pub outlier_count: u64,
pub skewness: Option<f64>,
pub kurtosis: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnProfile {
pub column_name: String,
pub data_type: DetectedDataType,
pub basic_stats: BasicStatistics,
pub categorical_histogram: Option<CategoricalHistogram>,
pub numeric_distribution: Option<NumericDistribution>,
pub profiling_time_ms: u64,
pub passes_executed: Vec<u8>,
}
pub struct ColumnProfilerBuilder {
config: ProfilerConfig,
progress_callback: Option<ProgressCallback>,
}
impl ColumnProfilerBuilder {
pub fn cardinality_threshold(mut self, threshold: u64) -> Self {
self.config.cardinality_threshold = threshold;
self
}
pub fn sample_size(mut self, size: u64) -> Self {
self.config.sample_size = size;
self
}
pub fn max_memory_bytes(mut self, bytes: u64) -> Self {
self.config.max_memory_bytes = bytes;
self
}
pub fn enable_parallel(mut self, enable: bool) -> Self {
self.config.enable_parallel = enable;
self
}
pub fn progress_callback<F>(mut self, callback: F) -> Self
where
F: Fn(ProfilerProgress) + Send + Sync + 'static,
{
self.progress_callback = Some(Arc::new(callback));
self
}
pub fn build(self) -> ColumnProfiler {
ColumnProfiler {
config: self.config,
progress_callback: self.progress_callback,
}
}
}
pub struct ColumnProfiler {
config: ProfilerConfig,
progress_callback: Option<ProgressCallback>,
}
impl ColumnProfiler {
pub fn builder() -> ColumnProfilerBuilder {
ColumnProfilerBuilder {
config: ProfilerConfig::default(),
progress_callback: None,
}
}
pub fn new() -> Self {
Self::builder().build()
}
#[instrument(skip(self, ctx))]
pub async fn profile_column(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
) -> ProfilerResult<ColumnProfile> {
let start_time = std::time::Instant::now();
let mut passes_executed = Vec::new();
info!(
table = table_name,
column = column_name,
"Starting three-pass column profiling"
);
self.report_progress(
1,
3,
column_name,
"Computing basic statistics and type detection",
);
let basic_stats = self.execute_pass1(ctx, table_name, column_name).await?;
let data_type = self.detect_data_type(&basic_stats).await?;
passes_executed.push(1);
let mut categorical_histogram = None;
let mut numeric_distribution = None;
if basic_stats.approximate_cardinality <= self.config.cardinality_threshold {
self.report_progress(2, 3, column_name, "Computing categorical histogram");
categorical_histogram = Some(
self.execute_pass2(ctx, table_name, column_name, &basic_stats)
.await?,
);
passes_executed.push(2);
} else if matches!(
data_type,
DetectedDataType::Integer | DetectedDataType::Double
) {
self.report_progress(3, 3, column_name, "Analyzing numeric distribution");
numeric_distribution = Some(
self.execute_pass3(ctx, table_name, column_name, &basic_stats)
.await?,
);
passes_executed.push(3);
}
let profiling_time_ms = start_time.elapsed().as_millis() as u64;
info!(
table = table_name,
column = column_name,
time_ms = profiling_time_ms,
passes = ?passes_executed,
"Completed column profiling"
);
Ok(ColumnProfile {
column_name: column_name.to_string(),
data_type,
basic_stats,
categorical_histogram,
numeric_distribution,
profiling_time_ms,
passes_executed,
})
}
#[instrument(skip(self, ctx))]
pub async fn profile_columns(
&self,
ctx: &SessionContext,
table_name: &str,
column_names: &[String],
) -> ProfilerResult<Vec<ColumnProfile>> {
if self.config.enable_parallel && column_names.len() > 1 {
let mut handles = Vec::new();
for column_name in column_names {
let ctx = ctx.clone();
let table_name = table_name.to_string();
let column_name = column_name.clone();
let profiler = self.clone_for_parallel();
let handle = tokio::spawn(async move {
profiler
.profile_column(&ctx, &table_name, &column_name)
.await
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok(profile)) => results.push(profile),
Ok(Err(e)) => return Err(e),
Err(e) => {
return Err(AnalyzerError::execution(format!("Task join error: {e}")))
}
}
}
Ok(results)
} else {
let mut results = Vec::new();
for column_name in column_names {
let profile = self.profile_column(ctx, table_name, column_name).await?;
results.push(profile);
}
Ok(results)
}
}
fn clone_for_parallel(&self) -> Self {
Self {
config: self.config.clone(),
progress_callback: self.progress_callback.clone(),
}
}
fn report_progress(
&self,
current_pass: u8,
total_passes: u8,
column_name: &str,
message: &str,
) {
if let Some(callback) = &self.progress_callback {
callback(ProfilerProgress {
current_pass,
total_passes,
column_name: column_name.to_string(),
message: message.to_string(),
});
}
}
#[instrument(skip(self, ctx))]
async fn execute_pass1(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
) -> ProfilerResult<BasicStatistics> {
let sample_sql = format!(
"SELECT {column_name} FROM {table_name} WHERE {column_name} IS NOT NULL LIMIT {}",
self.config.sample_size
);
let sample_df = ctx
.sql(&sample_sql)
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let sample_batches = sample_df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let stats_sql = format!(
"SELECT
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(DISTINCT {column_name}) as distinct_count
FROM {table_name}"
);
let stats_df = ctx
.sql(&stats_sql)
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let stats_batches = stats_df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
if stats_batches.is_empty() || stats_batches[0].num_rows() == 0 {
return Err(AnalyzerError::invalid_data(
"No data found for statistics computation".to_string(),
));
}
let stats_batch = &stats_batches[0];
let total_count = self.extract_u64(stats_batch, 0, "total_count")?;
let non_null_count = self.extract_u64(stats_batch, 1, "non_null_count")?;
let distinct_count = self.extract_u64(stats_batch, 2, "distinct_count")?;
let null_count = total_count - non_null_count;
let null_percentage = if total_count > 0 {
null_count as f64 / total_count as f64
} else {
0.0
};
let mut sample_values = Vec::new();
for batch in &sample_batches {
if batch.num_rows() > 0 {
let column_data = batch.column(0);
for i in 0..batch.num_rows().min(10) {
if !column_data.is_null(i) {
let value = self.extract_string_value(column_data, i)?;
sample_values.push(value);
}
}
}
}
let (min_value, max_value) = self
.get_min_max_values(ctx, table_name, column_name)
.await?;
Ok(BasicStatistics {
row_count: total_count,
null_count,
null_percentage,
approximate_cardinality: distinct_count,
min_value,
max_value,
sample_values,
})
}
#[instrument(skip(self, ctx))]
async fn execute_pass2(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
_basic_stats: &BasicStatistics,
) -> ProfilerResult<CategoricalHistogram> {
let histogram_sql = format!(
"SELECT
{column_name} as value,
COUNT(*) as count
FROM {table_name}
WHERE {column_name} IS NOT NULL
GROUP BY {column_name}
ORDER BY count DESC"
);
let df = ctx
.sql(&histogram_sql)
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let batches = df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let mut buckets = Vec::new();
let mut top_values = Vec::new();
let mut total_count = 0u64;
for batch in &batches {
for i in 0..batch.num_rows() {
let value = self.extract_string_value(batch.column(0), i)?;
let count = self.extract_u64(batch, 1, "count")?;
buckets.push(CategoricalBucket {
value: value.clone(),
count,
});
if top_values.len() < 10 {
top_values.push((value, count));
}
total_count += count;
}
}
let entropy = self.calculate_entropy(&buckets, total_count);
Ok(CategoricalHistogram {
buckets,
total_count,
entropy,
top_values,
})
}
#[instrument(skip(self, ctx))]
async fn execute_pass3(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
_basic_stats: &BasicStatistics,
) -> ProfilerResult<NumericDistribution> {
let stats_sql = format!(
"SELECT
AVG(CAST({column_name} AS DOUBLE)) as mean,
STDDEV(CAST({column_name} AS DOUBLE)) as std_dev,
VAR_SAMP(CAST({column_name} AS DOUBLE)) as variance
FROM {table_name}
WHERE {column_name} IS NOT NULL"
);
let stats_df = ctx
.sql(&stats_sql)
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let stats_batches = stats_df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let mut mean = None;
let mut std_dev = None;
let mut variance = None;
if !stats_batches.is_empty() && stats_batches[0].num_rows() > 0 {
let batch = &stats_batches[0];
mean = self.extract_optional_f64(batch, 0)?;
std_dev = self.extract_optional_f64(batch, 1)?;
variance = self.extract_optional_f64(batch, 2)?;
}
let mut quantiles = HashMap::new();
let percentiles = vec![("P50", 0.5), ("P90", 0.9), ("P95", 0.95), ("P99", 0.99)];
for (name, percentile) in percentiles {
if let Ok(value) = self
.calculate_percentile(ctx, table_name, column_name, percentile)
.await
{
quantiles.insert(name.to_string(), value);
}
}
let outlier_count = 0; let skewness = None; let kurtosis = None;
Ok(NumericDistribution {
mean,
std_dev,
variance,
quantiles,
outlier_count,
skewness,
kurtosis,
})
}
async fn detect_data_type(
&self,
basic_stats: &BasicStatistics,
) -> ProfilerResult<DetectedDataType> {
if basic_stats.sample_values.is_empty() {
return Ok(DetectedDataType::Unknown);
}
let mut type_counts = HashMap::new();
for value in &basic_stats.sample_values {
let detected_type = self.classify_value(value);
*type_counts.entry(detected_type).or_insert(0) += 1;
}
let dominant_type = type_counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(data_type, _)| data_type)
.unwrap_or(DetectedDataType::Unknown);
Ok(dominant_type)
}
fn classify_value(&self, value: &str) -> DetectedDataType {
let trimmed = value.trim();
if trimmed.eq_ignore_ascii_case("true") || trimmed.eq_ignore_ascii_case("false") {
return DetectedDataType::Boolean;
}
if trimmed.parse::<i64>().is_ok() {
return DetectedDataType::Integer;
}
if trimmed.parse::<f64>().is_ok() {
return DetectedDataType::Double;
}
if trimmed.len() == 10 && trimmed.matches('-').count() == 2 {
return DetectedDataType::Date;
}
if trimmed.contains('T') || trimmed.contains(' ') && trimmed.len() > 15 {
return DetectedDataType::Timestamp;
}
DetectedDataType::String
}
fn extract_u64(
&self,
batch: &arrow::record_batch::RecordBatch,
col_idx: usize,
col_name: &str,
) -> ProfilerResult<u64> {
use arrow::array::Array;
let column = batch.column(col_idx);
if column.is_null(0) {
return Err(AnalyzerError::invalid_data(format!(
"Null value in {col_name} column"
)));
}
if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt64Array>() {
Ok(arr.value(0))
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
Ok(arr.value(0) as u64)
} else {
Err(AnalyzerError::invalid_data(format!(
"Expected integer for {col_name}"
)))
}
}
fn extract_optional_f64(
&self,
batch: &arrow::record_batch::RecordBatch,
col_idx: usize,
) -> ProfilerResult<Option<f64>> {
use arrow::array::Array;
let column = batch.column(col_idx);
if column.is_null(0) {
return Ok(None);
}
if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
Ok(Some(arr.value(0)))
} else {
Ok(None)
}
}
fn extract_string_value(
&self,
column: &dyn arrow::array::Array,
row_idx: usize,
) -> ProfilerResult<String> {
if column.is_null(row_idx) {
return Ok("NULL".to_string());
}
if let Some(arr) = column.as_any().downcast_ref::<arrow::array::StringArray>() {
Ok(arr.value(row_idx).to_string())
} else if let Some(arr) = column
.as_any()
.downcast_ref::<arrow::array::StringViewArray>()
{
Ok(arr.value(row_idx).to_string())
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
Ok(arr.value(row_idx).to_string())
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
Ok(arr.value(row_idx).to_string())
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::BooleanArray>() {
Ok(arr.value(row_idx).to_string())
} else {
Ok("UNKNOWN".to_string())
}
}
async fn get_min_max_values(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
) -> ProfilerResult<(Option<String>, Option<String>)> {
let sql = format!(
"SELECT MIN({column_name}) as min_val, MAX({column_name}) as max_val FROM {table_name} WHERE {column_name} IS NOT NULL"
);
let df = ctx
.sql(&sql)
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
let batches = df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok((None, None));
}
let batch = &batches[0];
let min_val = if batch.column(0).is_null(0) {
None
} else {
Some(self.extract_string_value(batch.column(0), 0)?)
};
let max_val = if batch.column(1).is_null(0) {
None
} else {
Some(self.extract_string_value(batch.column(1), 0)?)
};
Ok((min_val, max_val))
}
async fn calculate_percentile(
&self,
ctx: &SessionContext,
table_name: &str,
column_name: &str,
percentile: f64,
) -> ProfilerResult<f64> {
let sql = format!(
"SELECT approx_percentile(CAST({column_name} AS DOUBLE), {percentile}) as percentile_val
FROM {table_name}
WHERE {column_name} IS NOT NULL"
);
match ctx.sql(&sql).await {
Ok(df) => {
let batches = df
.collect()
.await
.map_err(|e| AnalyzerError::execution(e.to_string()))?;
if !batches.is_empty() && batches[0].num_rows() > 0 {
let batch = &batches[0];
if let Some(value) = self.extract_optional_f64(batch, 0)? {
return Ok(value);
}
}
Err(AnalyzerError::invalid_data(
"No percentile result".to_string(),
))
}
Err(_) => {
Err(AnalyzerError::invalid_data(
"Percentile function not available".to_string(),
))
}
}
}
fn calculate_entropy(&self, buckets: &[CategoricalBucket], total_count: u64) -> f64 {
if total_count == 0 {
return 0.0;
}
let mut entropy = 0.0;
for bucket in buckets {
if bucket.count > 0 {
let probability = bucket.count as f64 / total_count as f64;
entropy -= probability * probability.log2();
}
}
entropy
}
}
impl Default for ColumnProfiler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_profiler_config_builder() {
let profiler = ColumnProfiler::builder()
.cardinality_threshold(200)
.sample_size(5000)
.max_memory_bytes(1024 * 1024 * 1024) .enable_parallel(false)
.build();
assert_eq!(profiler.config.cardinality_threshold, 200);
assert_eq!(profiler.config.sample_size, 5000);
assert_eq!(profiler.config.max_memory_bytes, 1024 * 1024 * 1024);
assert!(!profiler.config.enable_parallel);
}
#[tokio::test]
async fn test_data_type_detection() {
let profiler = ColumnProfiler::new();
assert_eq!(profiler.classify_value("123"), DetectedDataType::Integer);
assert_eq!(profiler.classify_value("123.45"), DetectedDataType::Double);
assert_eq!(profiler.classify_value("true"), DetectedDataType::Boolean);
assert_eq!(profiler.classify_value("hello"), DetectedDataType::String);
}
#[tokio::test]
async fn test_progress_callback() {
use std::sync::{Arc, Mutex};
let progress_calls = Arc::new(Mutex::new(Vec::new()));
let progress_calls_clone = progress_calls.clone();
let _profiler = ColumnProfiler::builder()
.progress_callback(move |progress| {
progress_calls_clone.lock().unwrap().push(progress);
})
.build();
}
}