use anyhow::{Result, anyhow};
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{Case, Expr as DfExpr, ExprSchemable, Operator};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum TypeCompat {
Same,
NumericWidening(DataType),
DateTimeStruct,
TimeStruct,
StringCompat,
BooleanCompat,
NullInvolved,
Incomparable,
Dynamic,
}
pub(crate) fn type_compat(left: &DataType, right: &DataType) -> TypeCompat {
use TypeCompat::*;
if left == right {
return Same;
}
if matches!(left, DataType::Null) || matches!(right, DataType::Null) {
return NullInvolved;
}
if is_numeric_type(left) && is_numeric_type(right) {
let wider = super::df_expr::wider_numeric_type(left, right);
return NumericWidening(wider);
}
if is_string_type(left) && is_string_type(right) {
return StringCompat;
}
if matches!(left, DataType::Boolean) && matches!(right, DataType::Boolean) {
return BooleanCompat;
}
if uni_common::core::schema::is_datetime_struct(left)
&& uni_common::core::schema::is_datetime_struct(right)
{
return DateTimeStruct;
}
if uni_common::core::schema::is_time_struct(left)
&& uni_common::core::schema::is_time_struct(right)
{
return TimeStruct;
}
if matches!(left, DataType::LargeBinary) || matches!(right, DataType::LargeBinary) {
return Dynamic;
}
if let (DataType::Struct(_), DataType::Struct(_)) = (left, right) {
return Dynamic;
}
Incomparable
}
pub(crate) fn is_numeric_type(dt: &DataType) -> bool {
matches!(
dt,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
)
}
pub(crate) fn is_string_type(dt: &DataType) -> bool {
matches!(dt, DataType::Utf8 | DataType::LargeUtf8)
}
pub(crate) fn build_cypher_comparison(
left: DfExpr,
left_type: &DataType,
right: DfExpr,
right_type: &DataType,
op: Operator,
) -> DfExpr {
use TypeCompat::*;
match type_compat(left_type, right_type) {
Same | StringCompat | BooleanCompat => binary_expr(left, op, right),
NumericWidening(common) => {
let left_cast = super::df_expr::cast_expr(left, common.clone());
let right_cast = super::df_expr::cast_expr(right, common);
binary_expr(left_cast, op, right_cast)
}
DateTimeStruct => {
let left_nanos = super::df_expr::extract_datetime_nanos(left);
let right_nanos = super::df_expr::extract_datetime_nanos(right);
binary_expr(left_nanos, op, right_nanos)
}
TimeStruct => {
let left_nanos = super::df_expr::extract_time_nanos(left);
let right_nanos = super::df_expr::extract_time_nanos(right);
binary_expr(left_nanos, op, right_nanos)
}
NullInvolved => lit(ScalarValue::Boolean(None)),
Incomparable => match op {
Operator::Eq => lit(false),
Operator::NotEq => lit(true),
_ => lit(ScalarValue::Boolean(None)),
},
Dynamic => {
let udf_name = super::df_expr::comparison_udf_name(op)
.expect("comparison operator should have UDF mapping");
super::df_expr::dummy_udf_expr(udf_name, vec![left, right])
}
}
}
pub(crate) fn build_cypher_plus(
left: DfExpr,
left_type: &DataType,
right: DfExpr,
right_type: &DataType,
) -> Result<DfExpr> {
if matches!(left_type, DataType::Null) || matches!(right_type, DataType::Null) {
return Ok(lit(ScalarValue::Null));
}
if matches!(left_type, DataType::LargeBinary) || matches!(right_type, DataType::LargeBinary) {
return Ok(super::df_expr::dummy_udf_expr(
"_cypher_add",
vec![left, right],
));
}
if is_numeric_type(left_type) && is_numeric_type(right_type) {
let common = super::df_expr::wider_numeric_type(left_type, right_type);
let left_cast = super::df_expr::cast_expr(left, common.clone());
let right_cast = super::df_expr::cast_expr(right, common);
return Ok(binary_expr(left_cast, Operator::Plus, right_cast));
}
if is_string_type(left_type) && is_string_type(right_type) {
return Ok(datafusion::functions::string::expr_fn::concat(vec![
left, right,
]));
}
if is_string_type(left_type) {
let right_str = to_string_expr(right, right_type)?;
return Ok(datafusion::functions::string::expr_fn::concat(vec![
left, right_str,
]));
}
if is_string_type(right_type) {
let left_str = to_string_expr(left, left_type)?;
return Ok(datafusion::functions::string::expr_fn::concat(vec![
left_str, right,
]));
}
if matches!(left_type, DataType::List(_)) && matches!(right_type, DataType::List(_)) {
return Ok(super::df_expr::dummy_udf_expr(
"_cypher_list_concat",
vec![left, right],
));
}
if matches!(left_type, DataType::List(_)) {
return Ok(super::df_expr::dummy_udf_expr(
"_cypher_list_append",
vec![left, right],
));
}
if matches!(right_type, DataType::List(_)) {
return Ok(super::df_expr::dummy_udf_expr(
"_cypher_list_append",
vec![right, left],
));
}
if is_temporal_or_interval(left_type) || is_temporal_or_interval(right_type) {
return Ok(super::df_expr::dummy_udf_expr(
"_cypher_add",
vec![left, right],
));
}
Err(anyhow!(
"Incompatible types for Plus operator: {:?} + {:?}",
left_type,
right_type
))
}
fn is_temporal_or_interval(dt: &datafusion::arrow::datatypes::DataType) -> bool {
use datafusion::arrow::datatypes::DataType;
matches!(
dt,
DataType::Date32
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Interval(_)
| DataType::Duration(_)
) || matches!(dt, DataType::Struct(_)) }
fn to_string_expr(expr: DfExpr, _expr_type: &DataType) -> Result<DfExpr> {
Ok(super::df_expr::cast_expr(expr, DataType::Utf8))
}
pub(crate) fn rewrite_simple_case_to_generic(
operand: DfExpr,
when_then_expr: Vec<(Box<DfExpr>, Box<DfExpr>)>,
else_expr: Option<Box<DfExpr>>,
schema: &datafusion::common::DFSchema,
) -> Result<Case> {
let operand_type = operand
.get_type(schema)
.map_err(|e| anyhow!("Failed to get operand type: {}", e))?;
let new_when_then = when_then_expr
.into_iter()
.map(|(when, then)| {
let when_type = when
.get_type(schema)
.map_err(|e| anyhow!("Failed to get WHEN type: {}", e))?;
let eq_expr = build_cypher_comparison(
operand.clone(),
&operand_type,
*when,
&when_type,
Operator::Eq,
);
Ok((Box::new(eq_expr), then))
})
.collect::<Result<Vec<_>>>()?;
Ok(Case {
expr: None, when_then_expr: new_when_then,
else_expr,
})
}
pub(crate) fn find_common_result_type(
types: &[DataType],
_schema: &datafusion::common::DFSchema,
) -> DataType {
if types.is_empty() {
return DataType::Utf8; }
let non_null_types: Vec<&DataType> = types
.iter()
.filter(|t| !matches!(t, DataType::Null))
.collect();
if non_null_types.is_empty() {
return DataType::Null; }
let first = non_null_types[0];
if non_null_types.iter().all(|t| *t == first) {
return first.clone();
}
if non_null_types.iter().all(|t| is_numeric_type(t)) {
let mut widest = DataType::Int8;
for t in non_null_types {
widest = super::df_expr::wider_numeric_type(&widest, t);
}
return widest;
}
if non_null_types.iter().any(|t| is_string_type(t)) {
return DataType::Utf8;
}
if non_null_types
.iter()
.all(|t| uni_common::core::schema::is_datetime_struct(t))
{
return first.clone(); }
if non_null_types
.iter()
.all(|t| uni_common::core::schema::is_time_struct(t))
{
return first.clone(); }
if non_null_types
.iter()
.any(|t| matches!(t, DataType::LargeBinary))
{
return DataType::LargeBinary;
}
DataType::Utf8
}
fn coerce_branch_to(expr: DfExpr, from_type: &DataType, target_type: &DataType) -> DfExpr {
if matches!(target_type, DataType::LargeBinary) && !matches!(from_type, DataType::LargeBinary) {
if matches!(from_type, DataType::List(_) | DataType::LargeList(_)) {
return super::df_expr::list_to_large_binary_expr(expr);
}
return super::df_expr::scalar_to_large_binary_expr(expr);
}
super::df_expr::cast_expr(expr, target_type.clone())
}
pub(crate) fn coerce_case_results(
case: &mut Case,
schema: &datafusion::common::DFSchema,
) -> Result<()> {
let mut types = Vec::new();
for (_, then_expr) in &case.when_then_expr {
let then_type = then_expr
.get_type(schema)
.map_err(|e| anyhow!("Failed to get THEN type: {}", e))?;
types.push(then_type);
}
if let Some(else_expr) = &case.else_expr {
let else_type = else_expr
.get_type(schema)
.map_err(|e| anyhow!("Failed to get ELSE type: {}", e))?;
types.push(else_type);
}
let common_type = find_common_result_type(&types, schema);
for (_, then_expr) in &mut case.when_then_expr {
let then_type = then_expr
.get_type(schema)
.map_err(|e| anyhow!("Failed to get THEN type for cast: {}", e))?;
if then_type != common_type {
**then_expr = coerce_branch_to((**then_expr).clone(), &then_type, &common_type);
}
}
if let Some(else_expr) = &mut case.else_expr {
let else_type = else_expr
.get_type(schema)
.map_err(|e| anyhow!("Failed to get ELSE type for cast: {}", e))?;
if else_type != common_type {
**else_expr = coerce_branch_to((**else_expr).clone(), &else_type, &common_type);
}
}
Ok(())
}
pub(crate) fn build_cypher_in_list(
expr: DfExpr,
expr_type: &DataType,
list: Vec<DfExpr>,
negated: bool,
schema: &datafusion::common::DFSchema,
) -> Result<DfExpr> {
for item in &list {
let item_type = item
.get_type(schema)
.map_err(|e| anyhow!("Failed to get IN list item type: {}", e))?;
match type_compat(expr_type, &item_type) {
TypeCompat::DateTimeStruct | TypeCompat::TimeStruct | TypeCompat::Dynamic => {
return build_in_as_or_chain(expr, expr_type, list, negated, schema);
}
_ => {} }
}
let mut compatible = Vec::new();
let mut has_incomparable = false;
for item in list {
let item_type = item
.get_type(schema)
.map_err(|e| anyhow!("Failed to get IN list item type: {}", e))?;
match type_compat(expr_type, &item_type) {
TypeCompat::Same | TypeCompat::StringCompat | TypeCompat::BooleanCompat => {
compatible.push(item);
}
TypeCompat::NumericWidening(common) => {
let cast_item = super::df_expr::cast_expr(item, common);
compatible.push(cast_item);
}
TypeCompat::NullInvolved => {
has_incomparable = true;
compatible.push(item); }
TypeCompat::Incomparable => {
has_incomparable = true;
}
_ => {
return Err(anyhow!("Unexpected type compat in second pass"));
}
}
}
if !has_incomparable {
Ok(DfExpr::InList(InList {
expr: Box::new(expr),
list: compatible,
negated,
}))
} else if compatible.is_empty() {
Ok(lit(ScalarValue::Boolean(None)))
} else {
let in_expr = DfExpr::InList(InList {
expr: Box::new(expr),
list: compatible,
negated,
});
let result_val = if negated {
lit(ScalarValue::Boolean(Some(false)))
} else {
lit(ScalarValue::Boolean(Some(true)))
};
Ok(when(in_expr, result_val).otherwise(lit(ScalarValue::Boolean(None)))?)
}
}
fn build_in_as_or_chain(
expr: DfExpr,
expr_type: &DataType,
list: Vec<DfExpr>,
negated: bool,
schema: &datafusion::common::DFSchema,
) -> Result<DfExpr> {
if list.is_empty() {
return Ok(lit(ScalarValue::Boolean(Some(negated))));
}
let result = list
.into_iter()
.map(|item| {
let item_type = item
.get_type(schema)
.map_err(|e| anyhow!("Failed to get item type in OR chain: {}", e))?;
Ok(build_cypher_comparison(
expr.clone(),
expr_type,
item,
&item_type,
Operator::Eq,
))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.reduce(|chain, eq_expr| binary_expr(chain, Operator::Or, eq_expr))
.unwrap();
if negated { Ok(not(result)) } else { Ok(result) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_compat_same() {
let dt = DataType::Int64;
assert_eq!(type_compat(&dt, &dt), TypeCompat::Same);
}
#[test]
fn test_type_compat_numeric_widening() {
let left = DataType::Int64;
let right = DataType::Float64;
match type_compat(&left, &right) {
TypeCompat::NumericWidening(common) => assert_eq!(common, DataType::Float64),
_ => panic!("Expected NumericWidening"),
}
}
#[test]
fn test_type_compat_incomparable() {
let left = DataType::Utf8;
let right = DataType::Int64;
assert_eq!(type_compat(&left, &right), TypeCompat::Incomparable);
}
#[test]
fn test_type_compat_null_involved() {
let left = DataType::Null;
let right = DataType::Int64;
assert_eq!(type_compat(&left, &right), TypeCompat::NullInvolved);
}
#[test]
fn test_type_compat_string_compat() {
let left = DataType::Utf8;
let right = DataType::Utf8;
assert_eq!(type_compat(&left, &right), TypeCompat::Same);
}
#[test]
fn test_build_cypher_comparison_incomparable_eq_returns_false() {
let left = lit(ScalarValue::Utf8(Some("hello".to_string())));
let right = lit(ScalarValue::Int64(Some(42)));
let result =
build_cypher_comparison(left, &DataType::Utf8, right, &DataType::Int64, Operator::Eq);
match result {
DfExpr::Literal(ScalarValue::Boolean(Some(false)), _) => {}
_ => panic!(
"Expected false literal for incomparable Eq, got {:?}",
result
),
}
}
#[test]
fn test_build_cypher_comparison_incomparable_not_eq_returns_true() {
let left = lit(ScalarValue::Utf8(Some("hello".to_string())));
let right = lit(ScalarValue::Int64(Some(42)));
let result = build_cypher_comparison(
left,
&DataType::Utf8,
right,
&DataType::Int64,
Operator::NotEq,
);
match result {
DfExpr::Literal(ScalarValue::Boolean(Some(true)), _) => {}
_ => panic!(
"Expected true literal for incomparable NotEq, got {:?}",
result
),
}
}
#[test]
fn test_build_cypher_comparison_incomparable_ordering_returns_null() {
for op in [Operator::Lt, Operator::LtEq, Operator::Gt, Operator::GtEq] {
let left = lit(ScalarValue::Utf8(Some("hello".to_string())));
let right = lit(ScalarValue::Int64(Some(42)));
let result =
build_cypher_comparison(left, &DataType::Utf8, right, &DataType::Int64, op);
match result {
DfExpr::Literal(ScalarValue::Boolean(None), _) => {}
_ => panic!(
"Expected null for incomparable ordering op {:?}, got {:?}",
op, result
),
}
}
}
#[test]
fn test_build_cypher_comparison_list_vs_bool_eq_returns_false() {
use datafusion::arrow::datatypes::Field;
use std::sync::Arc;
let list_type = DataType::List(Arc::new(Field::new("item", DataType::Int64, true)));
let left = lit(ScalarValue::Null); let right = lit(ScalarValue::Boolean(Some(true)));
let result =
build_cypher_comparison(left, &list_type, right, &DataType::Boolean, Operator::Eq);
match result {
DfExpr::Literal(ScalarValue::Boolean(Some(false)), _) => {}
_ => panic!("Expected false for List vs Boolean Eq, got {:?}", result),
}
}
#[test]
fn test_build_cypher_comparison_null_involved() {
let left = lit(ScalarValue::Null);
let right = lit(ScalarValue::Int64(Some(42)));
let result =
build_cypher_comparison(left, &DataType::Null, right, &DataType::Int64, Operator::Eq);
match result {
DfExpr::Literal(ScalarValue::Boolean(None), _) => {} _ => panic!("Expected null literal for null involved"),
}
}
#[test]
fn test_build_cypher_plus_null_propagation() {
let left = lit(ScalarValue::Null);
let right = lit(ScalarValue::Int64(Some(42)));
let result = build_cypher_plus(left, &DataType::Null, right, &DataType::Int64);
match result {
Ok(DfExpr::Literal(ScalarValue::Null, _)) => {} _ => panic!("Expected null for null propagation"),
}
}
#[test]
fn test_find_common_result_type_all_same() {
let types = vec![DataType::Int64, DataType::Int64, DataType::Int64];
let schema = datafusion::common::DFSchema::empty();
let common = find_common_result_type(&types, &schema);
assert_eq!(common, DataType::Int64);
}
#[test]
fn test_find_common_result_type_numeric_widening() {
let types = vec![DataType::Int64, DataType::Float64, DataType::Int32];
let schema = datafusion::common::DFSchema::empty();
let common = find_common_result_type(&types, &schema);
assert_eq!(common, DataType::Float64);
}
#[test]
fn test_find_common_result_type_any_string() {
let types = vec![DataType::Int64, DataType::Utf8, DataType::Float64];
let schema = datafusion::common::DFSchema::empty();
let common = find_common_result_type(&types, &schema);
assert_eq!(common, DataType::Utf8);
}
#[test]
fn test_type_compat_one_struct_one_non_struct_incomparable() {
use datafusion::arrow::datatypes::Field;
let struct_type = DataType::Struct(vec![Field::new("a", DataType::Int64, true)].into());
let non_struct = DataType::Int64;
assert_eq!(
type_compat(&struct_type, &non_struct),
TypeCompat::Incomparable
);
assert_eq!(
type_compat(&non_struct, &struct_type),
TypeCompat::Incomparable
);
}
#[test]
fn test_type_compat_both_non_temporal_structs_dynamic() {
use datafusion::arrow::datatypes::Field;
let struct1 = DataType::Struct(vec![Field::new("a", DataType::Int64, true)].into());
let struct2 = DataType::Struct(vec![Field::new("b", DataType::Utf8, true)].into());
assert_eq!(type_compat(&struct1, &struct2), TypeCompat::Dynamic);
}
}