use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type};
use crate::variation_const::DEFAULT_TYPE_VARIATION_REF;
use datafusion::common::{DFSchemaRef, ScalarValue};
use datafusion::logical_expr::{Cast, Expr, TryCast};
use substrait::proto::Expression;
use substrait::proto::expression::cast::FailureBehavior;
use substrait::proto::expression::literal::LiteralType;
use substrait::proto::expression::{Literal, RexType};
pub fn from_cast(
producer: &mut impl SubstraitProducer,
cast: &Cast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
let Cast { expr, data_type } = cast;
if let Expr::Literal(lit, _) = expr.as_ref() {
if *lit == ScalarValue::Null {
let lit = Literal {
nullable: true,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
literal_type: Some(LiteralType::Null(to_substrait_type(
producer, data_type, true,
)?)),
};
return Ok(Expression {
rex_type: Some(RexType::Literal(lit)),
});
}
}
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(producer, data_type, true)?),
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
failure_behavior: FailureBehavior::ThrowException.into(),
},
))),
})
}
pub fn from_try_cast(
producer: &mut impl SubstraitProducer,
cast: &TryCast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
let TryCast { expr, data_type } = cast;
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(producer, data_type, true)?),
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
failure_behavior: FailureBehavior::ReturnNull.into(),
},
))),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::producer::{
DefaultSubstraitProducer, to_substrait_extended_expr,
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DFSchema;
use datafusion::execution::SessionStateBuilder;
use datafusion::logical_expr::ExprSchemable;
use substrait::proto::expression_reference::ExprType;
#[tokio::test]
async fn fold_cast_null() {
let state = SessionStateBuilder::default().build();
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let field = Field::new("out", DataType::Int32, false);
let mut producer = DefaultSubstraitProducer::new(&state);
let expr = Expr::Literal(ScalarValue::Null, None)
.cast_to(&DataType::Int32, &empty_schema)
.unwrap();
let typed_null =
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)
.unwrap();
if let ExprType::Expression(expr) =
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
{
let lit = Literal {
nullable: true,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
literal_type: Some(LiteralType::Null(
to_substrait_type(&mut producer, &DataType::Int32, true).unwrap(),
)),
};
let expected = Expression {
rex_type: Some(RexType::Literal(lit)),
};
assert_eq!(*expr, expected);
} else {
panic!("Expected expression type");
}
let expr = Expr::Literal(ScalarValue::Int64(None), None)
.cast_to(&DataType::Int32, &empty_schema)
.unwrap();
let typed_null =
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)
.unwrap();
if let ExprType::Expression(expr) =
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
{
let cast_expr = substrait::proto::expression::Cast {
r#type: Some(
to_substrait_type(&mut producer, &DataType::Int32, true).unwrap(),
),
input: Some(Box::new(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
literal_type: Some(LiteralType::Null(
to_substrait_type(&mut producer, &DataType::Int64, true)
.unwrap(),
)),
})),
})),
failure_behavior: FailureBehavior::ThrowException as i32,
};
let expected = Expression {
rex_type: Some(RexType::Cast(Box::new(cast_expr))),
};
assert_eq!(*expr, expected);
} else {
panic!("Expected expression type");
}
}
}