use crate::expressions::Column;
use crate::intervals::cp_solver::PropagationResult;
use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval, IntervalBound};
use crate::utils::collect_columns;
use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData};
use arrow::compute::{and_kleene, filter_record_batch, is_not_null, SlicesIterator};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::DataPtr;
use datafusion_common::{ColumnStatistics, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use std::any::Any;
use std::fmt::{Debug, Display};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
fn as_any(&self) -> &dyn Any;
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
fn nullable(&self, input_schema: &Schema) -> Result<bool>;
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
fn evaluate_selection(
&self,
batch: &RecordBatch,
selection: &BooleanArray,
) -> Result<ColumnarValue> {
let tmp_batch = filter_record_batch(batch, selection)?;
let tmp_result = self.evaluate(&tmp_batch)?;
if batch.num_rows() == tmp_batch.num_rows() {
return Ok(tmp_result);
}
if let ColumnarValue::Array(a) = tmp_result {
let result = scatter(selection, a.as_ref())?;
Ok(ColumnarValue::Array(result))
} else {
Ok(tmp_result)
}
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;
fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
Err(DataFusionError::NotImplemented(format!(
"Not implemented for {self}"
)))
}
fn propagate_constraints(
&self,
_interval: &Interval,
_children: &[&Interval],
) -> Result<Vec<Option<Interval>>> {
Err(DataFusionError::NotImplemented(format!(
"Not implemented for {self}"
)))
}
fn dyn_hash(&self, _state: &mut dyn Hasher);
}
pub fn analyze(
expr: &Arc<dyn PhysicalExpr>,
context: AnalysisContext,
) -> Result<AnalysisContext> {
let target_boundaries = context.boundaries.ok_or_else(|| {
DataFusionError::Internal("No column exists at the input to filter".to_string())
})?;
let mut graph = ExprIntervalGraph::try_new(expr.clone())?;
let columns: Vec<Arc<dyn PhysicalExpr>> = collect_columns(expr)
.into_iter()
.map(|c| Arc::new(c) as Arc<dyn PhysicalExpr>)
.collect();
let target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)> =
graph.gather_node_indices(columns.as_slice());
let mut target_indices_and_boundaries: Vec<(usize, Interval)> =
target_expr_and_indices
.iter()
.filter_map(|(expr, i)| {
target_boundaries.iter().find_map(|bound| {
expr.as_any()
.downcast_ref::<Column>()
.filter(|expr_column| bound.column.eq(*expr_column))
.map(|_| (*i, bound.interval.clone()))
})
})
.collect();
match graph.update_ranges(&mut target_indices_and_boundaries)? {
PropagationResult::Success => {
shrink_boundaries(expr, graph, target_boundaries, target_expr_and_indices)
}
PropagationResult::Infeasible => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0))
}
PropagationResult::CannotPropagate => {
Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0))
}
}
}
fn shrink_boundaries(
expr: &Arc<dyn PhysicalExpr>,
mut graph: ExprIntervalGraph,
mut target_boundaries: Vec<ExprBoundaries>,
target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
) -> Result<AnalysisContext> {
let initial_boundaries = target_boundaries.clone();
target_expr_and_indices.iter().for_each(|(expr, i)| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(bound) = target_boundaries
.iter_mut()
.find(|bound| bound.column.eq(column))
{
bound.interval = graph.get_interval(*i);
};
}
});
let graph_nodes = graph.gather_node_indices(&[expr.clone()]);
let (_, root_index) = graph_nodes.first().ok_or_else(|| {
DataFusionError::Internal("Error in constructing predicate graph".to_string())
})?;
let final_result = graph.get_interval(*root_index);
let selectivity = calculate_selectivity(
&final_result.lower.value,
&final_result.upper.value,
&target_boundaries,
&initial_boundaries,
)?;
if !(0.0..=1.0).contains(&selectivity) {
return Err(DataFusionError::Internal(format!(
"Selectivity is out of limit: {}",
selectivity
)));
}
Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity))
}
fn calculate_selectivity(
lower_value: &ScalarValue,
upper_value: &ScalarValue,
target_boundaries: &[ExprBoundaries],
initial_boundaries: &[ExprBoundaries],
) -> Result<f64> {
match (lower_value, upper_value) {
(ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true))) => Ok(1.0),
(ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false))) => Ok(0.0),
_ => {
target_boundaries.iter().enumerate().try_fold(
1.0,
|acc, (i, ExprBoundaries { interval, .. })| {
let temp =
cardinality_ratio(&initial_boundaries[i].interval, interval)?;
Ok(acc * temp)
},
)
}
}
}
impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}
pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
#[derive(Clone, Debug, PartialEq)]
pub struct AnalysisContext {
pub boundaries: Option<Vec<ExprBoundaries>>,
pub selectivity: Option<f64>,
}
impl AnalysisContext {
pub fn new(boundaries: Vec<ExprBoundaries>) -> Self {
Self {
boundaries: Some(boundaries),
selectivity: None,
}
}
pub fn with_selectivity(mut self, selectivity: f64) -> Self {
self.selectivity = Some(selectivity);
self
}
pub fn from_statistics(
input_schema: &Schema,
statistics: &[ColumnStatistics],
) -> Self {
let mut column_boundaries = vec![];
for (idx, stats) in statistics.iter().enumerate() {
column_boundaries.push(ExprBoundaries::from_column(
stats,
input_schema.fields()[idx].name().clone(),
idx,
));
}
Self::new(column_boundaries)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ExprBoundaries {
pub column: Column,
pub interval: Interval,
pub distinct_count: Option<usize>,
}
impl ExprBoundaries {
pub fn from_column(stats: &ColumnStatistics, col: String, index: usize) -> Self {
Self {
column: Column::new(&col, index),
interval: Interval::new(
IntervalBound::new(
stats.min_value.clone().unwrap_or(ScalarValue::Null),
false,
),
IntervalBound::new(
stats.max_value.clone().unwrap_or(ScalarValue::Null),
false,
),
),
distinct_count: stats.distinct_count,
}
}
}
pub fn with_new_children_if_necessary(
expr: Arc<dyn PhysicalExpr>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let old_children = expr.children();
if children.len() != old_children.len() {
Err(DataFusionError::Internal(
"PhysicalExpr: Wrong number of children".to_string(),
))
} else if children.is_empty()
|| children
.iter()
.zip(old_children.iter())
.any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2))
{
expr.with_new_children(children)
} else {
Ok(expr)
}
}
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}
fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result<ArrayRef> {
let truthy = truthy.to_data();
let mask = and_kleene(mask, &is_not_null(mask)?)?;
let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len());
let mut filled = 0;
let mut true_pos = 0;
SlicesIterator::new(&mask).for_each(|(start, end)| {
if start > filled {
mutable.extend_nulls(start - filled);
}
let len = end - start;
mutable.extend(0, true_pos, true_pos + len);
true_pos += len;
filled = end;
});
if filled < mask.len() {
mutable.extend_nulls(mask.len() - filled);
}
let data = mutable.freeze();
Ok(make_array(data))
}
#[macro_export]
macro_rules! analysis_expect {
($context: ident, $expr: expr) => {
match $expr {
Some(expr) => expr,
None => return Ok($context.with_boundaries(None)),
}
};
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use arrow::array::Int32Array;
use datafusion_common::{
cast::{as_boolean_array, as_int32_array},
Result,
};
#[test]
fn scatter_int() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100]));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let expected =
Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_int_end_with_false() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100]));
let mask = BooleanArray::from(vec![true, false, true, false, false, false]);
let expected =
Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_with_null_mask() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11]));
let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None]
.into_iter()
.collect();
let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_int32_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn scatter_boolean() -> Result<()> {
let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true]));
let mask = BooleanArray::from(vec![true, true, false, false, true]);
let expected = BooleanArray::from_iter(vec![
Some(false),
Some(false),
None,
None,
Some(false),
]);
let result = scatter(&mask, truthy.as_ref())?;
let result = as_boolean_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
}