use super::Expr;
use crate::field_util::get_indexed_field;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, window_functions,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result};
pub trait ExprSchemable {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
}
impl ExprSchemable for Expr {
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
expr.get_type(schema)
}
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
Ok(data_type.clone())
}
Expr::ScalarUDF { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::ScalarFunction { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
}
Expr::WindowFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_functions::return_type(fun, &data_types)
}
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregates::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
ref left,
ref right,
ref op,
} => binary_operator_data_type(
&left.get_type(schema)?,
op,
&right.get_type(schema)?,
),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(schema)?;
get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
}
}
}
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::Between { expr, .. }
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case {
when_then_expr,
else_expr,
..
} => {
let then_nullable = when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = else_expr {
e.nullable(input_schema)
} else {
Ok(false)
}
}
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::ScalarVariable(_)
| Expr::TryCast { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. } => Ok(true),
Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false),
Expr::BinaryExpr {
ref left,
ref right,
..
} => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(input_schema)?;
get_indexed_field(&data_type, key).map(|x| x.is_nullable())
}
}
}
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
c.relation.as_deref(),
&c.name,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)),
_ => Ok(DFField::new(
None,
&self.name(input_schema)?,
self.get_type(input_schema)?,
self.nullable(input_schema)?,
)),
}
}
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self)
} else if can_cast_types(&this_type, cast_to_type) {
Ok(Expr::Cast {
expr: Box::new(self),
data_type: cast_to_type.clone(),
})
} else {
Err(DataFusionError::Plan(format!(
"Cannot automatically convert {:?} to {:?}",
this_type, cast_to_type
)))
}
}
}