use crate::core::report_assembler::ReportAssembler;
use crate::engines::columnar::RecordBatchAnalyzer;
use crate::types::{DataSource, ExecutionMetadata, ProfileReport, QueryEngine};
use anyhow::{Context, Result};
use datafusion::prelude::*;
use futures::stream::{Stream, StreamExt};
use std::time::Instant;
pub struct DataFusionLoader {
batch_size: usize,
ctx: SessionContext,
}
impl Default for DataFusionLoader {
fn default() -> Self {
Self::new()
}
}
impl DataFusionLoader {
pub fn new() -> Self {
let config = SessionConfig::new().with_batch_size(8192);
let ctx = SessionContext::new_with_config(config);
Self {
batch_size: 8192,
ctx,
}
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn context(&self) -> &SessionContext {
&self.ctx
}
pub fn context_mut(&mut self) -> &mut SessionContext {
&mut self.ctx
}
pub async fn register_csv(&self, table_name: &str, path: &str) -> Result<()> {
self.ctx
.register_csv(table_name, path, CsvReadOptions::default())
.await
.context(format!(
"Failed to register CSV file '{}' as '{}'",
path, table_name
))?;
Ok(())
}
pub async fn register_parquet(&self, table_name: &str, path: &str) -> Result<()> {
self.ctx
.register_parquet(table_name, path, ParquetReadOptions::default())
.await
.context(format!(
"Failed to register Parquet file '{}' as '{}'",
path, table_name
))?;
Ok(())
}
pub async fn register_json(&self, table_name: &str, path: &str) -> Result<()> {
self.ctx
.register_json(table_name, path, NdJsonReadOptions::default())
.await
.context(format!(
"Failed to register JSON file '{}' as '{}'",
path, table_name
))?;
Ok(())
}
pub async fn profile_query(&self, query: &str) -> Result<ProfileReport> {
let start = Instant::now();
log::info!("DataFusion: Preparing query");
let df = self
.ctx
.sql(query)
.await
.context(format!("Failed to execute query: '{}'", query))?;
let batches = df
.collect()
.await
.context("Failed to collect query results")?;
let mut analyzer = RecordBatchAnalyzer::new();
let mut batch_count = 0;
for record_batch in batches {
if record_batch.num_rows() > 0 {
batch_count += 1;
analyzer.process_batch(&record_batch)?;
}
}
let total_rows = analyzer.total_rows();
log::info!(
"DataFusion: Processed {} rows in {} batches",
total_rows,
batch_count
);
let column_profiles = analyzer.to_profiles(false, false, None);
let sample_columns = analyzer.create_sample_columns();
let scan_time_ms = start.elapsed().as_millis();
let num_columns = column_profiles.len();
Ok(ReportAssembler::new(
DataSource::Query {
engine: QueryEngine::DataFusion,
statement: query.to_string(),
database: None,
execution_id: None,
},
ExecutionMetadata::new(total_rows, num_columns, scan_time_ms),
)
.columns(column_profiles)
.with_quality_data(sample_columns)
.build())
}
pub async fn profile_table(&self, table_name: &str) -> Result<ProfileReport> {
if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(anyhow::anyhow!("Invalid table name: {}", table_name));
}
let query = format!("SELECT * FROM {}", table_name);
self.profile_query(&query).await
}
pub async fn profile_query_incremental(
&self,
query: &str,
) -> Result<impl Stream<Item = Result<ProfileReport>>> {
let start = Instant::now();
log::info!("DataFusion: Preparing query (incremental)");
let df = self
.ctx
.sql(query)
.await
.context(format!("Failed to execute query: '{}'", query))?;
let mut analyzer = RecordBatchAnalyzer::new();
let query_owned = query.to_string();
let stream = df
.execute_stream()
.await
.context("Failed to execute query stream")?
.map(move |batch| {
let batch = batch.context("Failed to fetch batch")?;
if batch.num_rows() > 0 {
analyzer.process_batch(&batch)?;
}
let column_profiles = analyzer.to_profiles(false, false, None);
let sample_columns = analyzer.create_sample_columns();
let total_rows = analyzer.total_rows();
let scan_time_ms = start.elapsed().as_millis();
let num_columns = column_profiles.len();
Ok(ReportAssembler::new(
DataSource::Query {
engine: QueryEngine::DataFusion,
statement: query_owned.clone(),
database: None,
execution_id: None,
},
ExecutionMetadata::new(total_rows, num_columns, scan_time_ms),
)
.columns(column_profiles)
.with_quality_data(sample_columns)
.build())
});
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::Builder;
#[tokio::test]
async fn test_datafusion_csv_profiling() -> Result<()> {
let mut temp_file = Builder::new().suffix(".csv").tempfile()?;
writeln!(temp_file, "name,age,salary")?;
writeln!(temp_file, "Alice,25,50000.0")?;
writeln!(temp_file, "Bob,30,60000.5")?;
writeln!(temp_file, "Charlie,35,70000.0")?;
temp_file.flush()?;
let loader = DataFusionLoader::new();
loader
.register_csv("test_table", temp_file.path().to_str().unwrap())
.await?;
let report = loader.profile_query("SELECT * FROM test_table").await?;
assert_eq!(report.column_profiles.len(), 3);
assert_eq!(report.execution.rows_processed, 3);
Ok(())
}
#[tokio::test]
async fn test_datafusion_sql_aggregation() -> Result<()> {
let mut temp_file = Builder::new().suffix(".csv").tempfile()?;
writeln!(temp_file, "category,value")?;
writeln!(temp_file, "A,10")?;
writeln!(temp_file, "B,20")?;
writeln!(temp_file, "A,30")?;
writeln!(temp_file, "B,40")?;
temp_file.flush()?;
let loader = DataFusionLoader::new();
loader
.register_csv("data", temp_file.path().to_str().unwrap())
.await?;
let report = loader
.profile_query("SELECT category, SUM(value) as total FROM data GROUP BY category")
.await?;
assert_eq!(report.column_profiles.len(), 2);
assert_eq!(report.execution.rows_processed, 2);
Ok(())
}
#[tokio::test]
async fn test_invalid_table_name() -> Result<()> {
let loader = DataFusionLoader::new();
let result = loader.profile_table("invalid-table-name").await;
assert!(result.is_err());
Ok(())
}
}