1use std::collections::HashMap;
41use std::sync::Arc;
42
43use datafusion::prelude::*;
44use serde::{Deserialize, Serialize};
45use tracing::{info, instrument};
46
47use crate::analyzers::errors::AnalyzerError;
48
49pub type ProfilerResult<T> = Result<T, AnalyzerError>;
51
52#[derive(Debug, Clone)]
54pub struct ProfilerConfig {
55 pub cardinality_threshold: u64,
57 pub sample_size: u64,
59 pub max_memory_bytes: u64,
61 pub enable_parallel: bool,
63}
64
65impl Default for ProfilerConfig {
66 fn default() -> Self {
67 Self {
68 cardinality_threshold: 100,
69 sample_size: 10000,
70 max_memory_bytes: 512 * 1024 * 1024, enable_parallel: true,
72 }
73 }
74}
75
76pub type ProgressCallback = Arc<dyn Fn(ProfilerProgress) + Send + Sync>;
78
79#[derive(Debug, Clone)]
81pub struct ProfilerProgress {
82 pub current_pass: u8,
83 pub total_passes: u8,
84 pub column_name: String,
85 pub message: String,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
90pub enum DetectedDataType {
91 Boolean,
93 Integer,
95 Double,
97 Date,
99 Timestamp,
101 String,
103 Mixed,
105 Unknown,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct BasicStatistics {
112 pub row_count: u64,
113 pub null_count: u64,
114 pub null_percentage: f64,
115 pub approximate_cardinality: u64,
116 pub min_value: Option<String>,
117 pub max_value: Option<String>,
118 pub sample_values: Vec<String>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct CategoricalBucket {
124 pub value: String,
125 pub count: u64,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct CategoricalHistogram {
131 pub buckets: Vec<CategoricalBucket>,
132 pub total_count: u64,
133 pub entropy: f64,
134 pub top_values: Vec<(String, u64)>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct NumericDistribution {
140 pub mean: Option<f64>,
141 pub std_dev: Option<f64>,
142 pub variance: Option<f64>,
143 pub quantiles: HashMap<String, f64>, pub outlier_count: u64,
145 pub skewness: Option<f64>,
146 pub kurtosis: Option<f64>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ColumnProfile {
152 pub column_name: String,
153 pub data_type: DetectedDataType,
154 pub basic_stats: BasicStatistics,
155 pub categorical_histogram: Option<CategoricalHistogram>,
156 pub numeric_distribution: Option<NumericDistribution>,
157 pub profiling_time_ms: u64,
158 pub passes_executed: Vec<u8>,
159}
160
161pub struct ColumnProfilerBuilder {
163 config: ProfilerConfig,
164 progress_callback: Option<ProgressCallback>,
165}
166
167impl ColumnProfilerBuilder {
168 pub fn cardinality_threshold(mut self, threshold: u64) -> Self {
170 self.config.cardinality_threshold = threshold;
171 self
172 }
173
174 pub fn sample_size(mut self, size: u64) -> Self {
176 self.config.sample_size = size;
177 self
178 }
179
180 pub fn max_memory_bytes(mut self, bytes: u64) -> Self {
182 self.config.max_memory_bytes = bytes;
183 self
184 }
185
186 pub fn enable_parallel(mut self, enable: bool) -> Self {
188 self.config.enable_parallel = enable;
189 self
190 }
191
192 pub fn progress_callback<F>(mut self, callback: F) -> Self
194 where
195 F: Fn(ProfilerProgress) + Send + Sync + 'static,
196 {
197 self.progress_callback = Some(Arc::new(callback));
198 self
199 }
200
201 pub fn build(self) -> ColumnProfiler {
203 ColumnProfiler {
204 config: self.config,
205 progress_callback: self.progress_callback,
206 }
207 }
208}
209
210pub struct ColumnProfiler {
212 config: ProfilerConfig,
213 progress_callback: Option<ProgressCallback>,
214}
215
216impl ColumnProfiler {
217 pub fn builder() -> ColumnProfilerBuilder {
219 ColumnProfilerBuilder {
220 config: ProfilerConfig::default(),
221 progress_callback: None,
222 }
223 }
224
225 pub fn new() -> Self {
227 Self::builder().build()
228 }
229
230 #[instrument(skip(self, ctx))]
232 pub async fn profile_column(
233 &self,
234 ctx: &SessionContext,
235 table_name: &str,
236 column_name: &str,
237 ) -> ProfilerResult<ColumnProfile> {
238 let start_time = std::time::Instant::now();
239 let mut passes_executed = Vec::new();
240
241 info!(
242 table = table_name,
243 column = column_name,
244 "Starting three-pass column profiling"
245 );
246
247 self.report_progress(
249 1,
250 3,
251 column_name,
252 "Computing basic statistics and type detection",
253 );
254 let basic_stats = self.execute_pass1(ctx, table_name, column_name).await?;
255 let data_type = self.detect_data_type(&basic_stats).await?;
256 passes_executed.push(1);
257
258 let mut categorical_histogram = None;
259 let mut numeric_distribution = None;
260
261 if basic_stats.approximate_cardinality <= self.config.cardinality_threshold {
263 self.report_progress(2, 3, column_name, "Computing categorical histogram");
265 categorical_histogram = Some(
266 self.execute_pass2(ctx, table_name, column_name, &basic_stats)
267 .await?,
268 );
269 passes_executed.push(2);
270 } else if matches!(
271 data_type,
272 DetectedDataType::Integer | DetectedDataType::Double
273 ) {
274 self.report_progress(3, 3, column_name, "Analyzing numeric distribution");
276 numeric_distribution = Some(
277 self.execute_pass3(ctx, table_name, column_name, &basic_stats)
278 .await?,
279 );
280 passes_executed.push(3);
281 }
282
283 let profiling_time_ms = start_time.elapsed().as_millis() as u64;
284
285 info!(
286 table = table_name,
287 column = column_name,
288 time_ms = profiling_time_ms,
289 passes = ?passes_executed,
290 "Completed column profiling"
291 );
292
293 Ok(ColumnProfile {
294 column_name: column_name.to_string(),
295 data_type,
296 basic_stats,
297 categorical_histogram,
298 numeric_distribution,
299 profiling_time_ms,
300 passes_executed,
301 })
302 }
303
304 #[instrument(skip(self, ctx))]
306 pub async fn profile_columns(
307 &self,
308 ctx: &SessionContext,
309 table_name: &str,
310 column_names: &[String],
311 ) -> ProfilerResult<Vec<ColumnProfile>> {
312 if self.config.enable_parallel && column_names.len() > 1 {
313 let mut handles = Vec::new();
315
316 for column_name in column_names {
317 let ctx = ctx.clone();
318 let table_name = table_name.to_string();
319 let column_name = column_name.clone();
320 let profiler = self.clone_for_parallel();
321
322 let handle = tokio::spawn(async move {
323 profiler
324 .profile_column(&ctx, &table_name, &column_name)
325 .await
326 });
327 handles.push(handle);
328 }
329
330 let mut results = Vec::new();
331 for handle in handles {
332 match handle.await {
333 Ok(Ok(profile)) => results.push(profile),
334 Ok(Err(e)) => return Err(e),
335 Err(e) => {
336 return Err(AnalyzerError::execution(format!("Task join error: {e}")))
337 }
338 }
339 }
340
341 Ok(results)
342 } else {
343 let mut results = Vec::new();
345 for column_name in column_names {
346 let profile = self.profile_column(ctx, table_name, column_name).await?;
347 results.push(profile);
348 }
349 Ok(results)
350 }
351 }
352
353 fn clone_for_parallel(&self) -> Self {
355 Self {
356 config: self.config.clone(),
357 progress_callback: self.progress_callback.clone(),
358 }
359 }
360
361 fn report_progress(
363 &self,
364 current_pass: u8,
365 total_passes: u8,
366 column_name: &str,
367 message: &str,
368 ) {
369 if let Some(callback) = &self.progress_callback {
370 callback(ProfilerProgress {
371 current_pass,
372 total_passes,
373 column_name: column_name.to_string(),
374 message: message.to_string(),
375 });
376 }
377 }
378
379 #[instrument(skip(self, ctx))]
381 async fn execute_pass1(
382 &self,
383 ctx: &SessionContext,
384 table_name: &str,
385 column_name: &str,
386 ) -> ProfilerResult<BasicStatistics> {
387 let sample_sql = format!(
389 "SELECT {column_name} FROM {table_name} WHERE {column_name} IS NOT NULL LIMIT {}",
390 self.config.sample_size
391 );
392
393 let sample_df = ctx
394 .sql(&sample_sql)
395 .await
396 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
397 let sample_batches = sample_df
398 .collect()
399 .await
400 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
401
402 let stats_sql = format!(
404 "SELECT
405 COUNT(*) as total_count,
406 COUNT({column_name}) as non_null_count,
407 COUNT(DISTINCT {column_name}) as distinct_count
408 FROM {table_name}"
409 );
410
411 let stats_df = ctx
412 .sql(&stats_sql)
413 .await
414 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
415 let stats_batches = stats_df
416 .collect()
417 .await
418 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
419
420 if stats_batches.is_empty() || stats_batches[0].num_rows() == 0 {
421 return Err(AnalyzerError::invalid_data(
422 "No data found for statistics computation".to_string(),
423 ));
424 }
425
426 let stats_batch = &stats_batches[0];
427 let total_count = self.extract_u64(stats_batch, 0, "total_count")?;
428 let non_null_count = self.extract_u64(stats_batch, 1, "non_null_count")?;
429 let distinct_count = self.extract_u64(stats_batch, 2, "distinct_count")?;
430
431 let null_count = total_count - non_null_count;
432 let null_percentage = if total_count > 0 {
433 null_count as f64 / total_count as f64
434 } else {
435 0.0
436 };
437
438 let mut sample_values = Vec::new();
440 for batch in &sample_batches {
441 if batch.num_rows() > 0 {
442 let column_data = batch.column(0);
443 for i in 0..batch.num_rows().min(10) {
444 if !column_data.is_null(i) {
446 let value = self.extract_string_value(column_data, i)?;
447 sample_values.push(value);
448 }
449 }
450 }
451 }
452
453 let (min_value, max_value) = self
455 .get_min_max_values(ctx, table_name, column_name)
456 .await?;
457
458 Ok(BasicStatistics {
459 row_count: total_count,
460 null_count,
461 null_percentage,
462 approximate_cardinality: distinct_count,
463 min_value,
464 max_value,
465 sample_values,
466 })
467 }
468
469 #[instrument(skip(self, ctx))]
471 async fn execute_pass2(
472 &self,
473 ctx: &SessionContext,
474 table_name: &str,
475 column_name: &str,
476 _basic_stats: &BasicStatistics,
477 ) -> ProfilerResult<CategoricalHistogram> {
478 let histogram_sql = format!(
480 "SELECT
481 {column_name} as value,
482 COUNT(*) as count
483 FROM {table_name}
484 WHERE {column_name} IS NOT NULL
485 GROUP BY {column_name}
486 ORDER BY count DESC"
487 );
488
489 let df = ctx
490 .sql(&histogram_sql)
491 .await
492 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
493 let batches = df
494 .collect()
495 .await
496 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
497
498 let mut buckets = Vec::new();
499 let mut top_values = Vec::new();
500 let mut total_count = 0u64;
501
502 for batch in &batches {
503 for i in 0..batch.num_rows() {
504 let value = self.extract_string_value(batch.column(0), i)?;
505 let count = self.extract_u64(batch, 1, "count")?;
506
507 buckets.push(CategoricalBucket {
508 value: value.clone(),
509 count,
510 });
511
512 if top_values.len() < 10 {
513 top_values.push((value, count));
514 }
515
516 total_count += count;
517 }
518 }
519
520 let entropy = self.calculate_entropy(&buckets, total_count);
522
523 Ok(CategoricalHistogram {
524 buckets,
525 total_count,
526 entropy,
527 top_values,
528 })
529 }
530
531 #[instrument(skip(self, ctx))]
533 async fn execute_pass3(
534 &self,
535 ctx: &SessionContext,
536 table_name: &str,
537 column_name: &str,
538 _basic_stats: &BasicStatistics,
539 ) -> ProfilerResult<NumericDistribution> {
540 let stats_sql = format!(
542 "SELECT
543 AVG(CAST({column_name} AS DOUBLE)) as mean,
544 STDDEV(CAST({column_name} AS DOUBLE)) as std_dev,
545 VAR_SAMP(CAST({column_name} AS DOUBLE)) as variance
546 FROM {table_name}
547 WHERE {column_name} IS NOT NULL"
548 );
549
550 let stats_df = ctx
551 .sql(&stats_sql)
552 .await
553 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
554 let stats_batches = stats_df
555 .collect()
556 .await
557 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
558
559 let mut mean = None;
560 let mut std_dev = None;
561 let mut variance = None;
562
563 if !stats_batches.is_empty() && stats_batches[0].num_rows() > 0 {
564 let batch = &stats_batches[0];
565 mean = self.extract_optional_f64(batch, 0)?;
566 std_dev = self.extract_optional_f64(batch, 1)?;
567 variance = self.extract_optional_f64(batch, 2)?;
568 }
569
570 let mut quantiles = HashMap::new();
572 let percentiles = vec![("P50", 0.5), ("P90", 0.9), ("P95", 0.95), ("P99", 0.99)];
573
574 for (name, percentile) in percentiles {
575 if let Ok(value) = self
576 .calculate_percentile(ctx, table_name, column_name, percentile)
577 .await
578 {
579 quantiles.insert(name.to_string(), value);
580 }
581 }
582
583 let outlier_count = 0; let skewness = None; let kurtosis = None; Ok(NumericDistribution {
589 mean,
590 std_dev,
591 variance,
592 quantiles,
593 outlier_count,
594 skewness,
595 kurtosis,
596 })
597 }
598
599 async fn detect_data_type(
601 &self,
602 basic_stats: &BasicStatistics,
603 ) -> ProfilerResult<DetectedDataType> {
604 if basic_stats.sample_values.is_empty() {
605 return Ok(DetectedDataType::Unknown);
606 }
607
608 let mut type_counts = HashMap::new();
609
610 for value in &basic_stats.sample_values {
611 let detected_type = self.classify_value(value);
612 *type_counts.entry(detected_type).or_insert(0) += 1;
613 }
614
615 let dominant_type = type_counts
617 .into_iter()
618 .max_by_key(|(_, count)| *count)
619 .map(|(data_type, _)| data_type)
620 .unwrap_or(DetectedDataType::Unknown);
621
622 Ok(dominant_type)
623 }
624
625 fn classify_value(&self, value: &str) -> DetectedDataType {
627 let trimmed = value.trim();
628
629 if trimmed.eq_ignore_ascii_case("true") || trimmed.eq_ignore_ascii_case("false") {
631 return DetectedDataType::Boolean;
632 }
633
634 if trimmed.parse::<i64>().is_ok() {
636 return DetectedDataType::Integer;
637 }
638
639 if trimmed.parse::<f64>().is_ok() {
641 return DetectedDataType::Double;
642 }
643
644 if trimmed.len() == 10 && trimmed.matches('-').count() == 2 {
646 return DetectedDataType::Date;
647 }
648
649 if trimmed.contains('T') || trimmed.contains(' ') && trimmed.len() > 15 {
651 return DetectedDataType::Timestamp;
652 }
653
654 DetectedDataType::String
655 }
656
657 fn extract_u64(
659 &self,
660 batch: &arrow::record_batch::RecordBatch,
661 col_idx: usize,
662 col_name: &str,
663 ) -> ProfilerResult<u64> {
664 use arrow::array::Array;
665
666 let column = batch.column(col_idx);
667 if column.is_null(0) {
668 return Err(AnalyzerError::invalid_data(format!(
669 "Null value in {col_name} column"
670 )));
671 }
672
673 if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt64Array>() {
674 Ok(arr.value(0))
675 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
676 Ok(arr.value(0) as u64)
677 } else {
678 Err(AnalyzerError::invalid_data(format!(
679 "Expected integer for {col_name}"
680 )))
681 }
682 }
683
684 fn extract_optional_f64(
685 &self,
686 batch: &arrow::record_batch::RecordBatch,
687 col_idx: usize,
688 ) -> ProfilerResult<Option<f64>> {
689 use arrow::array::Array;
690
691 let column = batch.column(col_idx);
692 if column.is_null(0) {
693 return Ok(None);
694 }
695
696 if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
697 Ok(Some(arr.value(0)))
698 } else {
699 Ok(None)
700 }
701 }
702
703 fn extract_string_value(
704 &self,
705 column: &dyn arrow::array::Array,
706 row_idx: usize,
707 ) -> ProfilerResult<String> {
708 if column.is_null(row_idx) {
709 return Ok("NULL".to_string());
710 }
711
712 if let Some(arr) = column.as_any().downcast_ref::<arrow::array::StringArray>() {
713 Ok(arr.value(row_idx).to_string())
714 } else if let Some(arr) = column
715 .as_any()
716 .downcast_ref::<arrow::array::StringViewArray>()
717 {
718 Ok(arr.value(row_idx).to_string())
719 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
720 Ok(arr.value(row_idx).to_string())
721 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
722 Ok(arr.value(row_idx).to_string())
723 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::BooleanArray>() {
724 Ok(arr.value(row_idx).to_string())
725 } else {
726 Ok("UNKNOWN".to_string())
728 }
729 }
730
731 async fn get_min_max_values(
732 &self,
733 ctx: &SessionContext,
734 table_name: &str,
735 column_name: &str,
736 ) -> ProfilerResult<(Option<String>, Option<String>)> {
737 let sql = format!(
738 "SELECT MIN({column_name}) as min_val, MAX({column_name}) as max_val FROM {table_name} WHERE {column_name} IS NOT NULL"
739 );
740
741 let df = ctx
742 .sql(&sql)
743 .await
744 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
745 let batches = df
746 .collect()
747 .await
748 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
749
750 if batches.is_empty() || batches[0].num_rows() == 0 {
751 return Ok((None, None));
752 }
753
754 let batch = &batches[0];
755 let min_val = if batch.column(0).is_null(0) {
756 None
757 } else {
758 Some(self.extract_string_value(batch.column(0), 0)?)
759 };
760
761 let max_val = if batch.column(1).is_null(0) {
762 None
763 } else {
764 Some(self.extract_string_value(batch.column(1), 0)?)
765 };
766
767 Ok((min_val, max_val))
768 }
769
770 async fn calculate_percentile(
771 &self,
772 ctx: &SessionContext,
773 table_name: &str,
774 column_name: &str,
775 percentile: f64,
776 ) -> ProfilerResult<f64> {
777 let sql = format!(
779 "SELECT approx_percentile(CAST({column_name} AS DOUBLE), {percentile}) as percentile_val
780 FROM {table_name}
781 WHERE {column_name} IS NOT NULL"
782 );
783
784 match ctx.sql(&sql).await {
785 Ok(df) => {
786 let batches = df
787 .collect()
788 .await
789 .map_err(|e| AnalyzerError::execution(e.to_string()))?;
790
791 if !batches.is_empty() && batches[0].num_rows() > 0 {
792 let batch = &batches[0];
793 if let Some(value) = self.extract_optional_f64(batch, 0)? {
794 return Ok(value);
795 }
796 }
797
798 Err(AnalyzerError::invalid_data(
800 "No percentile result".to_string(),
801 ))
802 }
803 Err(_) => {
804 Err(AnalyzerError::invalid_data(
806 "Percentile function not available".to_string(),
807 ))
808 }
809 }
810 }
811
812 fn calculate_entropy(&self, buckets: &[CategoricalBucket], total_count: u64) -> f64 {
813 if total_count == 0 {
814 return 0.0;
815 }
816
817 let mut entropy = 0.0;
818 for bucket in buckets {
819 if bucket.count > 0 {
820 let probability = bucket.count as f64 / total_count as f64;
821 entropy -= probability * probability.log2();
822 }
823 }
824 entropy
825 }
826}
827
828impl Default for ColumnProfiler {
829 fn default() -> Self {
830 Self::new()
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 #[tokio::test]
839 async fn test_profiler_config_builder() {
840 let profiler = ColumnProfiler::builder()
841 .cardinality_threshold(200)
842 .sample_size(5000)
843 .max_memory_bytes(1024 * 1024 * 1024) .enable_parallel(false)
845 .build();
846
847 assert_eq!(profiler.config.cardinality_threshold, 200);
848 assert_eq!(profiler.config.sample_size, 5000);
849 assert_eq!(profiler.config.max_memory_bytes, 1024 * 1024 * 1024);
850 assert!(!profiler.config.enable_parallel);
851 }
852
853 #[tokio::test]
854 async fn test_data_type_detection() {
855 let profiler = ColumnProfiler::new();
856
857 assert_eq!(profiler.classify_value("123"), DetectedDataType::Integer);
858 assert_eq!(profiler.classify_value("123.45"), DetectedDataType::Double);
859 assert_eq!(profiler.classify_value("true"), DetectedDataType::Boolean);
860 assert_eq!(profiler.classify_value("hello"), DetectedDataType::String);
861 }
862
863 #[tokio::test]
864 async fn test_progress_callback() {
865 use std::sync::{Arc, Mutex};
866
867 let progress_calls = Arc::new(Mutex::new(Vec::new()));
868 let progress_calls_clone = progress_calls.clone();
869
870 let _profiler = ColumnProfiler::builder()
871 .progress_callback(move |progress| {
872 progress_calls_clone.lock().unwrap().push(progress);
873 })
874 .build();
875
876 }
878}