use std::fmt::Debug;
use std::sync::Arc;
use crate::expressions::Column;
use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult};
use crate::utils::collect_columns;
use crate::PhysicalExpr;
use arrow::datatypes::Schema;
use datafusion_common::stats::Precision;
use datafusion_common::{
internal_datafusion_err, internal_err, ColumnStatistics, Result, ScalarValue,
};
use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval};
#[derive(Clone, Debug, PartialEq)]
pub struct AnalysisContext {
pub boundaries: Vec<ExprBoundaries>,
pub selectivity: Option<f64>,
}
impl AnalysisContext {
pub fn new(boundaries: Vec<ExprBoundaries>) -> Self {
Self {
boundaries,
selectivity: None,
}
}
pub fn with_selectivity(mut self, selectivity: f64) -> Self {
self.selectivity = Some(selectivity);
self
}
pub fn try_from_statistics(
input_schema: &Schema,
statistics: &[ColumnStatistics],
) -> Result<Self> {
statistics
.iter()
.enumerate()
.map(|(idx, stats)| ExprBoundaries::try_from_column(input_schema, stats, idx))
.collect::<Result<Vec<_>>>()
.map(Self::new)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ExprBoundaries {
pub column: Column,
pub interval: Option<Interval>,
pub distinct_count: Precision<usize>,
}
impl ExprBoundaries {
pub fn try_from_column(
schema: &Schema,
col_stats: &ColumnStatistics,
col_index: usize,
) -> Result<Self> {
let field = schema.fields().get(col_index).ok_or_else(|| {
internal_datafusion_err!(
"Could not create `ExprBoundaries`: in `try_from_column` `col_index`
has gone out of bounds with a value of {col_index}, the schema has {} columns.",
schema.fields.len()
)
})?;
let empty_field =
ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null);
let interval = Interval::try_new(
col_stats
.min_value
.get_value()
.cloned()
.unwrap_or_else(|| empty_field.clone()),
col_stats
.max_value
.get_value()
.cloned()
.unwrap_or(empty_field),
)?;
let column = Column::new(field.name(), col_index);
Ok(ExprBoundaries {
column,
interval: Some(interval),
distinct_count: col_stats.distinct_count,
})
}
pub fn try_new_unbounded(schema: &Schema) -> Result<Vec<Self>> {
schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
Ok(Self {
column: Column::new(field.name(), i),
interval: Some(Interval::make_unbounded(field.data_type())?),
distinct_count: Precision::Absent,
})
})
.collect()
}
}
pub fn analyze(
expr: &Arc<dyn PhysicalExpr>,
context: AnalysisContext,
schema: &Schema,
) -> Result<AnalysisContext> {
let initial_boundaries = &context.boundaries;
if initial_boundaries
.iter()
.all(|bound| bound.interval.is_none())
{
if initial_boundaries
.iter()
.any(|bound| bound.distinct_count != Precision::Exact(0))
{
return internal_err!(
"ExprBoundaries has a non-zero distinct count although it represents an empty table"
);
}
if context.selectivity != Some(0.0) {
return internal_err!(
"AnalysisContext has a non-zero selectivity although it represents an empty table"
);
}
Ok(context)
} else if initial_boundaries
.iter()
.any(|bound| bound.interval.is_none())
{
internal_err!(
"AnalysisContext is an inconsistent state. Some columns represent empty table while others don't"
)
} else {
let mut target_boundaries = context.boundaries;
let mut graph = ExprIntervalGraph::try_new(Arc::clone(expr), schema)?;
let columns = collect_columns(expr)
.into_iter()
.map(|c| Arc::new(c) as _)
.collect::<Vec<_>>();
let mut target_indices_and_boundaries = vec![];
let target_expr_and_indices = graph.gather_node_indices(columns.as_slice());
for (expr, index) in &target_expr_and_indices {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(bound) =
target_boundaries.iter().find(|b| b.column == *column)
{
target_indices_and_boundaries
.push((*index, bound.interval.as_ref().unwrap().clone()));
}
}
}
match graph
.update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)?
{
PropagationResult::Success => {
shrink_boundaries(graph, target_boundaries, target_expr_and_indices)
}
PropagationResult::Infeasible => {
target_boundaries
.iter_mut()
.for_each(|bound| bound.interval = None);
Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0))
}
PropagationResult::CannotPropagate => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0))
}
}
}
}
fn shrink_boundaries(
graph: ExprIntervalGraph,
mut target_boundaries: Vec<ExprBoundaries>,
target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
) -> Result<AnalysisContext> {
let initial_boundaries = target_boundaries.clone();
target_expr_and_indices.iter().for_each(|(expr, i)| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(bound) = target_boundaries
.iter_mut()
.find(|bound| bound.column.eq(column))
{
bound.interval = Some(graph.get_interval(*i));
};
}
});
let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries)?;
if !(0.0..=1.0).contains(&selectivity) {
return internal_err!("Selectivity is out of limit: {}", selectivity);
}
Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity))
}
fn calculate_selectivity(
target_boundaries: &[ExprBoundaries],
initial_boundaries: &[ExprBoundaries],
) -> Result<f64> {
if target_boundaries.len() != initial_boundaries.len() {
return Err(internal_datafusion_err!(
"The number of columns in the initial and target boundaries should be the same"
));
}
let mut acc: f64 = 1.0;
for (initial, target) in initial_boundaries.iter().zip(target_boundaries) {
match (initial.interval.as_ref(), target.interval.as_ref()) {
(Some(initial), Some(target)) => {
acc *= cardinality_ratio(initial, target);
}
(None, Some(_)) => {
return internal_err!(
"Initial boundary cannot be None while having a Some() target boundary"
);
}
_ => return Ok(0.0),
}
}
Ok(acc)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema};
use datafusion_expr::{
col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, Expr,
};
use crate::{create_physical_expr, AnalysisContext};
use super::{analyze, ExprBoundaries};
fn make_field(name: &str, data_type: DataType) -> Field {
let nullable = false;
Field::new(name, data_type, nullable)
}
#[test]
fn test_analyze_boundary_exprs() {
let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)]));
type TestCase = (Expr, Option<i32>, Option<i32>);
let test_cases: Vec<TestCase> = vec![
(col("a").gt(lit(10)), Some(11), None),
(col("a").lt(lit(20)), None, Some(19)),
(
col("a").gt(lit(10)).and(col("a").lt(lit(20))),
Some(11),
Some(19),
),
(col("a").gt_eq(lit(10)), Some(10), None),
(col("a").lt_eq(lit(20)), None, Some(20)),
(
col("a").gt_eq(lit(10)).and(col("a").lt_eq(lit(20))),
Some(10),
Some(20),
),
(
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").lt(lit(15))),
Some(11),
Some(14),
),
(
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").gt(lit(15)))
.and(col("a").lt(lit(25))),
Some(16),
Some(19),
),
];
for (expr, lower, upper) in test_cases {
let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap();
let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap();
let physical_expr =
create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap();
let analysis_result = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)
.unwrap();
let Some(actual) = &analysis_result.boundaries[0].interval else {
panic!("The analysis result should contain non-empty intervals for all columns");
};
let expected = Interval::make(lower, upper).unwrap();
assert_eq!(
&expected, actual,
"did not get correct interval for SQL expression: {expr:?}"
);
}
}
#[test]
fn test_analyze_empty_set_boundary_exprs() {
let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)]));
let test_cases: Vec<Expr> = vec![
col("a").gt(lit(10)).and(col("a").lt(lit(10))),
col("a")
.gt(lit(10))
.and(col("a").lt(lit(20)))
.and(col("a").gt(lit(20)))
.and(col("a").lt(lit(30))),
];
for expr in test_cases {
let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap();
let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap();
let physical_expr =
create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap();
let analysis_result = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)
.unwrap();
for boundary in analysis_result.boundaries {
assert!(boundary.interval.is_none());
}
}
}
#[test]
fn test_analyze_invalid_boundary_exprs() {
let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)]));
let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20)));
let expected_error = "OR operator cannot yet propagate true intervals";
let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap();
let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap();
let physical_expr =
create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap();
let analysis_error = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)
.unwrap_err();
assert_contains!(analysis_error.to_string(), expected_error);
}
}