use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::utils::scatter;
use arrow::array::{ArrayRef, BooleanArray, new_empty_array};
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
Result, ScalarValue, assert_eq_or_internal_err, exec_err, not_impl_err,
};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::placement::ExpressionPlacement;
use datafusion_expr_common::sort_properties::ExprProperties;
use datafusion_expr_common::statistics::Distribution;
use itertools::izip;
pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
fn as_any(&self) -> &dyn Any;
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
Ok(self.return_field(input_schema)?.data_type().to_owned())
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
Ok(self.return_field(input_schema)?.is_nullable())
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
format!("{self}"),
self.data_type(input_schema)?,
self.nullable(input_schema)?,
)))
}
fn evaluate_selection(
&self,
batch: &RecordBatch,
selection: &BooleanArray,
) -> Result<ColumnarValue> {
let row_count = batch.num_rows();
if row_count != selection.len() {
return exec_err!(
"Selection array length does not match batch row count: {} != {row_count}",
selection.len()
);
}
let selection_count = selection.true_count();
if selection_count == row_count {
return self.evaluate(batch);
}
let filtered_result = if selection_count == 0 {
let datatype = self.data_type(batch.schema_ref().as_ref())?;
ColumnarValue::Array(new_empty_array(&datatype))
} else {
let filtered_batch = filter_record_batch(batch, selection)?;
self.evaluate(&filtered_batch)?
};
match &filtered_result {
ColumnarValue::Array(a) => {
scatter(selection, a.as_ref()).map(ColumnarValue::Array)
}
ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
if let Some(v) = value {
if *v {
Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
} else {
Ok(filtered_result)
}
} else {
let array = BooleanArray::from(vec![None; row_count]);
scatter(selection, &array).map(ColumnarValue::Array)
}
}
ColumnarValue::Scalar(_) => Ok(filtered_result),
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;
fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
not_impl_err!("Not implemented for {self}")
}
fn propagate_constraints(
&self,
_interval: &Interval,
_children: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
Ok(Some(vec![]))
}
fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
let children_ranges = children
.iter()
.map(|c| c.range())
.collect::<Result<Vec<_>>>()?;
let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
let dt = output_interval.data_type();
if dt.eq(&DataType::Boolean) {
let p = if output_interval.eq(&Interval::TRUE) {
ScalarValue::new_one(&dt)
} else if output_interval.eq(&Interval::FALSE) {
ScalarValue::new_zero(&dt)
} else {
ScalarValue::try_from(&dt)
}?;
Distribution::new_bernoulli(p)
} else {
Distribution::new_from_interval(output_interval)
}
}
fn propagate_statistics(
&self,
parent: &Distribution,
children: &[&Distribution],
) -> Result<Option<Vec<Distribution>>> {
let children_ranges = children
.iter()
.map(|c| c.range())
.collect::<Result<Vec<_>>>()?;
let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
let parent_range = parent.range()?;
let Some(propagated_children) =
self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
else {
return Ok(None);
};
izip!(propagated_children.into_iter(), children_ranges, children)
.map(|(new_interval, old_interval, child)| {
if new_interval == old_interval {
Ok((*child).clone())
} else if new_interval.data_type().eq(&DataType::Boolean) {
let dt = old_interval.data_type();
let p = if new_interval.eq(&Interval::TRUE) {
ScalarValue::new_one(&dt)
} else if new_interval.eq(&Interval::FALSE) {
ScalarValue::new_zero(&dt)
} else {
unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
}?;
Distribution::new_bernoulli(p)
} else {
Distribution::new_from_interval(new_interval)
}
})
.collect::<Result<_>>()
.map(Some)
}
fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
Ok(ExprProperties::new_unknown())
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
Ok(None)
}
fn snapshot_generation(&self) -> u64 {
0
}
fn is_volatile_node(&self) -> bool {
false
}
fn placement(&self) -> ExpressionPlacement {
ExpressionPlacement::KeepInPlace
}
}
#[deprecated(
since = "50.0.0",
note = "Use `datafusion_expr_common::dyn_eq` instead"
)]
pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
impl PartialEq for dyn PhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other.as_any())
}
}
impl Eq for dyn PhysicalExpr {}
impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}
pub fn with_new_children_if_necessary(
expr: Arc<dyn PhysicalExpr>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let old_children = expr.children();
assert_eq_or_internal_err!(
children.len(),
old_children.len(),
"PhysicalExpr: Wrong number of children"
);
if children.is_empty()
|| children
.iter()
.zip(old_children.iter())
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
{
Ok(expr.with_new_children(children)?)
} else {
Ok(expr)
}
}
pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
where
T: IntoIterator,
T::Item: Display,
T::IntoIter: Clone,
{
struct DisplayWrapper<I>(I)
where
I: Iterator + Clone,
I::Item: Display;
impl<I> Display for DisplayWrapper<I>
where
I: Iterator + Clone,
I::Item: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut iter = self.0.clone();
write!(f, "[")?;
if let Some(expr) = iter.next() {
write!(f, "{expr}")?;
}
for expr in iter {
write!(f, ", {expr}")?;
}
write!(f, "]")?;
Ok(())
}
}
DisplayWrapper(exprs.into_iter())
}
pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
struct Wrapper<'a> {
expr: &'a dyn PhysicalExpr,
}
impl Display for Wrapper<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.expr.fmt_sql(f)?;
Ok(())
}
}
Wrapper { expr }
}
pub fn snapshot_physical_expr(
expr: Arc<dyn PhysicalExpr>,
) -> Result<Arc<dyn PhysicalExpr>> {
snapshot_physical_expr_opt(expr).data()
}
pub fn snapshot_physical_expr_opt(
expr: Arc<dyn PhysicalExpr>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
expr.transform_up(|e| {
if let Some(snapshot) = e.snapshot()? {
Ok(Transformed::yes(snapshot))
} else {
Ok(Transformed::no(Arc::clone(&e)))
}
})
}
pub fn snapshot_generation(expr: &Arc<dyn PhysicalExpr>) -> u64 {
let mut generation = 0u64;
expr.apply(|e| {
generation = generation.wrapping_add(e.snapshot_generation());
Ok(TreeNodeRecursion::Continue)
})
.expect("this traversal is infallible");
generation
}
pub fn is_dynamic_physical_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
snapshot_generation(expr) != 0
}
pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
if expr.is_volatile_node() {
return true;
}
let mut is_volatile = false;
expr.apply(|e| {
if e.is_volatile_node() {
is_volatile = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.expect("infallible closure should not fail");
is_volatile
}
#[cfg(test)]
mod test {
use crate::physical_expr::PhysicalExpr;
use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
use arrow::datatypes::{DataType, Schema};
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestExpr {}
impl PhysicalExpr for TestExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
Ok(DataType::Int64)
}
fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
Ok(false)
}
fn evaluate(
&self,
batch: &RecordBatch,
) -> datafusion_common::Result<ColumnarValue> {
let data = vec![1; batch.num_rows()];
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Self {}))
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("TestExpr")
}
}
impl Display for TestExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.fmt_sql(f)
}
}
macro_rules! assert_arrays_eq {
($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
let expected = $EXPECTED.to_array(1).unwrap();
let actual = $ACTUAL;
let actual_array = actual.to_array(expected.len()).unwrap();
let actual_ref = actual_array.as_ref();
let expected_ref = expected.as_ref();
assert!(
actual_ref == expected_ref,
"{}: expected: {:?}, actual: {:?}",
$MESSAGE,
$EXPECTED,
actual_ref
);
};
}
fn test_evaluate_selection(
batch: &RecordBatch,
selection: &BooleanArray,
expected: &ColumnarValue,
) {
let expr = TestExpr {};
let selection_result = expr.evaluate_selection(batch, selection).unwrap();
assert_eq!(
expected.to_array(1).unwrap().len(),
selection_result.to_array(1).unwrap().len(),
"evaluate_selection should output row count should match input record batch"
);
assert_arrays_eq!(
expected,
&selection_result,
"evaluate_selection returned unexpected value"
);
if (0..batch.num_rows())
.all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
{
let empty_result = expr.evaluate(batch).unwrap();
assert_arrays_eq!(
empty_result,
&selection_result,
"evaluate_selection does not match unfiltered evaluate result"
);
}
}
fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
let expr = TestExpr {};
let selection_result = expr.evaluate_selection(batch, selection);
assert!(selection_result.is_err(), "evaluate_selection should fail");
}
#[test]
pub fn test_evaluate_selection_with_empty_record_batch() {
test_evaluate_selection(
&RecordBatch::new_empty(Arc::new(Schema::empty())),
&BooleanArray::from(vec![false; 0]),
&ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
);
}
#[test]
pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
test_evaluate_selection_error(
&RecordBatch::new_empty(Arc::new(Schema::empty())),
&BooleanArray::from(vec![false; 10]),
);
}
#[test]
pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
test_evaluate_selection_error(
&RecordBatch::new_empty(Arc::new(Schema::empty())),
&BooleanArray::from(vec![true; 10]),
);
}
#[test]
pub fn test_evaluate_selection_with_non_empty_record_batch() {
test_evaluate_selection(
&unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
&BooleanArray::from(vec![true; 10]),
&ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
);
}
#[test]
pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection()
{
test_evaluate_selection_error(
&unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
&BooleanArray::from(vec![false; 20]),
);
}
#[test]
pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection()
{
test_evaluate_selection_error(
&unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
&BooleanArray::from(vec![true; 20]),
);
}
#[test]
pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection()
{
test_evaluate_selection_error(
&unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
&BooleanArray::from(vec![false; 5]),
);
}
#[test]
pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection()
{
test_evaluate_selection_error(
&unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
&BooleanArray::from(vec![true; 5]),
);
}
}