use std::sync::Arc;
use arrow::array::new_null_array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::is_volatile;
use crate::PhysicalExpr;
use crate::expressions::{Column, Literal};
pub fn simplify_const_expr(
expr: &Arc<dyn PhysicalExpr>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if is_volatile(expr) || has_column_references(expr) {
return Ok(Transformed::no(Arc::clone(expr)));
}
let batch = create_dummy_batch()?;
match expr.evaluate(&batch) {
Ok(ColumnarValue::Scalar(scalar)) => {
Ok(Transformed::yes(Arc::new(Literal::new(scalar))))
}
Ok(ColumnarValue::Array(arr)) if arr.len() == 1 => {
let scalar = ScalarValue::try_from_array(&arr, 0)?;
Ok(Transformed::yes(Arc::new(Literal::new(scalar))))
}
Ok(_) => {
Ok(Transformed::no(Arc::clone(expr)))
}
Err(_) => {
Ok(Transformed::no(Arc::clone(expr)))
}
}
}
fn create_dummy_batch() -> Result<RecordBatch> {
let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)]));
let col = new_null_array(&DataType::Null, 1);
Ok(RecordBatch::try_new(dummy_schema, vec![col])?)
}
pub fn has_column_references(expr: &Arc<dyn PhysicalExpr>) -> bool {
let mut has_columns = false;
expr.apply(|expr| {
if expr.as_any().downcast_ref::<Column>().is_some() {
has_columns = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.expect("apply should not fail");
has_columns
}