use std::sync::Arc;
use std::collections::HashMap;
use std::path::Path;
use crate::error::{Result, Error};
use crate::distributed::config::DistributedConfig;
use crate::distributed::execution::{ExecutionContext, ExecutionPlan, ExecutionResult, Operation, AggregateExpr, JoinType, SortExpr};
use crate::distributed::partition::{PartitionSet, Partition, PartitionMetadata};
use crate::distributed::expr::{ColumnProjection, UdfDefinition, ExprSchema};
use crate::distributed::schema_validator::SchemaValidator;
use crate::distributed::explain::{ExplainOptions, ExplainFormat, explain_plan};
use super::conversion::{dataframe_to_record_batches, record_batches_to_dataframe};
pub struct DataFusionContext {
config: DistributedConfig,
#[cfg(feature = "distributed")]
context: datafusion::execution::context::SessionContext,
datasets: HashMap<String, PartitionSet>,
schema_validator: SchemaValidator,
}
impl DataFusionContext {
pub fn new(config: DistributedConfig) -> Self {
#[cfg(feature = "distributed")]
let context = {
let mut config_builder = datafusion::execution::context::SessionConfig::new()
.with_target_partitions(config.concurrency())
.with_batch_size(8192);
config_builder = config_builder.set_str("parquet.parallel_read", "true");
if config.enable_optimization() {
for (rule, value) in config.optimizer_rules() {
config_builder = config_builder.set_str(
&format!("optimizer.{}", rule),
value
);
}
config_builder = config_builder
.set_str("statistics.enabled", "true")
.set_str("optimizer.statistics_based_join_ordering", "true");
} else {
config_builder = config_builder
.set_str("optimizer.skip_optimize", "true")
.set_str("statistics.enabled", "false");
}
datafusion::execution::context::SessionContext::new_with_config(config_builder)
};
Self {
config,
#[cfg(feature = "distributed")]
context,
datasets: HashMap::new(),
schema_validator: SchemaValidator::new(),
}
}
#[cfg(feature = "distributed")]
fn register_record_batches(&mut self, name: &str, batches: Vec<datafusion::arrow::record_batch::RecordBatch>) -> Result<()> {
if batches.is_empty() {
return Err(Error::InvalidInput(format!("No record batches to register for dataset {}", name)));
}
let schema = batches[0].schema();
let mem_table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
.map_err(|e| Error::DistributedProcessing(format!("Failed to create memory table: {}", e)))?;
self.context.register_table(name, Arc::new(mem_table))
.map_err(|e| Error::DistributedProcessing(format!("Failed to register table: {}", e)))?;
Ok(())
}
#[cfg(feature = "distributed")]
fn validate_plan(&self, plan: &ExecutionPlan) -> Result<()> {
if self.config.skip_validation() {
return Ok(());
}
self.schema_validator.validate_plan(plan)
}
#[cfg(feature = "distributed")]
fn convert_operation_to_sql(&self, plan: &ExecutionPlan) -> Result<String> {
let mut sql_components = Vec::new();
match &plan.operation() {
Operation::Select { columns } => {
let column_list = if columns.is_empty() {
"*".to_string()
} else {
columns.join(", ")
};
sql_components.push(format!("SELECT {}", column_list));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
},
Operation::Filter { predicate } => {
sql_components.push(format!("SELECT *"));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
sql_components.push(format!("WHERE {}", predicate));
},
Operation::Join {
right,
join_type,
left_keys,
right_keys
} => {
sql_components.push(format!("SELECT *"));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
let join_type_str = match join_type {
JoinType::Inner => "INNER JOIN",
JoinType::Left => "LEFT JOIN",
JoinType::Right => "RIGHT JOIN",
JoinType::Full => "FULL OUTER JOIN",
JoinType::Cross => "CROSS JOIN",
};
let on_clause = if *join_type == JoinType::Cross {
String::new()
} else {
let mut join_conditions = Vec::new();
for (i, left_key) in left_keys.iter().enumerate() {
if i < right_keys.len() {
let right_key = &right_keys[i];
join_conditions.push(format!("{}.{} = {}.{}",
plan.inputs()[0], left_key,
right, right_key));
}
}
format!(" ON {}", join_conditions.join(" AND "))
};
sql_components.push(format!("{} {}{}", join_type_str, right, on_clause));
},
Operation::GroupBy {
keys,
aggregates
} => {
let mut select_parts = Vec::new();
for key in keys {
select_parts.push(key.clone());
}
for agg in aggregates {
select_parts.push(format!("{}({}) as {}",
agg.function, agg.input, agg.output));
}
sql_components.push(format!("SELECT {}", select_parts.join(", ")));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
sql_components.push(format!("GROUP BY {}", keys.join(", ")));
},
Operation::OrderBy {
sort_exprs
} => {
sql_components.push(format!("SELECT *"));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
let mut order_parts = Vec::new();
for expr in sort_exprs {
let direction = if expr.ascending { "ASC" } else { "DESC" };
let nulls = if expr.nulls_first { "NULLS FIRST" } else { "NULLS LAST" };
order_parts.push(format!("{} {} {}", expr.column, direction, nulls));
}
sql_components.push(format!("ORDER BY {}", order_parts.join(", ")));
},
Operation::Limit {
limit
} => {
sql_components.push(format!("SELECT *"));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
sql_components.push(format!("LIMIT {}", limit));
},
Operation::Window {
window_functions,
} => {
let mut select_parts = Vec::new();
select_parts.push("input_table.*".to_string());
for wf in window_functions {
select_parts.push(wf.to_sql());
}
sql_components.push(format!("SELECT {}", select_parts.join(", ")));
sql_components.push(format!("FROM {} AS input_table", plan.inputs()[0]));
},
Operation::Custom {
name,
params
} => {
match name.as_str() {
"select_expr" => {
if let Some(projections_json) = params.get("projections") {
let projections: Vec<ColumnProjection> = serde_json::from_str(projections_json)
.map_err(|e| Error::DistributedProcessing(
format!("Failed to parse projections: {}", e)
))?;
let mut select_parts = Vec::new();
for projection in projections {
select_parts.push(projection.to_sql());
}
sql_components.push(format!("SELECT {}", select_parts.join(", ")));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
} else {
return Err(Error::InvalidOperation(
"select_expr operation requires projections parameter".to_string()
));
}
},
"with_column" => {
let column_name = params.get("column_name")
.ok_or_else(|| Error::InvalidOperation(
"with_column operation requires column_name parameter".to_string()
))?;
if let Some(projection_json) = params.get("projection") {
let projection: ColumnProjection = serde_json::from_str(projection_json)
.map_err(|e| Error::DistributedProcessing(
format!("Failed to parse projection: {}", e)
))?;
sql_components.push(format!(
"SELECT *, {} AS {}",
projection.expr, column_name
));
sql_components.push(format!("FROM {}", plan.inputs()[0]));
} else {
return Err(Error::InvalidOperation(
"with_column operation requires projection parameter".to_string()
));
}
},
"create_udf" => {
if let Some(udfs_json) = params.get("udfs") {
let udfs: Vec<UdfDefinition> = serde_json::from_str(udfs_json)
.map_err(|e| Error::DistributedProcessing(
format!("Failed to parse UDFs: {}", e)
))?;
let mut udf_statements = Vec::new();
for udf in udfs {
udf_statements.push(udf.to_sql());
}
return Ok(udf_statements.join(";\n"));
} else {
return Err(Error::InvalidOperation(
"create_udf operation requires udfs parameter".to_string()
));
}
},
_ => {
return Err(Error::NotImplemented(
format!("Custom operation '{}' cannot be converted to SQL", name)
));
}
}
},
}
Ok(sql_components.join(" "))
}
}
impl ExecutionContext for DataFusionContext {
fn execute(&self, plan: &ExecutionPlan) -> Result<ExecutionResult> {
#[cfg(feature = "distributed")]
{
use std::time::Instant;
let start_time = Instant::now();
self.validate_plan(plan)?;
let is_udf_creation = match plan.operation() {
Operation::Custom { name, .. } => name == "create_udf",
_ => false,
};
let sql = self.convert_operation_to_sql(plan)?;
let arrow_batches = if is_udf_creation {
let statements: Vec<&str> = sql.split(";\n").collect();
for stmt in &statements {
if !stmt.trim().is_empty() {
futures::executor::block_on(self.context.sql(stmt))
.map_err(|e| Error::DistributedProcessing(
format!("Failed to execute UDF creation: {}", e)
))?;
}
}
Vec::new()
} else {
let df = futures::executor::block_on(self.context.sql(&sql))
.map_err(|e| Error::DistributedProcessing(format!("Failed to execute SQL query: {}", e)))?;
futures::executor::block_on(df.collect())
.map_err(|e| Error::DistributedProcessing(format!("Failed to collect query results: {}", e)))?
};
let execution_time_ms = start_time.elapsed().as_millis() as u64;
let pandrs_df = record_batches_to_dataframe(&arrow_batches)?;
let batches_count = arrow_batches.len();
let row_count = pandrs_df.shape()?.0;
let bytes_processed =
if batches_count > 0 && row_count > 0 {
let schema = arrow_batches[0].schema();
let bytes_per_row = schema.fields().iter()
.map(|f| match f.data_type() {
arrow::datatypes::DataType::Int64 => 8,
arrow::datatypes::DataType::Float64 => 8,
arrow::datatypes::DataType::Utf8 => 24, arrow::datatypes::DataType::Boolean => 1,
_ => 8, })
.sum::<usize>();
bytes_per_row * row_count
} else {
0
};
let metrics = crate::distributed::execution::ExecutionMetrics::new(
execution_time_ms,
row_count,
batches_count,
bytes_processed,
bytes_processed, );
let batch_size = if batches_count > 0 { row_count / batches_count } else { row_count };
let result_batches = dataframe_to_record_batches(&pandrs_df, batch_size)?;
let mut partitions = Vec::new();
for (i, batch) in result_batches.iter().enumerate() {
let partition = Partition::new(i, batch.clone());
partitions.push(Arc::new(partition));
}
let partition_set = PartitionSet::new(
partitions,
result_batches[0].schema(),
);
Ok(ExecutionResult::new(
partition_set,
metrics,
))
}
#[cfg(not(feature = "distributed"))]
{
Err(Error::FeatureNotAvailable(
"Distributed processing is not available. Recompile with the 'distributed' feature flag.".to_string()
))
}
}
fn register_dataset(&mut self, name: &str, partitions: PartitionSet) -> Result<()> {
#[cfg(feature = "distributed")]
{
let mut batches = Vec::new();
for partition in partitions.partitions() {
if let Some(data) = partition.data() {
batches.push(data.clone());
}
}
if batches.is_empty() {
return Err(Error::InvalidInput(format!("No data partitions found for dataset {}", name)));
}
self.register_record_batches(name, batches)?;
let schema = partitions.schema();
self.schema_validator.register_arrow_schema(name, schema.clone())?;
self.datasets.insert(name.to_string(), partitions);
Ok(())
}
#[cfg(not(feature = "distributed"))]
{
self.datasets.insert(name.to_string(), partitions);
Ok(())
}
}
fn register_csv(&mut self, name: &str, path: &str) -> Result<()> {
#[cfg(feature = "distributed")]
{
if !Path::new(path).exists() {
return Err(Error::IoError(format!("CSV file not found: {}", path)));
}
futures::executor::block_on(
self.context.register_csv(name, path, datafusion::prelude::CsvReadOptions::new())
).map_err(|e| Error::DistributedProcessing(format!("Failed to register CSV: {}", e)))?;
Ok(())
}
#[cfg(not(feature = "distributed"))]
{
Err(Error::FeatureNotAvailable(
"CSV registration is not available. Recompile with the 'distributed' feature flag.".to_string()
))
}
}
fn register_parquet(&mut self, name: &str, path: &str) -> Result<()> {
#[cfg(feature = "distributed")]
{
if !Path::new(path).exists() {
return Err(Error::IoError(format!("Parquet file not found: {}", path)));
}
futures::executor::block_on(
self.context.register_parquet(name, path, datafusion::prelude::ParquetReadOptions::default())
).map_err(|e| Error::DistributedProcessing(format!("Failed to register Parquet: {}", e)))?;
Ok(())
}
#[cfg(not(feature = "distributed"))]
{
Err(Error::FeatureNotAvailable(
"Parquet registration is not available. Recompile with the 'distributed' feature flag.".to_string()
))
}
}
fn explain_plan(&self, plan: &ExecutionPlan, with_statistics: bool) -> Result<String> {
#[cfg(feature = "distributed")]
{
let options = ExplainOptions {
format: ExplainFormat::Text,
with_statistics,
optimized: self.config.enable_optimization(),
analyze: false,
};
let explanation = explain_plan(plan, &options)?;
if self.config.enable_optimization() {
let sql = self.convert_operation_to_sql(plan)?;
let df = futures::executor::block_on(self.context.sql(&format!("EXPLAIN {}", sql)))
.map_err(|e| Error::DistributedProcessing(
format!("Failed to explain query: {}", e)))?;
let batches = futures::executor::block_on(df.collect())
.map_err(|e| Error::DistributedProcessing(
format!("Failed to collect explain results: {}", e)))?;
let mut optimized_plan = String::new();
for batch in &batches {
if let Some(array) = batch.column(0).as_any().downcast_ref::<arrow::array::StringArray>() {
for i in 0..array.len() {
if !array.is_null(i) {
optimized_plan.push_str(array.value(i));
optimized_plan.push('\n');
}
}
}
}
let mut result = String::new();
result.push_str("=== Logical Plan ===\n");
result.push_str(&explanation);
result.push_str("\n\n=== Optimized Plan ===\n");
result.push_str(&optimized_plan);
Ok(result)
} else {
Ok(explanation)
}
}
#[cfg(not(feature = "distributed"))]
{
Err(Error::FeatureNotAvailable(
"Plan explanation is not available. Recompile with the 'distributed' feature flag.".to_string()
))
}
}
fn sql(&self, query: &str) -> Result<ExecutionResult> {
#[cfg(feature = "distributed")]
{
use std::time::Instant;
let start_time = Instant::now();
let df = futures::executor::block_on(self.context.sql(query))
.map_err(|e| Error::DistributedProcessing(format!("Failed to execute SQL query: {}", e)))?;
let arrow_batches = futures::executor::block_on(df.collect())
.map_err(|e| Error::DistributedProcessing(format!("Failed to collect query results: {}", e)))?;
let execution_time_ms = start_time.elapsed().as_millis() as u64;
let pandrs_df = record_batches_to_dataframe(&arrow_batches)?;
let batches_count = arrow_batches.len();
let row_count = pandrs_df.shape()?.0;
let bytes_processed =
if batches_count > 0 && row_count > 0 {
let schema = arrow_batches[0].schema();
let bytes_per_row = schema.fields().iter()
.map(|f| match f.data_type() {
arrow::datatypes::DataType::Int64 => 8,
arrow::datatypes::DataType::Float64 => 8,
arrow::datatypes::DataType::Utf8 => 24, arrow::datatypes::DataType::Boolean => 1,
_ => 8, })
.sum::<usize>();
bytes_per_row * row_count
} else {
0
};
let metrics = crate::distributed::execution::ExecutionMetrics::new(
execution_time_ms,
row_count,
batches_count,
bytes_processed,
bytes_processed, );
let batch_size = if batches_count > 0 { row_count / batches_count } else { row_count };
let result_batches = dataframe_to_record_batches(&pandrs_df, batch_size)?;
let mut partitions = Vec::new();
for (i, batch) in result_batches.iter().enumerate() {
let partition = Partition::new(i, batch.clone());
partitions.push(Arc::new(partition));
}
let partition_set = PartitionSet::new(
partitions,
if !result_batches.is_empty() { result_batches[0].schema() } else {
Arc::new(arrow::datatypes::Schema::empty())
},
);
Ok(ExecutionResult::new(
partition_set,
metrics,
))
}
#[cfg(not(feature = "distributed"))]
{
Err(Error::FeatureNotAvailable(
"SQL execution is not available. Recompile with the 'distributed' feature flag.".to_string()
))
}
}
}