use std::{
any::Any,
hash::{Hash, Hasher},
sync::Arc,
};
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::Result;
use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr};
use crate::utils::down_cast_any_ref;
#[derive(Debug, Hash)]
pub struct IfExpr {
if_expr: Arc<dyn PhysicalExpr>,
true_expr: Arc<dyn PhysicalExpr>,
false_expr: Arc<dyn PhysicalExpr>,
case_expr: Arc<CaseExpr>,
}
impl std::fmt::Display for IfExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"If [if: {}, true_expr: {}, false_expr: {}]",
self.if_expr, self.true_expr, self.false_expr
)
}
}
impl IfExpr {
pub fn new(
if_expr: Arc<dyn PhysicalExpr>,
true_expr: Arc<dyn PhysicalExpr>,
false_expr: Arc<dyn PhysicalExpr>,
) -> Self {
Self {
if_expr: Arc::clone(&if_expr),
true_expr: Arc::clone(&true_expr),
false_expr: Arc::clone(&false_expr),
case_expr: Arc::new(
CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
),
}
}
}
impl PhysicalExpr for IfExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
let data_type = self.true_expr.data_type(input_schema)?;
Ok(data_type)
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? {
Ok(true)
} else {
Ok(false)
}
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
self.case_expr.evaluate(batch)
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.if_expr, &self.true_expr, &self.false_expr]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(IfExpr::new(
Arc::clone(&children[0]),
Arc::clone(&children[1]),
Arc::clone(&children[2]),
)))
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.if_expr.hash(&mut s);
self.true_expr.hash(&mut s);
self.false_expr.hash(&mut s);
self.hash(&mut s);
}
}
impl PartialEq<dyn Any> for IfExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.if_expr.eq(&x.if_expr)
&& self.true_expr.eq(&x.true_expr)
&& self.false_expr.eq(&x.false_expr)
})
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use arrow::{array::StringArray, datatypes::*};
use arrow_array::Int32Array;
use datafusion::logical_expr::Operator;
use datafusion_common::cast::as_int32_array;
use datafusion_physical_expr::expressions::{binary, col, lit};
use super::*;
fn if_fn(
if_expr: Arc<dyn PhysicalExpr>,
true_expr: Arc<dyn PhysicalExpr>,
false_expr: Arc<dyn PhysicalExpr>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
}
#[test]
fn test_if_1() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
let schema_ref = batch.schema();
let if_expr = binary(
col("a", &schema_ref)?,
Operator::Eq,
lit("foo"),
&schema_ref,
)?;
let true_expr = lit(123i32);
let false_expr = lit(999i32);
let expr = if_fn(if_expr, true_expr, false_expr);
let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_int32_array(&result)?;
let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_if_2() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
let schema_ref = batch.schema();
let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?;
let true_expr = lit(123i32);
let false_expr = lit(999i32);
let expr = if_fn(if_expr, true_expr, false_expr);
let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_int32_array(&result)?;
let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_if_children() {
let if_expr = lit(true);
let true_expr = lit(123i32);
let false_expr = lit(999i32);
let expr = if_fn(if_expr, true_expr, false_expr).unwrap();
let children = expr.children();
assert_eq!(children.len(), 3);
assert_eq!(children[0].to_string(), "true");
assert_eq!(children[1].to_string(), "123");
assert_eq!(children[2].to_string(), "999");
}
}