use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, BinaryExpr, Cast, GetIndexedField, Sort, TryCast, WindowFunction,
};
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::binary_operator_data_type;
use crate::{aggregate_function, function, window_function};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
use datafusion_common::{Column, DFField, DFSchema, DataFusionError, ExprSchema, Result};
pub trait ExprSchemable {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}
impl ExprSchemable for Expr {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(expr, name) => match &**expr {
Expr::Placeholder { data_type, .. } => match &data_type {
None => schema.data_type(&Column::from_name(name)).cloned(),
Some(dt) => Ok(dt.clone()),
},
_ => expr.get_type(schema),
},
Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
Expr::ScalarUDF { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::ScalarFunction { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
function::return_type(fun, &data_types)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_function::return_type(fun, &data_types)
}
Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregate_function::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).data_type().clone())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
ref op,
}) => binary_operator_data_type(
&left.get_type(schema)?,
op,
&right.get_type(schema)?,
),
Expr::Like { .. } | Expr::ILike { .. } | Expr::SimilarTo { .. } => {
Ok(DataType::Boolean)
}
Expr::Placeholder { data_type, .. } => data_type.clone().ok_or_else(|| {
DataFusionError::Plan("Placeholder type could not be resolved".to_owned())
}),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
Expr::GroupingSet(_) => {
Ok(DataType::Null)
}
Expr::GetIndexedField(GetIndexedField { key, expr }) => {
let data_type = expr.get_type(schema)?;
get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
}
}
}
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort(Sort { expr, .. })
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Between(Between { expr, .. }) => expr.nullable(input_schema),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case(case) => {
let then_nullable = case
.when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = &case.else_expr {
e.nullable(input_schema)
} else {
Ok(true)
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. } => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. }
| Expr::Placeholder { .. } => Ok(true),
Expr::InSubquery { expr, .. } => expr.nullable(input_schema),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).is_nullable())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
..
}) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Like(Like { expr, .. }) => expr.nullable(input_schema),
Expr::ILike(Like { expr, .. }) => expr.nullable(input_schema),
Expr::SimilarTo(Like { expr, .. }) => expr.nullable(input_schema),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
Expr::GetIndexedField(GetIndexedField { key, expr }) => {
let data_type = expr.get_type(input_schema)?;
get_indexed_field(&data_type, key).map(|x| x.is_nullable())
}
Expr::GroupingSet(_) => {
Ok(true)
}
}
}
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
c.relation.as_deref(),
&c.name,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)),
_ => Ok(DFField::new(
None,
&self.display_name()?,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)),
}
}
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self)
} else if can_cast_types(&this_type, cast_to_type) {
Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone())))
} else {
Err(DataFusionError::Plan(format!(
"Cannot automatically convert {this_type:?} to {cast_to_type:?}"
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, lit};
use arrow::datatypes::DataType;
use datafusion_common::Column;
#[test]
fn expr_schema_nullability() {
let expr = col("foo").eq(lit(1));
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
assert!(expr
.nullable(&MockExprSchema::new().with_nullable(true))
.unwrap());
}
#[test]
fn expr_schema_data_type() {
let expr = col("foo");
assert_eq!(
DataType::Utf8,
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
.unwrap()
);
}
struct MockExprSchema {
nullable: bool,
data_type: DataType,
}
impl MockExprSchema {
fn new() -> Self {
Self {
nullable: false,
data_type: DataType::Null,
}
}
fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
}
impl ExprSchema for MockExprSchema {
fn nullable(&self, _col: &Column) -> Result<bool> {
Ok(self.nullable)
}
fn data_type(&self, _col: &Column) -> Result<&DataType> {
Ok(&self.data_type)
}
}
}