use super::PhysicalExpr;
use crate::expressions::try_cast;
use arrow::datatypes::Schema;
use datafusion_common::Result;
use datafusion_expr::{type_coercion::functions::data_types, Signature};
use std::{sync::Arc, vec};
pub fn coerce(
expressions: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
signature: &Signature,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
if expressions.is_empty() {
return Ok(vec![]);
}
let current_types = expressions
.iter()
.map(|e| e.data_type(schema))
.collect::<Result<Vec<_>>>()?;
let new_types = data_types(¤t_types, signature)?;
expressions
.iter()
.enumerate()
.map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone()))
.collect::<Result<Vec<_>>>()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::col;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::DataFusionError;
use datafusion_expr::Volatility;
#[test]
fn test_coerce() -> Result<()> {
let schema = |t: Vec<DataType>| {
Schema::new(
t.iter()
.enumerate()
.map(|(i, t)| Field::new(format!("c{i}"), t.clone(), true))
.collect(),
)
};
let expressions = |t: Vec<DataType>, schema| -> Result<Vec<_>> {
t.iter()
.enumerate()
.map(|(i, t)| {
try_cast(col(&format!("c{i}"), &schema)?, &schema, t.clone())
})
.collect::<Result<Vec<_>>>()
};
let case =
|observed: Vec<DataType>, valid, expected: Vec<DataType>| -> Result<_> {
let schema = schema(observed.clone());
let expr = expressions(observed, schema.clone())?;
let expected = expressions(expected, schema.clone())?;
Ok((expr.clone(), schema, valid, expected))
};
let cases = vec![
case(
vec![DataType::UInt16],
Signature::uniform(1, vec![DataType::UInt32], Volatility::Immutable),
vec![DataType::UInt32],
)?,
case(
vec![DataType::UInt32, DataType::UInt32],
Signature::uniform(2, vec![DataType::UInt32], Volatility::Immutable),
vec![DataType::UInt32, DataType::UInt32],
)?,
case(
vec![DataType::UInt32],
Signature::uniform(
1,
vec![DataType::Float32, DataType::Float64],
Volatility::Immutable,
),
vec![DataType::Float32],
)?,
case(
vec![DataType::UInt32, DataType::UInt32],
Signature::variadic(vec![DataType::Float32], Volatility::Immutable),
vec![DataType::Float32, DataType::Float32],
)?,
case(
vec![DataType::Float32, DataType::UInt32],
Signature::variadic_equal(Volatility::Immutable),
vec![DataType::Float32, DataType::Float32],
)?,
case(
vec![DataType::UInt32, DataType::UInt64],
Signature::variadic(
vec![DataType::UInt32, DataType::UInt64],
Volatility::Immutable,
),
vec![DataType::UInt64, DataType::UInt64],
)?,
case(
vec![DataType::Float32],
Signature::any(1, Volatility::Immutable),
vec![DataType::Float32],
)?,
];
for case in cases {
let observed = format!("{:?}", coerce(&case.0, &case.1, &case.2)?);
let expected = format!("{:?}", case.3);
assert_eq!(observed, expected);
}
let cases = vec![
case(
vec![DataType::Boolean],
Signature::uniform(1, vec![DataType::UInt16], Volatility::Immutable),
vec![],
)?,
case(
vec![DataType::UInt32, DataType::Boolean],
Signature::variadic_equal(Volatility::Immutable),
vec![],
)?,
case(
vec![DataType::Boolean, DataType::Boolean],
Signature::variadic(vec![DataType::UInt32], Volatility::Immutable),
vec![],
)?,
case(
vec![DataType::UInt32],
Signature::any(2, Volatility::Immutable),
vec![],
)?,
];
for case in cases {
if coerce(&case.0, &case.1, &case.2).is_ok() {
return Err(DataFusionError::Plan(format!(
"Error was expected in {case:?}"
)));
}
}
Ok(())
}
}