use std::collections::HashSet;
use arrow::datatypes::{DataType, Field, Schema};
use crate::error::{ExecutionError, Result};
use crate::logicalplan::Expr;
pub fn exprlist_to_column_indices(
expr: &[Expr],
accum: &mut HashSet<usize>,
) -> Result<()> {
for e in expr {
expr_to_column_indices(e, accum)?;
}
Ok(())
}
pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) -> Result<()> {
match expr {
Expr::Alias(expr, _) => expr_to_column_indices(expr, accum),
Expr::Column(i) => {
accum.insert(*i);
Ok(())
}
Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError(
"Columns need to be resolved before column indexes resolution rule can run"
.to_owned(),
)),
Expr::Literal(_) => {
Ok(())
}
Expr::Not(e) => expr_to_column_indices(e, accum),
Expr::IsNull(e) => expr_to_column_indices(e, accum),
Expr::IsNotNull(e) => expr_to_column_indices(e, accum),
Expr::BinaryExpr { left, right, .. } => {
expr_to_column_indices(left, accum)?;
expr_to_column_indices(right, accum)?;
Ok(())
}
Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum),
Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum),
Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}
pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
match e {
Expr::Alias(expr, name) => {
Ok(Field::new(name, expr.get_type(input_schema)?, true))
}
Expr::UnresolvedColumn(name) => Ok(input_schema.field_with_name(&name)?.clone()),
Expr::Column(i) => {
let input_schema_field_count = input_schema.fields().len();
if *i < input_schema_field_count {
Ok(input_schema.fields()[*i].clone())
} else {
Err(ExecutionError::General(format!(
"Column index {} out of bounds for input schema with {} field(s)",
*i, input_schema_field_count
)))
}
}
Expr::Literal(ref lit) => Ok(Field::new("lit", lit.get_datatype(), true)),
Expr::ScalarFunction {
ref name,
ref return_type,
..
} => Ok(Field::new(&name, return_type.clone(), true)),
Expr::AggregateFunction {
ref name,
ref return_type,
..
} => Ok(Field::new(&name, return_type.clone(), true)),
Expr::Cast { ref data_type, .. } => {
Ok(Field::new("cast", data_type.clone(), true))
}
Expr::BinaryExpr {
ref left,
ref right,
..
} => {
let left_type = left.get_type(input_schema)?;
let right_type = right.get_type(input_schema)?;
Ok(Field::new(
"binary_expr",
get_supertype(&left_type, &right_type).unwrap(),
true,
))
}
_ => Err(ExecutionError::NotImplemented(format!(
"Cannot determine schema type for expression {:?}",
e
))),
}
}
pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result<Vec<Field>> {
expr.iter()
.map(|e| expr_to_field(e, input_schema))
.collect()
}
pub fn get_supertype(l: &DataType, r: &DataType) -> Result<DataType> {
match _get_supertype(l, r) {
Some(dt) => Ok(dt),
None => _get_supertype(r, l).ok_or_else(|| {
ExecutionError::InternalError(format!(
"Failed to determine supertype of {:?} and {:?}",
l, r
))
}),
}
}
fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (l, r) {
(UInt8, Int8) => Some(Int8),
(UInt8, Int16) => Some(Int16),
(UInt8, Int32) => Some(Int32),
(UInt8, Int64) => Some(Int64),
(UInt16, Int16) => Some(Int16),
(UInt16, Int32) => Some(Int32),
(UInt16, Int64) => Some(Int64),
(UInt32, Int32) => Some(Int32),
(UInt32, Int64) => Some(Int64),
(UInt64, Int64) => Some(Int64),
(Int8, UInt8) => Some(Int8),
(Int16, UInt8) => Some(Int16),
(Int16, UInt16) => Some(Int16),
(Int32, UInt8) => Some(Int32),
(Int32, UInt16) => Some(Int32),
(Int32, UInt32) => Some(Int32),
(Int64, UInt8) => Some(Int64),
(Int64, UInt16) => Some(Int64),
(Int64, UInt32) => Some(Int64),
(Int64, UInt64) => Some(Int64),
(UInt8, UInt8) => Some(UInt8),
(UInt8, UInt16) => Some(UInt16),
(UInt8, UInt32) => Some(UInt32),
(UInt8, UInt64) => Some(UInt64),
(UInt8, Float32) => Some(Float32),
(UInt8, Float64) => Some(Float64),
(UInt16, UInt8) => Some(UInt16),
(UInt16, UInt16) => Some(UInt16),
(UInt16, UInt32) => Some(UInt32),
(UInt16, UInt64) => Some(UInt64),
(UInt16, Float32) => Some(Float32),
(UInt16, Float64) => Some(Float64),
(UInt32, UInt8) => Some(UInt32),
(UInt32, UInt16) => Some(UInt32),
(UInt32, UInt32) => Some(UInt32),
(UInt32, UInt64) => Some(UInt64),
(UInt32, Float32) => Some(Float32),
(UInt32, Float64) => Some(Float64),
(UInt64, UInt8) => Some(UInt64),
(UInt64, UInt16) => Some(UInt64),
(UInt64, UInt32) => Some(UInt64),
(UInt64, UInt64) => Some(UInt64),
(UInt64, Float32) => Some(Float32),
(UInt64, Float64) => Some(Float64),
(Int8, Int8) => Some(Int8),
(Int8, Int16) => Some(Int16),
(Int8, Int32) => Some(Int32),
(Int8, Int64) => Some(Int64),
(Int8, Float32) => Some(Float32),
(Int8, Float64) => Some(Float64),
(Int16, Int8) => Some(Int16),
(Int16, Int16) => Some(Int16),
(Int16, Int32) => Some(Int32),
(Int16, Int64) => Some(Int64),
(Int16, Float32) => Some(Float32),
(Int16, Float64) => Some(Float64),
(Int32, Int8) => Some(Int32),
(Int32, Int16) => Some(Int32),
(Int32, Int32) => Some(Int32),
(Int32, Int64) => Some(Int64),
(Int32, Float32) => Some(Float32),
(Int32, Float64) => Some(Float64),
(Int64, Int8) => Some(Int64),
(Int64, Int16) => Some(Int64),
(Int64, Int32) => Some(Int64),
(Int64, Int64) => Some(Int64),
(Int64, Float32) => Some(Float32),
(Int64, Float64) => Some(Float64),
(Float32, Float32) => Some(Float32),
(Float32, Float64) => Some(Float64),
(Float64, Float32) => Some(Float64),
(Float64, Float64) => Some(Float64),
(Utf8, _) => Some(Utf8),
(_, Utf8) => Some(Utf8),
(Boolean, Boolean) => Some(Boolean),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logicalplan::Expr;
use arrow::datatypes::DataType;
use std::collections::HashSet;
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<usize> = HashSet::new();
expr_to_column_indices(
&Expr::Cast {
expr: Box::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
)?;
expr_to_column_indices(
&Expr::Cast {
expr: Box::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&3));
Ok(())
}
}