use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types;
use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFField,
ExprSchema, Result,
};
use std::collections::HashMap;
use std::sync::Arc;
pub trait ExprSchemable {
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>;
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
-> Result<(DataType, bool)>;
}
impl ExprSchemable for Expr {
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
None => schema.data_type(&Column::from_name(name)).cloned(),
Some(dt) => Ok(dt.clone()),
},
_ => expr.get_type(schema),
},
Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::Literal(l) => Ok(l.data_type()),
Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
Expr::Unnest(Unnest { exprs }) => {
let arg_data_types = exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
let arg_data_type = arg_data_types[0].clone();
match arg_data_type{
DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{
Ok(field.data_type().clone())
}
DataType::Struct(_) => {
not_impl_err!("unnest() does not support struct yet")
}
DataType::Null => {
not_impl_err!("unnest() does not support null yet")
}
_ => {
plan_err!("unnest() can only be applied to array, struct and null")
}
}
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
data_types(&arg_data_types, &fun.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
&format!("{fun}"),
fun.signature(),
&arg_data_types,
)
)
})?;
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
data_types(&arg_data_types, fun.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
fun.name(),
fun.signature().clone(),
&arg_data_types,
)
)
})?;
Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
}
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
}
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
fun.return_type(&data_types)
}
AggregateFunctionDefinition::UDF(fun) => {
Ok(fun.return_type(&data_types)?)
}
AggregateFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
}
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).data_type().clone())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
ref op,
}) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?),
Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
Expr::Placeholder(Placeholder { data_type, .. }) => {
data_type.clone().ok_or_else(|| {
plan_datafusion_err!("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.")
})
}
Expr::Wildcard { qualifier } => {
match qualifier {
Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"),
None => Ok(DataType::Null)
}
}
Expr::GroupingSet(_) => {
Ok(DataType::Null)
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
field_for_index(expr, field, schema).map(|x| x.data_type().clone())
}
}
}
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
match self {
Expr::Alias(Alias { expr, .. })
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema),
Expr::InList(InList { expr, list, .. }) => {
const MAX_INSPECT_LIMIT: usize = 6;
let has_nullable = std::iter::once(expr.as_ref())
.chain(list)
.take(MAX_INSPECT_LIMIT)
.find_map(|e| {
e.nullable(input_schema)
.map(|nullable| if nullable { Some(()) } else { None })
.transpose()
})
.transpose()?;
Ok(match has_nullable {
Some(_) => true,
None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
_ => false,
})
}
Expr::Between(Between {
expr, low, high, ..
}) => Ok(expr.nullable(input_schema)?
|| low.nullable(input_schema)?
|| high.nullable(input_schema)?),
Expr::Column(c) => input_schema.nullable(c),
Expr::OuterReferenceColumn(_, _) => Ok(true),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case(case) => {
let then_nullable = case
.when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = &case.else_expr {
e.nullable(input_schema)
} else {
Ok(true)
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::Unnest(_)
| Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. } => Ok(false),
Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).is_nullable())
}
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
..
}) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
}
Expr::Wildcard { .. } => internal_err!(
"Wildcard expressions are not valid in a logical query plan"
),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
if let Expr::Column(col) = expr.as_ref() {
if input_schema.nullable(col)? {
return Ok(true);
}
}
field_for_index(expr, field, input_schema).map(|x| x.is_nullable())
}
Expr::GroupingSet(_) => {
Ok(true)
}
}
}
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
match self {
Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
_ => Ok(HashMap::new()),
}
}
fn data_type_and_nullable(
&self,
schema: &dyn ExprSchema,
) -> Result<(DataType, bool)> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
None => schema
.data_type_and_nullable(&Column::from_name(name))
.map(|(d, n)| (d.clone(), n)),
Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)),
},
_ => expr.data_type_and_nullable(schema),
},
Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => {
expr.data_type_and_nullable(schema)
}
Expr::Column(c) => schema
.data_type_and_nullable(c)
.map(|(d, n)| (d.clone(), n)),
Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)),
Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)),
Expr::Literal(l) => Ok((l.data_type(), l.is_null())),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. } => Ok((DataType::Boolean, false)),
Expr::ScalarSubquery(subquery) => Ok((
subquery.subquery.schema().field(0).data_type().clone(),
subquery.subquery.schema().field(0).is_nullable(),
)),
Expr::BinaryExpr(BinaryExpr {
ref left,
ref right,
ref op,
}) => {
let left = left.data_type_and_nullable(schema)?;
let right = right.data_type_and_nullable(schema)?;
Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1))
}
_ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
}
}
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> {
match self {
Expr::Column(c) => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok(
DFField::new(c.relation.clone(), &c.name, data_type, nullable)
.with_metadata(self.metadata(input_schema)?),
)
}
Expr::Alias(Alias { relation, name, .. }) => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok(DFField::new(relation.clone(), name, data_type, nullable)
.with_metadata(self.metadata(input_schema)?))
}
_ => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok(
DFField::new_unqualified(&self.display_name()?, data_type, nullable)
.with_metadata(self.metadata(input_schema)?),
)
}
}
}
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
return Ok(self);
}
if can_cast_types(&this_type, cast_to_type) {
match self {
Expr::ScalarSubquery(subquery) => {
Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
}
_ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
}
} else {
plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
}
}
}
fn field_for_index(
expr: &Expr,
field: &GetFieldAccess,
schema: &dyn ExprSchema,
) -> Result<Field> {
let expr_dt = expr.get_type(schema)?;
match field {
GetFieldAccess::NamedStructField { name } => {
GetFieldAccessSchema::NamedStructField { name: name.clone() }
}
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.get_type(schema)?,
},
GetFieldAccess::ListRange {
start,
stop,
stride,
} => GetFieldAccessSchema::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
stride_dt: stride.get_type(schema)?,
},
}
.get_accessed_field(&expr_dt)
}
pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
if subquery.subquery.schema().field(0).data_type() == cast_to_type {
return Ok(subquery);
}
let plan = subquery.subquery.as_ref();
let new_plan = match plan {
LogicalPlan::Projection(projection) => {
let cast_expr = projection.expr[0]
.clone()
.cast_to(cast_to_type, projection.input.schema())?;
LogicalPlan::Projection(Projection::try_new(
vec![cast_expr],
projection.input.clone(),
)?)
}
_ => {
let cast_expr = Expr::Column(plan.schema().field(0).qualified_column())
.cast_to(cast_to_type, subquery.subquery.schema())?;
LogicalPlan::Projection(Projection::try_new(
vec![cast_expr],
subquery.subquery,
)?)
}
};
Ok(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, lit};
use arrow::datatypes::{DataType, Fields};
use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
macro_rules! test_is_expr_nullable {
($EXPR_TYPE:ident) => {{
let expr = lit(ScalarValue::Null).$EXPR_TYPE();
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
}};
}
#[test]
fn expr_schema_nullability() {
let expr = col("foo").eq(lit(1));
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
assert!(expr
.nullable(&MockExprSchema::new().with_nullable(true))
.unwrap());
test_is_expr_nullable!(is_null);
test_is_expr_nullable!(is_not_null);
test_is_expr_nullable!(is_true);
test_is_expr_nullable!(is_not_true);
test_is_expr_nullable!(is_false);
test_is_expr_nullable!(is_not_false);
test_is_expr_nullable!(is_unknown);
test_is_expr_nullable!(is_not_unknown);
}
#[test]
fn test_between_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(nullable)
};
let expr = col("foo").between(lit(1), lit(2));
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
let null = lit(ScalarValue::Int32(None));
let expr = col("foo").between(null.clone(), lit(2));
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").between(lit(1), null.clone());
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").between(null.clone(), null);
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn test_inlist_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(nullable)
};
let expr = col("foo").in_list(vec![lit(1); 5], false);
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
assert!(expr
.nullable(&get_schema(false).with_error_on_nullable(true))
.is_err());
let null = lit(ScalarValue::Int32(None));
let expr = col("foo").in_list(vec![null, lit(1)], false);
assert!(expr.nullable(&get_schema(false)).unwrap());
let expr = col("foo").in_list(vec![lit(1); 6], false);
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn test_like_nullability() {
let get_schema = |nullable| {
MockExprSchema::new()
.with_data_type(DataType::Utf8)
.with_nullable(nullable)
};
let expr = col("foo").like(lit("bar"));
assert!(!expr.nullable(&get_schema(false)).unwrap());
assert!(expr.nullable(&get_schema(true)).unwrap());
let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
assert!(expr.nullable(&get_schema(false)).unwrap());
}
#[test]
fn expr_schema_data_type() {
let expr = col("foo");
assert_eq!(
DataType::Utf8,
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
.unwrap()
);
}
#[test]
fn test_expr_metadata() {
let mut meta = HashMap::new();
meta.insert("bar".to_string(), "buzz".to_string());
let expr = col("foo");
let schema = MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_metadata(meta.clone());
assert_eq!(meta, expr.metadata(&schema).unwrap());
assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
assert_eq!(
HashMap::new(),
expr.clone()
.cast_to(&DataType::Int64, &schema)
.unwrap()
.metadata(&schema)
.unwrap()
);
let schema = DFSchema::new_with_metadata(
vec![DFField::new_unqualified("foo", DataType::Int32, true)
.with_metadata(meta.clone())],
HashMap::new(),
)
.unwrap();
assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata());
}
#[test]
fn test_nested_schema_nullability() {
let fields = DFField::new(
Some(TableReference::Bare {
table: "table_name".into(),
}),
"parent",
DataType::Struct(Fields::from(vec![Field::new(
"child",
DataType::Int64,
false,
)])),
true,
);
let schema = DFSchema::new_with_metadata(vec![fields], HashMap::new()).unwrap();
let expr = col("parent").field("child");
assert!(expr.nullable(&schema).unwrap());
}
#[derive(Debug)]
struct MockExprSchema {
nullable: bool,
data_type: DataType,
error_on_nullable: bool,
metadata: HashMap<String, String>,
}
impl MockExprSchema {
fn new() -> Self {
Self {
nullable: false,
data_type: DataType::Null,
error_on_nullable: false,
metadata: HashMap::new(),
}
}
fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
self.error_on_nullable = error_on_nullable;
self
}
fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = metadata;
self
}
}
impl ExprSchema for MockExprSchema {
fn nullable(&self, _col: &Column) -> Result<bool> {
if self.error_on_nullable {
internal_err!("nullable error")
} else {
Ok(self.nullable)
}
}
fn data_type(&self, _col: &Column) -> Result<&DataType> {
Ok(&self.data_type)
}
fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
Ok(&self.metadata)
}
fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> {
Ok((self.data_type(col)?, self.nullable(col)?))
}
}
}