use std::any::Any;
use std::fmt::{Debug, Display};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::intervals::Interval;
use crate::sort_properties::SortProperties;
use crate::utils::scatter;
use arrow::array::BooleanArray;
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::DataPtr;
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use itertools::izip;
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
fn as_any(&self) -> &dyn Any;
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
fn nullable(&self, input_schema: &Schema) -> Result<bool>;
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
fn evaluate_selection(
&self,
batch: &RecordBatch,
selection: &BooleanArray,
) -> Result<ColumnarValue> {
let tmp_batch = filter_record_batch(batch, selection)?;
let tmp_result = self.evaluate(&tmp_batch)?;
if batch.num_rows() == tmp_batch.num_rows() {
Ok(tmp_result)
} else if let ColumnarValue::Array(a) = tmp_result {
scatter(selection, a.as_ref()).map(ColumnarValue::Array)
} else {
Ok(tmp_result)
}
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;
fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
not_impl_err!("Not implemented for {self}")
}
fn propagate_constraints(
&self,
_interval: &Interval,
_children: &[&Interval],
) -> Result<Vec<Option<Interval>>> {
not_impl_err!("Not implemented for {self}")
}
fn dyn_hash(&self, _state: &mut dyn Hasher);
fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties {
SortProperties::Unordered
}
}
impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}
pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
pub fn with_new_children_if_necessary(
expr: Arc<dyn PhysicalExpr>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
let old_children = expr.children();
if children.len() != old_children.len() {
internal_err!("PhysicalExpr: Wrong number of children")
} else if children.is_empty()
|| children
.iter()
.zip(old_children.iter())
.any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2))
{
expr.with_new_children(children)
} else {
Ok(expr)
}
}
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}
pub fn physical_exprs_contains(
physical_exprs: &[Arc<dyn PhysicalExpr>],
expr: &Arc<dyn PhysicalExpr>,
) -> bool {
physical_exprs
.iter()
.any(|physical_expr| physical_expr.eq(expr))
}
pub fn have_common_entries(
lhs: &[Arc<dyn PhysicalExpr>],
rhs: &[Arc<dyn PhysicalExpr>],
) -> bool {
lhs.iter().any(|expr| physical_exprs_contains(rhs, expr))
}
pub fn physical_exprs_equal(
lhs: &[Arc<dyn PhysicalExpr>],
rhs: &[Arc<dyn PhysicalExpr>],
) -> bool {
lhs.len() == rhs.len() && izip!(lhs, rhs).all(|(lhs, rhs)| lhs.eq(rhs))
}
pub fn physical_exprs_bag_equal(
lhs: &[Arc<dyn PhysicalExpr>],
rhs: &[Arc<dyn PhysicalExpr>],
) -> bool {
if lhs.len() == rhs.len() {
let mut rhs_vec = rhs.to_vec();
for expr in lhs {
if let Some(idx) = rhs_vec.iter().position(|e| expr.eq(e)) {
rhs_vec.swap_remove(idx);
} else {
return false;
}
}
true
} else {
false
}
}
pub fn deduplicate_physical_exprs(exprs: &mut Vec<Arc<dyn PhysicalExpr>>) {
let mut idx = 0;
while idx < exprs.len() {
let mut rest_idx = idx + 1;
while rest_idx < exprs.len() {
if exprs[idx].eq(&exprs[rest_idx]) {
exprs.swap_remove(rest_idx);
} else {
rest_idx += 1;
}
}
idx += 1;
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::expressions::{Column, Literal};
use crate::physical_expr::{
deduplicate_physical_exprs, have_common_entries, physical_exprs_bag_equal,
physical_exprs_contains, physical_exprs_equal, PhysicalExpr,
};
use datafusion_common::ScalarValue;
#[test]
fn test_physical_exprs_contains() {
let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;
let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>;
let lit4 =
Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) as Arc<dyn PhysicalExpr>;
let lit2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let lit1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let col_a_expr = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let col_c_expr = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
let physical_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
lit_true.clone(),
lit_false.clone(),
lit4.clone(),
lit2.clone(),
col_a_expr.clone(),
col_b_expr.clone(),
];
assert!(physical_exprs_contains(&physical_exprs, &lit_true));
assert!(physical_exprs_contains(&physical_exprs, &lit2));
assert!(physical_exprs_contains(&physical_exprs, &col_b_expr));
assert!(!physical_exprs_contains(&physical_exprs, &col_c_expr));
assert!(!physical_exprs_contains(&physical_exprs, &lit1));
}
#[test]
fn test_have_common_entries() {
let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;
let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>;
let lit2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let lit1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let vec1 = vec![lit_true.clone(), lit_false.clone()];
let vec2 = vec![lit_true.clone(), col_b_expr.clone()];
let vec3 = vec![lit2.clone(), lit1.clone()];
assert!(have_common_entries(&vec1, &vec2));
assert!(!have_common_entries(&vec1, &vec3));
assert!(!have_common_entries(&vec2, &vec3));
}
#[test]
fn test_physical_exprs_equal() {
let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;
let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>;
let lit1 =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let lit2 =
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let vec1 = vec![lit_true.clone(), lit_false.clone()];
let vec2 = vec![lit_true.clone(), col_b_expr.clone()];
let vec3 = vec![lit2.clone(), lit1.clone()];
let vec4 = vec![lit_true.clone(), lit_false.clone()];
assert!(physical_exprs_equal(&vec1, &vec1));
assert!(physical_exprs_equal(&vec1, &vec4));
assert!(physical_exprs_bag_equal(&vec1, &vec1));
assert!(physical_exprs_bag_equal(&vec1, &vec4));
assert!(!physical_exprs_equal(&vec1, &vec2));
assert!(!physical_exprs_equal(&vec1, &vec3));
assert!(!physical_exprs_bag_equal(&vec1, &vec2));
assert!(!physical_exprs_bag_equal(&vec1, &vec3));
}
#[test]
fn test_physical_exprs_set_equal() {
let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("b", 1)),
];
let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("a", 0)),
];
assert!(!physical_exprs_bag_equal(
list1.as_slice(),
list2.as_slice()
));
assert!(!physical_exprs_bag_equal(
list2.as_slice(),
list1.as_slice()
));
assert!(!physical_exprs_equal(list1.as_slice(), list2.as_slice()));
assert!(!physical_exprs_equal(list2.as_slice(), list1.as_slice()));
let list3: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("c", 2)),
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("b", 1)),
];
let list4: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("b", 1)),
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("c", 2)),
Arc::new(Column::new("a", 0)),
];
assert!(physical_exprs_bag_equal(list3.as_slice(), list4.as_slice()));
assert!(physical_exprs_bag_equal(list4.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice()));
assert!(!physical_exprs_equal(list3.as_slice(), list4.as_slice()));
assert!(!physical_exprs_equal(list4.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice()));
}
#[test]
fn test_deduplicate_physical_exprs() {
let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>);
let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>);
let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4))))
as Arc<dyn PhysicalExpr>);
let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2))))
as Arc<dyn PhysicalExpr>);
let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>);
let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>);
let test_cases = vec![
(
vec![
lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr,
lit_true, lit2,
],
vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr],
),
(
vec![lit_true, lit_true, lit_false, lit4],
vec![lit_true, lit4, lit_false],
),
];
for (exprs, expected) in test_cases {
let mut exprs = exprs.into_iter().cloned().collect::<Vec<_>>();
let expected = expected.into_iter().cloned().collect::<Vec<_>>();
deduplicate_physical_exprs(&mut exprs);
assert!(physical_exprs_equal(&exprs, &expected));
}
}
}