dataprof 0.7.1

High-performance data profiler with ISO 8000/25012 quality metrics for CSV, JSON/JSONL, and Parquet files
Documentation
//! DataFusion SQL query engine integration
//!
//! Provides SQL-based data profiling using Apache DataFusion.
//! DataFusion is a fast, extensible query engine built on Apache Arrow.

use crate::core::report_assembler::ReportAssembler;
use crate::engines::columnar::RecordBatchAnalyzer;
use crate::types::{DataSource, ExecutionMetadata, ProfileReport, QueryEngine};

use anyhow::{Context, Result};
use datafusion::prelude::*;
use futures::stream::{Stream, StreamExt};
use std::time::Instant;

/// DataFusion loader for profiling SQL queries using Arrow integration
pub struct DataFusionLoader {
    batch_size: usize,
    ctx: SessionContext,
}

impl Default for DataFusionLoader {
    fn default() -> Self {
        Self::new()
    }
}

impl DataFusionLoader {
    pub fn new() -> Self {
        let config = SessionConfig::new().with_batch_size(8192);
        let ctx = SessionContext::new_with_config(config);

        Self {
            batch_size: 8192,
            ctx,
        }
    }

    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
        self.batch_size = batch_size;
        self
    }

    /// Get the underlying SessionContext for advanced configuration
    pub fn context(&self) -> &SessionContext {
        &self.ctx
    }

    /// Get mutable access to the SessionContext
    pub fn context_mut(&mut self) -> &mut SessionContext {
        &mut self.ctx
    }

    /// Register a CSV file as a table for querying
    pub async fn register_csv(&self, table_name: &str, path: &str) -> Result<()> {
        self.ctx
            .register_csv(table_name, path, CsvReadOptions::default())
            .await
            .context(format!(
                "Failed to register CSV file '{}' as '{}'",
                path, table_name
            ))?;
        Ok(())
    }

    /// Register a Parquet file as a table for querying
    pub async fn register_parquet(&self, table_name: &str, path: &str) -> Result<()> {
        self.ctx
            .register_parquet(table_name, path, ParquetReadOptions::default())
            .await
            .context(format!(
                "Failed to register Parquet file '{}' as '{}'",
                path, table_name
            ))?;
        Ok(())
    }

    /// Register a JSON file as a table for querying
    pub async fn register_json(&self, table_name: &str, path: &str) -> Result<()> {
        self.ctx
            .register_json(table_name, path, NdJsonReadOptions::default())
            .await
            .context(format!(
                "Failed to register JSON file '{}' as '{}'",
                path, table_name
            ))?;
        Ok(())
    }

    /// Execute a SQL query and profile the results using Arrow
    pub async fn profile_query(&self, query: &str) -> Result<ProfileReport> {
        let start = Instant::now();
        log::info!("DataFusion: Preparing query");

        // Execute query and get DataFrame
        let df = self
            .ctx
            .sql(query)
            .await
            .context(format!("Failed to execute query: '{}'", query))?;

        // Collect as RecordBatch stream
        let batches = df
            .collect()
            .await
            .context("Failed to collect query results")?;

        // Initialize the RecordBatchAnalyzer
        let mut analyzer = RecordBatchAnalyzer::new();
        let mut batch_count = 0;

        // Process each batch
        for record_batch in batches {
            if record_batch.num_rows() > 0 {
                batch_count += 1;
                analyzer.process_batch(&record_batch)?;
            }
        }

        let total_rows = analyzer.total_rows();
        log::info!(
            "DataFusion: Processed {} rows in {} batches",
            total_rows,
            batch_count
        );

        // Build the report
        let column_profiles = analyzer.to_profiles(false, false, None);
        let sample_columns = analyzer.create_sample_columns();

        let scan_time_ms = start.elapsed().as_millis();
        let num_columns = column_profiles.len();

        Ok(ReportAssembler::new(
            DataSource::Query {
                engine: QueryEngine::DataFusion,
                statement: query.to_string(),
                database: None,
                execution_id: None,
            },
            ExecutionMetadata::new(total_rows, num_columns, scan_time_ms),
        )
        .columns(column_profiles)
        .with_quality_data(sample_columns)
        .build())
    }

    /// Profile a registered table directly
    pub async fn profile_table(&self, table_name: &str) -> Result<ProfileReport> {
        if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
            return Err(anyhow::anyhow!("Invalid table name: {}", table_name));
        }

        let query = format!("SELECT * FROM {}", table_name);
        self.profile_query(&query).await
    }

    /// Execute a SQL query and emit incremental profiling reports.
    ///
    /// This method is designed for real-time monitoring of data streams.
    /// It emits a [`ProfileReport`] after processing each batch, where each
    /// report contains **cumulative** statistics from all batches processed
    /// so far.
    ///
    /// For batch processing where you only need the final result,
    /// use [`DataFusionLoader::profile_query`] instead.
    ///
    /// # Example
    /// ```ignore
    /// use futures::StreamExt;
    ///
    /// let mut stream = loader.profile_query_incremental("SELECT * FROM data").await?;
    /// while let Some(report) = stream.next().await {
    ///     let report = report?;
    ///     println!("Processed {} rows so far", report.execution.rows_processed);
    /// }
    /// ```
    pub async fn profile_query_incremental(
        &self,
        query: &str,
    ) -> Result<impl Stream<Item = Result<ProfileReport>>> {
        let start = Instant::now();
        log::info!("DataFusion: Preparing query (incremental)");

        // Execute query and get DataFrame
        let df = self
            .ctx
            .sql(query)
            .await
            .context(format!("Failed to execute query: '{}'", query))?;

        // Initialize the RecordBatchAnalyzer
        let mut analyzer = RecordBatchAnalyzer::new();

        // Own the query string for the closure
        let query_owned = query.to_string();

        // Stream batches and process each one, emitting cumulative reports
        let stream = df
            .execute_stream()
            .await
            .context("Failed to execute query stream")?
            .map(move |batch| {
                let batch = batch.context("Failed to fetch batch")?;
                if batch.num_rows() > 0 {
                    analyzer.process_batch(&batch)?;
                }
                let column_profiles = analyzer.to_profiles(false, false, None);
                let sample_columns = analyzer.create_sample_columns();

                let total_rows = analyzer.total_rows();

                let scan_time_ms = start.elapsed().as_millis();
                let num_columns = column_profiles.len();

                Ok(ReportAssembler::new(
                    DataSource::Query {
                        engine: QueryEngine::DataFusion,
                        statement: query_owned.clone(),
                        database: None,
                        execution_id: None,
                    },
                    ExecutionMetadata::new(total_rows, num_columns, scan_time_ms),
                )
                .columns(column_profiles)
                .with_quality_data(sample_columns)
                .build())
            });

        Ok(stream)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Write;
    use tempfile::Builder;

    #[tokio::test]
    async fn test_datafusion_csv_profiling() -> Result<()> {
        let mut temp_file = Builder::new().suffix(".csv").tempfile()?;
        writeln!(temp_file, "name,age,salary")?;
        writeln!(temp_file, "Alice,25,50000.0")?;
        writeln!(temp_file, "Bob,30,60000.5")?;
        writeln!(temp_file, "Charlie,35,70000.0")?;
        temp_file.flush()?;

        let loader = DataFusionLoader::new();
        loader
            .register_csv("test_table", temp_file.path().to_str().unwrap())
            .await?;

        let report = loader.profile_query("SELECT * FROM test_table").await?;

        assert_eq!(report.column_profiles.len(), 3);
        assert_eq!(report.execution.rows_processed, 3);

        Ok(())
    }

    #[tokio::test]
    async fn test_datafusion_sql_aggregation() -> Result<()> {
        let mut temp_file = Builder::new().suffix(".csv").tempfile()?;
        writeln!(temp_file, "category,value")?;
        writeln!(temp_file, "A,10")?;
        writeln!(temp_file, "B,20")?;
        writeln!(temp_file, "A,30")?;
        writeln!(temp_file, "B,40")?;
        temp_file.flush()?;

        let loader = DataFusionLoader::new();
        loader
            .register_csv("data", temp_file.path().to_str().unwrap())
            .await?;

        let report = loader
            .profile_query("SELECT category, SUM(value) as total FROM data GROUP BY category")
            .await?;

        assert_eq!(report.column_profiles.len(), 2);
        assert_eq!(report.execution.rows_processed, 2); // 2 groups

        Ok(())
    }

    #[tokio::test]
    async fn test_invalid_table_name() -> Result<()> {
        let loader = DataFusionLoader::new();
        let result = loader.profile_table("invalid-table-name").await;
        assert!(result.is_err());
        Ok(())
    }
}