use std::sync::Arc;
use crate::expressions::Literal;
use crate::intervals::cp_solver::PropagationResult;
use crate::physical_expr::PhysicalExpr;
use crate::utils::{ExprTreeNode, build_dag};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::statistics::Distribution;
use datafusion_expr_common::interval_arithmetic::Interval;
use petgraph::Outgoing;
use petgraph::adj::DefaultIx;
use petgraph::prelude::Bfs;
use petgraph::stable_graph::{NodeIndex, StableGraph};
use petgraph::visit::DfsPostOrder;
#[derive(Clone, Debug)]
pub struct ExprStatisticsGraph {
graph: StableGraph<ExprStatisticsGraphNode, usize>,
root: NodeIndex,
}
#[derive(Clone, Debug)]
pub struct ExprStatisticsGraphNode {
expr: Arc<dyn PhysicalExpr>,
dist: Distribution,
}
impl ExprStatisticsGraphNode {
fn new_uniform(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Result<Self> {
Distribution::new_uniform(interval)
.map(|dist| ExprStatisticsGraphNode { expr, dist })
}
fn new_bernoulli(expr: Arc<dyn PhysicalExpr>) -> Result<Self> {
Distribution::new_bernoulli(ScalarValue::Float64(None))
.map(|dist| ExprStatisticsGraphNode { expr, dist })
}
fn new_generic(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
let interval = Interval::make_unbounded(dt)?;
let dist = Distribution::new_from_interval(interval)?;
Ok(ExprStatisticsGraphNode { expr, dist })
}
pub fn distribution(&self) -> &Distribution {
&self.dist
}
pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
let expr = Arc::clone(&node.expr);
if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
let value = literal.value();
Interval::try_new(value.clone(), value.clone())
.and_then(|interval| Self::new_uniform(expr, interval))
} else {
expr.data_type(schema).and_then(|dt| {
if dt.eq(&DataType::Boolean) {
Self::new_bernoulli(expr)
} else {
Self::new_generic(expr, &dt)
}
})
}
}
}
impl ExprStatisticsGraph {
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
let (root, graph) = build_dag(expr, &|node| {
ExprStatisticsGraphNode::make_node(node, schema)
})?;
Ok(Self { graph, root })
}
pub fn assign_statistics(&mut self, assignments: &[(usize, Distribution)]) {
for (index, stats) in assignments {
let node_index = NodeIndex::from(*index as DefaultIx);
self.graph[node_index].dist = stats.clone();
}
}
pub fn evaluate_statistics(&mut self) -> Result<&Distribution> {
let mut dfs = DfsPostOrder::new(&self.graph, self.root);
while let Some(idx) = dfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(idx, Outgoing);
let mut children_statistics = neighbors
.map(|child| self.graph[child].distribution())
.collect::<Vec<_>>();
if !children_statistics.is_empty() {
children_statistics.reverse();
self.graph[idx].dist = self.graph[idx]
.expr
.evaluate_statistics(&children_statistics)?;
}
}
Ok(self.graph[self.root].distribution())
}
pub fn propagate_statistics(
&mut self,
given_stats: Distribution,
) -> Result<PropagationResult> {
let root_range = self.graph[self.root].dist.range()?;
let given_range = given_stats.range()?;
if let Some(interval) = root_range.intersect(&given_range)? {
if interval != root_range {
let subset = root_range.contains(given_range)?;
self.graph[self.root].dist = if subset == Interval::TRUE {
given_stats
} else {
Distribution::new_from_interval(interval)?
};
}
} else {
return Ok(PropagationResult::Infeasible);
}
let mut bfs = Bfs::new(&self.graph, self.root);
while let Some(node) = bfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children = neighbors.collect::<Vec<_>>();
if children.is_empty() {
continue;
}
children.reverse();
let children_stats = children
.iter()
.map(|child| self.graph[*child].distribution())
.collect::<Vec<_>>();
let node_statistics = self.graph[node].distribution();
let propagated_statistics = self.graph[node]
.expr
.propagate_statistics(node_statistics, &children_stats)?;
if let Some(propagated_stats) = propagated_statistics {
for (child_idx, stats) in children.into_iter().zip(propagated_stats) {
self.graph[child_idx].dist = stats;
}
} else {
return Ok(PropagationResult::Infeasible);
}
}
Ok(PropagationResult::Success)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::expressions::{Column, binary, try_cast};
use crate::intervals::cp_solver::PropagationResult;
use crate::statistics::stats_solver::ExprStatisticsGraph;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::operator::Operator;
use datafusion_expr_common::statistics::Distribution;
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
pub fn binary_expr(
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
let left_type = left.data_type(schema)?;
let right_type = right.data_type(schema)?;
let binary_type_coercer = BinaryTypeCoercer::new(&left_type, &op, &right_type);
let (lhs, rhs) = binary_type_coercer.get_input_types()?;
let left_expr = try_cast(left, schema, lhs)?;
let right_expr = try_cast(right, schema, rhs)?;
binary(left_expr, op, right_expr, schema)
}
#[test]
fn test_stats_integration() -> Result<()> {
let schema = &Schema::new(vec![
Field::new("a", DataType::Float64, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Float64, false),
Field::new("d", DataType::Float64, false),
]);
let a = Arc::new(Column::new("a", 0)) as _;
let b = Arc::new(Column::new("b", 1)) as _;
let c = Arc::new(Column::new("c", 2)) as _;
let d = Arc::new(Column::new("d", 3)) as _;
let left = binary_expr(a, Operator::Plus, b, schema)?;
let right = binary_expr(c, Operator::Minus, d, schema)?;
let expr = binary_expr(left, Operator::Eq, right, schema)?;
let mut graph = ExprStatisticsGraph::try_new(expr, schema)?;
graph.assign_statistics(&[
(
0usize,
Distribution::new_uniform(Interval::make(Some(0.), Some(1.))?)?,
),
(
1usize,
Distribution::new_uniform(Interval::make(Some(0.), Some(2.))?)?,
),
(
3usize,
Distribution::new_uniform(Interval::make(Some(1.), Some(3.))?)?,
),
(
4usize,
Distribution::new_uniform(Interval::make(Some(1.), Some(5.))?)?,
),
]);
let ev_stats = graph.evaluate_statistics()?;
assert_eq!(
ev_stats,
&Distribution::new_bernoulli(ScalarValue::Float64(None))?
);
let one = ScalarValue::new_one(&DataType::Float64)?;
assert_eq!(
graph.propagate_statistics(Distribution::new_bernoulli(one)?)?,
PropagationResult::Success
);
Ok(())
}
}