use arrow_schema::DataType;
use datafusion::logical_expr::{
expr::ScalarFunction, BinaryExpr, BuiltinScalarFunction, GetFieldAccess, GetIndexedField,
Operator,
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use lance_arrow::DataTypeExt;
use lance_datafusion::expr::safe_coerce_scalar;
use crate::datatypes::Schema;
use crate::{Error, Result};
use snafu::{location, Location};
fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
match expr {
Expr::Literal(scalar_value) => {
Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::IO {
message: format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
location: location!(),
})?))
}
_ => Err(Error::IO {
message: format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
location: location!(),
}),
}
}
pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
let mut field_path = Vec::new();
let mut current_expr = expr;
loop {
match current_expr {
Expr::Column(c) => {
field_path.push(c.name.as_str());
break;
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
if let GetFieldAccess::NamedStructField {
name: ScalarValue::Utf8(Some(name)),
} = field
{
field_path.push(name);
} else {
return None;
}
current_expr = expr.as_ref();
}
_ => return None,
}
}
let mut path_iter = field_path.iter().rev();
let mut field = schema.field(path_iter.next()?)?;
for name in path_iter {
if field.data_type().is_struct() {
field = field.children.iter().find(|f| &f.name == name)?;
} else {
return None;
}
}
Some(field.data_type())
}
pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
if matches!(op, Operator::And | Operator::Or) {
return Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(resolve_expr(left.as_ref(), schema)?),
op: *op,
right: Box::new(resolve_expr(right.as_ref(), schema)?),
}));
}
match (left.as_ref(), right.as_ref()) {
(Expr::Column(_) | Expr::GetIndexedField(_), Expr::Literal(_)) => {
if let Some(resolved_type) = resolve_column_type(left.as_ref(), schema) {
Ok(Expr::BinaryExpr(BinaryExpr {
left: left.clone(),
op: *op,
right: Box::new(resolve_value(right.as_ref(), &resolved_type)?),
}))
} else {
Ok(expr.clone())
}
}
(Expr::Literal(_), Expr::Column(_) | Expr::GetIndexedField(_)) => {
if let Some(resolved_type) = resolve_column_type(right.as_ref(), schema) {
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(resolve_value(right.as_ref(), &resolved_type)?),
op: *op,
right: left.clone(),
}))
} else {
Ok(expr.clone())
}
}
(Expr::Column(_) | Expr::GetIndexedField(_), Expr::BinaryExpr(r)) => {
if let Some(resolved_type) = resolve_column_type(left.as_ref(), schema) {
Ok(Expr::BinaryExpr(BinaryExpr {
left: left.clone(),
op: *op,
right: Box::new(Expr::BinaryExpr(BinaryExpr {
left: coerce_expr(&r.left, &resolved_type).map(Box::new)?,
op: r.op,
right: coerce_expr(&r.right, &resolved_type).map(Box::new)?,
})),
}))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
_ => {
Ok(expr.clone())
}
}
}
pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(coerce_expr(left, dtype)?),
op: *op,
right: Box::new(coerce_expr(right, dtype)?),
})),
Expr::Literal(l) => Ok(resolve_value(&Expr::Literal(l.clone()), dtype)?),
_ => Ok(expr.clone()),
}
}
pub fn coerce_filter_type_to_boolean(expr: Expr) -> Result<Expr> {
match expr {
Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::RegexpMatch,
args: _,
}) => Ok(Expr::IsNotNull(Box::new(expr))),
_ => Ok(expr),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use arrow_schema::{Field, Schema as ArrowSchema};
use datafusion::scalar::ScalarValue;
#[test]
fn test_resolve_large_utf8() {
let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
let expr = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column("a".to_string().into())),
op: Operator::Eq,
right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))),
});
let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
match resolved {
Expr::BinaryExpr(be) => {
assert_eq!(
be.right.as_ref(),
&Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())))
)
}
_ => unreachable!("Expected BinaryExpr"),
};
}
#[test]
fn test_resolve_binary_expr_on_right() {
let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
let expr = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column("a".to_string().into())),
op: Operator::Eq,
right: Box::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)))),
op: Operator::Minus,
right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)))),
})),
});
let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
match resolved {
Expr::BinaryExpr(be) => match be.right.as_ref() {
Expr::BinaryExpr(r_be) => {
assert_eq!(
r_be.left.as_ref(),
&Expr::Literal(ScalarValue::Float64(Some(2.0)))
);
assert_eq!(
r_be.right.as_ref(),
&Expr::Literal(ScalarValue::Float64(Some(-1.0)))
);
}
_ => panic!("Expected BinaryExpr"),
},
_ => panic!("Expected BinaryExpr"),
}
}
#[test]
fn test_resolve_column_type() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("int", DataType::Int32, true),
Field::new(
"st",
DataType::Struct(
vec![
Field::new("str", DataType::Utf8, true),
Field::new(
"st",
DataType::Struct(
vec![Field::new("float", DataType::Float64, true)].into(),
),
true,
),
]
.into(),
),
true,
),
]));
let schema = Schema::try_from(schema.as_ref()).unwrap();
assert_eq!(
resolve_column_type(&col("int"), &schema),
Some(DataType::Int32)
);
assert_eq!(
resolve_column_type(&col("st").field("str"), &schema),
Some(DataType::Utf8)
);
assert_eq!(
resolve_column_type(&col("st").field("st").field("float"), &schema),
Some(DataType::Float64)
);
assert_eq!(resolve_column_type(&col("x"), &schema), None);
assert_eq!(resolve_column_type(&col("str"), &schema), None);
assert_eq!(resolve_column_type(&col("float"), &schema), None);
assert_eq!(
resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
None
);
}
}