use crate::expressions::Column;
use crate::intervals::cp_solver::PropagationResult;
use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval, IntervalBound};
use crate::utils::collect_columns;
use crate::PhysicalExpr;
use arrow::datatypes::Schema;
use datafusion_common::{
internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue,
};
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq)]
pub struct AnalysisContext {
pub boundaries: Option<Vec<ExprBoundaries>>,
pub selectivity: Option<f64>,
}
impl AnalysisContext {
pub fn new(boundaries: Vec<ExprBoundaries>) -> Self {
Self {
boundaries: Some(boundaries),
selectivity: None,
}
}
pub fn with_selectivity(mut self, selectivity: f64) -> Self {
self.selectivity = Some(selectivity);
self
}
pub fn from_statistics(
input_schema: &Schema,
statistics: &[ColumnStatistics],
) -> Self {
let mut column_boundaries = vec![];
for (idx, stats) in statistics.iter().enumerate() {
column_boundaries.push(ExprBoundaries::from_column(
stats,
input_schema.fields()[idx].name().clone(),
idx,
));
}
Self::new(column_boundaries)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ExprBoundaries {
pub column: Column,
pub interval: Interval,
pub distinct_count: Option<usize>,
}
impl ExprBoundaries {
pub fn from_column(stats: &ColumnStatistics, col: String, index: usize) -> Self {
Self {
column: Column::new(&col, index),
interval: Interval::new(
IntervalBound::new_closed(
stats.min_value.clone().unwrap_or(ScalarValue::Null),
),
IntervalBound::new_closed(
stats.max_value.clone().unwrap_or(ScalarValue::Null),
),
),
distinct_count: stats.distinct_count,
}
}
}
pub fn analyze(
expr: &Arc<dyn PhysicalExpr>,
context: AnalysisContext,
) -> Result<AnalysisContext> {
let target_boundaries = context.boundaries.ok_or_else(|| {
DataFusionError::Internal("No column exists at the input to filter".to_string())
})?;
let mut graph = ExprIntervalGraph::try_new(expr.clone())?;
let columns: Vec<Arc<dyn PhysicalExpr>> = collect_columns(expr)
.into_iter()
.map(|c| Arc::new(c) as Arc<dyn PhysicalExpr>)
.collect();
let target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)> =
graph.gather_node_indices(columns.as_slice());
let mut target_indices_and_boundaries: Vec<(usize, Interval)> =
target_expr_and_indices
.iter()
.filter_map(|(expr, i)| {
target_boundaries.iter().find_map(|bound| {
expr.as_any()
.downcast_ref::<Column>()
.filter(|expr_column| bound.column.eq(*expr_column))
.map(|_| (*i, bound.interval.clone()))
})
})
.collect();
match graph.update_ranges(&mut target_indices_and_boundaries)? {
PropagationResult::Success => {
shrink_boundaries(expr, graph, target_boundaries, target_expr_and_indices)
}
PropagationResult::Infeasible => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0))
}
PropagationResult::CannotPropagate => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0))
}
}
}
fn shrink_boundaries(
expr: &Arc<dyn PhysicalExpr>,
mut graph: ExprIntervalGraph,
mut target_boundaries: Vec<ExprBoundaries>,
target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
) -> Result<AnalysisContext> {
let initial_boundaries = target_boundaries.clone();
target_expr_and_indices.iter().for_each(|(expr, i)| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(bound) = target_boundaries
.iter_mut()
.find(|bound| bound.column.eq(column))
{
bound.interval = graph.get_interval(*i);
};
}
});
let graph_nodes = graph.gather_node_indices(&[expr.clone()]);
let (_, root_index) = graph_nodes.first().ok_or_else(|| {
DataFusionError::Internal("Error in constructing predicate graph".to_string())
})?;
let final_result = graph.get_interval(*root_index);
let selectivity = calculate_selectivity(
&final_result.lower.value,
&final_result.upper.value,
&target_boundaries,
&initial_boundaries,
)
.unwrap_or(1.0);
if !(0.0..=1.0).contains(&selectivity) {
return internal_err!("Selectivity is out of limit: {}", selectivity);
}
Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity))
}
fn calculate_selectivity(
lower_value: &ScalarValue,
upper_value: &ScalarValue,
target_boundaries: &[ExprBoundaries],
initial_boundaries: &[ExprBoundaries],
) -> Result<f64> {
match (lower_value, upper_value) {
(ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true))) => Ok(1.0),
(ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false))) => Ok(0.0),
_ => {
target_boundaries.iter().enumerate().try_fold(
1.0,
|acc, (i, ExprBoundaries { interval, .. })| {
let temp =
cardinality_ratio(&initial_boundaries[i].interval, interval)?;
Ok(acc * temp)
},
)
}
}
}