use std::fmt::Debug;
use std::sync::Arc;
use crate::expressions::Column;
use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult};
use crate::utils::collect_columns;
use crate::PhysicalExpr;
use arrow::datatypes::Schema;
use datafusion_common::stats::Precision;
use datafusion_common::{
internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval};
#[derive(Clone, Debug, PartialEq)]
pub struct AnalysisContext {
pub boundaries: Vec<ExprBoundaries>,
pub selectivity: Option<f64>,
}
impl AnalysisContext {
pub fn new(boundaries: Vec<ExprBoundaries>) -> Self {
Self {
boundaries,
selectivity: None,
}
}
pub fn with_selectivity(mut self, selectivity: f64) -> Self {
self.selectivity = Some(selectivity);
self
}
pub fn try_from_statistics(
input_schema: &Schema,
statistics: &[ColumnStatistics],
) -> Result<Self> {
statistics
.iter()
.enumerate()
.map(|(idx, stats)| ExprBoundaries::try_from_column(input_schema, stats, idx))
.collect::<Result<Vec<_>>>()
.map(Self::new)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ExprBoundaries {
pub column: Column,
pub interval: Interval,
pub distinct_count: Precision<usize>,
}
impl ExprBoundaries {
pub fn try_from_column(
schema: &Schema,
col_stats: &ColumnStatistics,
col_index: usize,
) -> Result<Self> {
let field = &schema.fields()[col_index];
let empty_field =
ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null);
let interval = Interval::try_new(
col_stats
.min_value
.get_value()
.cloned()
.unwrap_or(empty_field.clone()),
col_stats
.max_value
.get_value()
.cloned()
.unwrap_or(empty_field),
)?;
let column = Column::new(field.name(), col_index);
Ok(ExprBoundaries {
column,
interval,
distinct_count: col_stats.distinct_count.clone(),
})
}
pub fn try_new_unbounded(schema: &Schema) -> Result<Vec<Self>> {
schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
Ok(Self {
column: Column::new(field.name(), i),
interval: Interval::make_unbounded(field.data_type())?,
distinct_count: Precision::Absent,
})
})
.collect()
}
}
pub fn analyze(
expr: &Arc<dyn PhysicalExpr>,
context: AnalysisContext,
schema: &Schema,
) -> Result<AnalysisContext> {
let target_boundaries = context.boundaries;
let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?;
let columns = collect_columns(expr)
.into_iter()
.map(|c| Arc::new(c) as _)
.collect::<Vec<_>>();
let target_expr_and_indices = graph.gather_node_indices(columns.as_slice());
let mut target_indices_and_boundaries = target_expr_and_indices
.iter()
.filter_map(|(expr, i)| {
target_boundaries.iter().find_map(|bound| {
expr.as_any()
.downcast_ref::<Column>()
.filter(|expr_column| bound.column.eq(*expr_column))
.map(|_| (*i, bound.interval.clone()))
})
})
.collect::<Vec<_>>();
match graph
.update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)?
{
PropagationResult::Success => {
shrink_boundaries(graph, target_boundaries, target_expr_and_indices)
}
PropagationResult::Infeasible => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0))
}
PropagationResult::CannotPropagate => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0))
}
}
}
fn shrink_boundaries(
graph: ExprIntervalGraph,
mut target_boundaries: Vec<ExprBoundaries>,
target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
) -> Result<AnalysisContext> {
let initial_boundaries = target_boundaries.clone();
target_expr_and_indices.iter().for_each(|(expr, i)| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(bound) = target_boundaries
.iter_mut()
.find(|bound| bound.column.eq(column))
{
bound.interval = graph.get_interval(*i);
};
}
});
let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries);
if !(0.0..=1.0).contains(&selectivity) {
return internal_err!("Selectivity is out of limit: {}", selectivity);
}
Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity))
}
fn calculate_selectivity(
target_boundaries: &[ExprBoundaries],
initial_boundaries: &[ExprBoundaries],
) -> f64 {
initial_boundaries
.iter()
.zip(target_boundaries.iter())
.fold(1.0, |acc, (initial, target)| {
acc * cardinality_ratio(&initial.interval, &target.interval)
})
}