use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::utils::scatter;
use arrow::array::BooleanArray;
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::sort_properties::ExprProperties;
use datafusion_expr_common::statistics::Distribution;
use itertools::izip;
pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
fn as_any(&self) -> &dyn Any;
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
fn nullable(&self, input_schema: &Schema) -> Result<bool>;
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
fn evaluate_selection(
&self,
batch: &RecordBatch,
selection: &BooleanArray,
) -> Result<ColumnarValue> {
let tmp_batch = filter_record_batch(batch, selection)?;
let tmp_result = self.evaluate(&tmp_batch)?;
if batch.num_rows() == tmp_batch.num_rows() {
Ok(tmp_result)
} else if let ColumnarValue::Array(a) = tmp_result {
scatter(selection, a.as_ref()).map(ColumnarValue::Array)
} else {
Ok(tmp_result)
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;
fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
not_impl_err!("Not implemented for {self}")
}
fn propagate_constraints(
&self,
_interval: &Interval,
_children: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
Ok(Some(vec![]))
}
fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
let children_ranges = children
.iter()
.map(|c| c.range())
.collect::<Result<Vec<_>>>()?;
let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
let dt = output_interval.data_type();
if dt.eq(&DataType::Boolean) {
let p = if output_interval.eq(&Interval::CERTAINLY_TRUE) {
ScalarValue::new_one(&dt)
} else if output_interval.eq(&Interval::CERTAINLY_FALSE) {
ScalarValue::new_zero(&dt)
} else {
ScalarValue::try_from(&dt)
}?;
Distribution::new_bernoulli(p)
} else {
Distribution::new_from_interval(output_interval)
}
}
fn propagate_statistics(
&self,
parent: &Distribution,
children: &[&Distribution],
) -> Result<Option<Vec<Distribution>>> {
let children_ranges = children
.iter()
.map(|c| c.range())
.collect::<Result<Vec<_>>>()?;
let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
let parent_range = parent.range()?;
let Some(propagated_children) =
self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
else {
return Ok(None);
};
izip!(propagated_children.into_iter(), children_ranges, children)
.map(|(new_interval, old_interval, child)| {
if new_interval == old_interval {
Ok((*child).clone())
} else if new_interval.data_type().eq(&DataType::Boolean) {
let dt = old_interval.data_type();
let p = if new_interval.eq(&Interval::CERTAINLY_TRUE) {
ScalarValue::new_one(&dt)
} else if new_interval.eq(&Interval::CERTAINLY_FALSE) {
ScalarValue::new_zero(&dt)
} else {
unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
}?;
Distribution::new_bernoulli(p)
} else {
Distribution::new_from_interval(new_interval)
}
})
.collect::<Result<_>>()
.map(Some)
}
fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
Ok(ExprProperties::new_unknown())
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
Ok(None)
}
}
pub trait DynEq {
fn dyn_eq(&self, other: &dyn Any) -> bool;
}
impl<T: Eq + Any> DynEq for T {
fn dyn_eq(&self, other: &dyn Any) -> bool {
other.downcast_ref::<Self>() == Some(self)
}
}
impl PartialEq for dyn PhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other.as_any())
}
}
impl Eq for dyn PhysicalExpr {}
pub trait DynHash {
fn dyn_hash(&self, _state: &mut dyn Hasher);
}
impl<T: Hash + Any> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.type_id().hash(&mut state);
self.hash(&mut state)
}
}
impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}
pub fn with_new_children_if_necessary(
expr: Arc<dyn PhysicalExpr>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let old_children = expr.children();
if children.len() != old_children.len() {
internal_err!("PhysicalExpr: Wrong number of children")
} else if children.is_empty()
|| children
.iter()
.zip(old_children.iter())
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
{
Ok(expr.with_new_children(children)?)
} else {
Ok(expr)
}
}
#[deprecated(since = "44.0.0")]
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}
pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
where
T: IntoIterator,
T::Item: Display,
T::IntoIter: Clone,
{
struct DisplayWrapper<I>(I)
where
I: Iterator + Clone,
I::Item: Display;
impl<I> Display for DisplayWrapper<I>
where
I: Iterator + Clone,
I::Item: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut iter = self.0.clone();
write!(f, "[")?;
if let Some(expr) = iter.next() {
write!(f, "{}", expr)?;
}
for expr in iter {
write!(f, ", {}", expr)?;
}
write!(f, "]")?;
Ok(())
}
}
DisplayWrapper(exprs.into_iter())
}
pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
struct Wrapper<'a> {
expr: &'a dyn PhysicalExpr,
}
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.expr.fmt_sql(f)?;
Ok(())
}
}
Wrapper { expr }
}
pub fn snapshot_physical_expr(
expr: Arc<dyn PhysicalExpr>,
) -> Result<Arc<dyn PhysicalExpr>> {
expr.transform_up(|e| {
if let Some(snapshot) = e.snapshot()? {
Ok(Transformed::yes(snapshot))
} else {
Ok(Transformed::no(Arc::clone(&e)))
}
})
.data()
}