use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use crate::operator::Operator;
use arrow::array::{Array, new_empty_array};
use arrow::compute::can_cast_types;
use arrow::datatypes::IntervalUnit::MonthDayNano;
use arrow::datatypes::TimeUnit::*;
use arrow::datatypes::{
DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION,
DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, Field, FieldRef, Fields,
TimeUnit,
};
use datafusion_common::types::NativeType;
use datafusion_common::{
Diagnostic, Result, Span, Spans, exec_err, internal_err, not_impl_err,
plan_datafusion_err, plan_err,
};
use itertools::Itertools;
struct Signature {
lhs: DataType,
rhs: DataType,
ret: DataType,
}
impl Signature {
fn uniform(t: DataType) -> Self {
Self {
lhs: t.clone(),
rhs: t.clone(),
ret: t,
}
}
fn comparison(t: DataType) -> Self {
Self {
lhs: t.clone(),
rhs: t,
ret: DataType::Boolean,
}
}
}
pub struct BinaryTypeCoercer<'a> {
lhs: &'a DataType,
op: &'a Operator,
rhs: &'a DataType,
lhs_spans: Spans,
op_spans: Spans,
rhs_spans: Spans,
}
impl<'a> BinaryTypeCoercer<'a> {
pub fn new(lhs: &'a DataType, op: &'a Operator, rhs: &'a DataType) -> Self {
Self {
lhs,
op,
rhs,
lhs_spans: Spans::new(),
op_spans: Spans::new(),
rhs_spans: Spans::new(),
}
}
pub fn set_lhs_spans(&mut self, spans: Spans) {
self.lhs_spans = spans;
}
pub fn set_op_spans(&mut self, spans: Spans) {
self.op_spans = spans;
}
pub fn set_rhs_spans(&mut self, spans: Spans) {
self.rhs_spans = spans;
}
fn span(&self) -> Option<Span> {
Span::union_iter(
[self.lhs_spans.first(), self.rhs_spans.first()]
.iter()
.copied()
.flatten(),
)
}
fn signature(&'a self) -> Result<Signature> {
if matches!((self.lhs, self.rhs), (DataType::Null, DataType::Null))
&& self.op.is_numerical_operators()
{
return Ok(Signature::uniform(DataType::Int64));
}
if let Some(coerced) = null_coercion(self.lhs, self.rhs) {
if self.op.is_numerical_operators() && !coerced.is_temporal() {
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for arithmetic operation {coerced} {} {coerced}: {e}",
self.op
)
})?;
return Ok(Signature {
lhs: coerced.clone(),
rhs: coerced,
ret,
});
}
return self.signature_inner(&coerced, &coerced);
}
self.signature_inner(self.lhs, self.rhs)
}
fn get_result(
&self,
lhs: &DataType,
rhs: &DataType,
) -> arrow::error::Result<DataType> {
use arrow::compute::kernels::numeric::*;
let l = new_empty_array(lhs);
let r = new_empty_array(rhs);
let result = match self.op {
Operator::Plus => add_wrapping(&l, &r),
Operator::Minus => sub_wrapping(&l, &r),
Operator::Multiply => mul_wrapping(&l, &r),
Operator::Divide => div(&l, &r),
Operator::Modulo => rem(&l, &r),
_ => unreachable!(),
};
result.map(|x| x.data_type().clone())
}
fn signature_inner(&'a self, lhs: &DataType, rhs: &DataType) -> Result<Signature> {
use Operator::*;
use arrow::datatypes::DataType::*;
let result = match self.op {
Eq |
NotEq |
Lt |
LtEq |
Gt |
GtEq |
IsDistinctFrom |
IsNotDistinctFrom => {
comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for comparison operation {} {} {}",
self.lhs,
self.op,
self.rhs
)
})
}
And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) {
Ok(Signature::uniform(Boolean))
} else {
plan_err!(
"Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs
)
}
RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
StringConcat => {
string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
AtArrow | ArrowAt => {
array_coercion(lhs, rhs)
.or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
AtAt => {
like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
Minus if is_date_minus_date(lhs, rhs) => {
return Ok(Signature {
lhs: lhs.clone(),
rhs: rhs.clone(),
ret: Int64,
});
}
Plus | Minus | Multiply | Divide | Modulo => {
if let Ok(ret) = self.get_result(lhs, rhs) {
Ok(Signature{
lhs: lhs.clone(),
rhs: rhs.clone(),
ret,
})
} else if let Some((lhs, rhs)) = temporal_math_coercion(lhs, rhs) {
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for temporal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature {
lhs,
rhs,
ret,
})
} else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op
)
})?;
Ok(Signature{
lhs: coerced.clone(),
rhs: coerced,
ret,
})
} else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature{
lhs,
rhs,
ret,
})
} else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
Ok(Signature::uniform(numeric))
} else {
plan_err!(
"Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs
)
}
},
Colon => {
Ok(Signature { lhs: lhs.clone(), rhs: rhs.clone(), ret: lhs.clone() })
},
IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow
| HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => {
not_impl_err!("Operator {} is not yet supported", self.op)
}
};
result.map_err(|err| {
let diagnostic =
Diagnostic::new_error("expressions have incompatible types", self.span())
.with_note(format!("has type {}", self.lhs), self.lhs_spans.first())
.with_note(format!("has type {}", self.rhs), self.rhs_spans.first());
err.with_diagnostic(diagnostic)
})
}
pub fn get_result_type(&'a self) -> Result<DataType> {
self.signature().map(|sig| sig.ret)
}
pub fn get_input_types(&'a self) -> Result<(DataType, DataType)> {
self.signature().map(|sig| (sig.lhs, sig.rhs))
}
}
fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool {
matches!(
(lhs, rhs),
(DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64)
)
}
fn math_decimal_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<(DataType, DataType)> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Dictionary(_, value_type), _) => {
let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
Some((value_type, rhs_type))
}
(_, Dictionary(_, value_type)) => {
let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
Some((lhs_type, value_type))
}
(
Null,
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
) => Some((rhs_type.clone(), rhs_type.clone())),
(
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
Null,
) => Some((lhs_type.clone(), lhs_type.clone())),
(Decimal32(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal64(_, _))
| (Decimal128(_, _), Decimal128(_, _))
| (Decimal256(_, _), Decimal256(_, _)) => {
Some((lhs_type.clone(), rhs_type.clone()))
}
(lhs, rhs)
if lhs.is_decimal()
&& rhs.is_decimal()
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
{
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
Some((coerced_type.clone(), coerced_type))
}
(
Decimal32(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
) => Some((
lhs_type.clone(),
coerce_numeric_type_to_decimal32(rhs_type)?,
)),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
Decimal32(_, _),
) => Some((
coerce_numeric_type_to_decimal32(lhs_type)?,
rhs_type.clone(),
)),
(
Decimal64(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
) => Some((
lhs_type.clone(),
coerce_numeric_type_to_decimal64(rhs_type)?,
)),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
Decimal64(_, _),
) => Some((
coerce_numeric_type_to_decimal64(lhs_type)?,
rhs_type.clone(),
)),
(
Decimal128(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
) => Some((
lhs_type.clone(),
coerce_numeric_type_to_decimal128(rhs_type)?,
)),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
Decimal128(_, _),
) => Some((
coerce_numeric_type_to_decimal128(lhs_type)?,
rhs_type.clone(),
)),
(
Decimal256(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
) => Some((
lhs_type.clone(),
coerce_numeric_type_to_decimal256(rhs_type)?,
)),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
Decimal256(_, _),
) => Some((
coerce_numeric_type_to_decimal256(lhs_type)?,
rhs_type.clone(),
)),
_ => None,
}
}
fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !both_numeric_or_null_and_numeric(left_type, right_type) {
return None;
}
let is_integer_dictionary =
matches!(left_type, Dictionary(_, value_type) if value_type.is_integer());
if left_type == right_type && (left_type.is_integer() || is_integer_dictionary) {
return Some(left_type.clone());
}
match (left_type, right_type) {
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
| (Int16, UInt16)
| (UInt16, Int8)
| (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
enum TypeCategory {
Array,
Boolean,
Numeric,
DateTime,
Composite,
Unknown,
NotSupported,
}
impl From<&DataType> for TypeCategory {
fn from(data_type: &DataType) -> Self {
match data_type {
DataType::Dictionary(_, v) => {
let v = v.as_ref();
TypeCategory::from(v)
}
_ => {
if data_type.is_numeric() {
return TypeCategory::Numeric;
}
if *data_type == DataType::Boolean {
return TypeCategory::Boolean;
}
if matches!(
data_type,
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
) {
return TypeCategory::Array;
}
if matches!(
data_type,
DataType::Utf8
| DataType::LargeUtf8
| DataType::Utf8View
| DataType::Null
) {
return TypeCategory::Unknown;
}
if matches!(
data_type,
DataType::Date32
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Interval(_)
| DataType::Duration(_)
) {
return TypeCategory::DateTime;
}
if matches!(
data_type,
DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _)
) {
return TypeCategory::Composite;
}
TypeCategory::NotSupported
}
}
}
}
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
}
if data_types.iter().all(|t| t == &data_types[0]) {
return Some(data_types[0].clone());
}
if data_types.iter().all(|t| t == &DataType::Null) {
return Some(DataType::Utf8View);
}
let data_types_category: Vec<TypeCategory> = data_types
.iter()
.filter(|&t| t != &DataType::Null)
.map(|t| t.into())
.collect();
if data_types_category
.iter()
.any(|t| t == &TypeCategory::NotSupported)
{
return None;
}
let categories: HashSet<TypeCategory> = HashSet::from_iter(
data_types_category
.iter()
.filter(|&c| c != &TypeCategory::Unknown)
.cloned(),
);
if categories.len() > 1 {
return None;
}
let mut candidate_type: Option<DataType> = None;
for data_type in data_types.iter() {
if data_type == &DataType::Null {
continue;
}
if let Some(ref candidate_t) = candidate_type {
if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) {
candidate_type = Some(t);
} else {
return None;
}
} else {
candidate_type = Some(data_type.clone());
}
}
candidate_type
}
fn type_union_resolution_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}
match (lhs_type, rhs_type) {
(
DataType::Dictionary(lhs_index_type, lhs_value_type),
DataType::Dictionary(rhs_index_type, rhs_value_type),
) => {
let new_index_type =
type_union_resolution_coercion(lhs_index_type, rhs_index_type);
let new_value_type =
type_union_resolution_coercion(lhs_value_type, rhs_value_type);
if let (Some(new_index_type), Some(new_value_type)) =
(new_index_type, new_value_type)
{
Some(DataType::Dictionary(
Box::new(new_index_type),
Box::new(new_value_type),
))
} else {
None
}
}
(DataType::Dictionary(index_type, value_type), other_type)
| (other_type, DataType::Dictionary(index_type, value_type)) => {
match type_union_resolution_coercion(value_type, other_type) {
Some(DataType::Utf8View) => Some(DataType::Utf8View),
Some(new_value_type) => Some(DataType::Dictionary(
index_type.clone(),
Box::new(new_value_type),
)),
None => None,
}
}
(DataType::Struct(lhs), DataType::Struct(rhs)) => {
if lhs.len() != rhs.len() {
return None;
}
fn search_corresponding_coerced_type(
lhs_field: &FieldRef,
rhs: &Fields,
) -> Option<DataType> {
for rhs_field in rhs.iter() {
if lhs_field.name() == rhs_field.name() {
if let Some(t) = type_union_resolution_coercion(
lhs_field.data_type(),
rhs_field.data_type(),
) {
return Some(t);
} else {
return None;
}
}
}
None
}
let coerced_types = lhs
.iter()
.map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs))
.collect::<Option<Vec<_>>>()?;
let orig_fields = std::iter::zip(lhs.iter(), rhs.iter());
let fields: Vec<FieldRef> = coerced_types
.into_iter()
.zip(orig_fields)
.map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs))
.collect();
Some(DataType::Struct(fields.into()))
}
_ => {
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| list_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| numeric_string_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
}
}
}
pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
let err = match try_type_union_resolution_with_struct(data_types) {
Ok(struct_types) => return Ok(struct_types),
Err(e) => Some(e),
};
if let Some(new_type) = type_union_resolution(data_types) {
Ok(vec![new_type; data_types.len()])
} else {
exec_err!("Fail to find the coerced type, errors: {:?}", err)
}
}
pub fn try_type_union_resolution_with_struct(
data_types: &[DataType],
) -> Result<Vec<DataType>> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!(
"Expect same keys for struct type but got mismatched pair {} and {}",
*k,
keys
);
}
} else {
keys_string = Some(keys);
}
} else {
return exec_err!("Expect to get struct but got {data_type}");
}
}
let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
{
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!(
"Struct type is checked is the previous function, so this should be unreachable"
);
};
for data_type in data_types.iter().skip(1) {
if let DataType::Struct(fields) = data_type {
let incoming_struct_types: Vec<DataType> =
fields.iter().map(|f| f.data_type().to_owned()).collect();
for (lhs_type, rhs_type) in
struct_types.iter_mut().zip(incoming_struct_types.iter())
{
if let Some(coerced_type) =
type_union_resolution_coercion(lhs_type, rhs_type)
{
*lhs_type = coerced_type;
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
lhs_type,
rhs_type
);
}
}
} else {
return exec_err!("Expect to get struct but got {data_type}");
}
}
let mut final_struct_types = vec![];
for s in data_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(struct_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}
Ok(final_struct_types)
}
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type.equals_datatype(rhs_type) {
return Some(lhs_type.clone());
}
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true))
.or_else(|| ree_comparison_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| list_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
.or_else(|| string_temporal_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
.or_else(|| struct_coercion(lhs_type, rhs_type))
.or_else(|| map_coercion(lhs_type, rhs_type))
}
pub fn comparison_coercion_numeric(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true))
.or_else(|| ree_comparison_coercion_numeric(lhs_type, rhs_type, true))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type))
}
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, _) if rhs_type.is_numeric() => Some(Utf8),
(LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8),
(Utf8View, _) if rhs_type.is_numeric() => Some(Utf8View),
(_, Utf8) if lhs_type.is_numeric() => Some(Utf8),
(_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8),
(_, Utf8View) if lhs_type.is_numeric() => Some(Utf8View),
_ => None,
}
}
fn string_numeric_coercion_as_numeric(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
let lhs_logical_type = NativeType::from(lhs_type);
let rhs_logical_type = NativeType::from(rhs_type);
if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String {
return Some(lhs_type.to_owned());
}
if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String {
return Some(rhs_type.to_owned());
}
None
}
fn string_temporal_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> {
match (l, r) {
(Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => {
match temporal {
Date32 | Date64 => Some(temporal.clone()),
Time32(_) | Time64(_) => {
if is_time_with_valid_unit(temporal) {
Some(temporal.to_owned())
} else {
None
}
}
Timestamp(_, tz) => Some(Timestamp(Nanosecond, tz.clone())),
_ => None,
}
}
_ => None,
}
}
match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type))
}
pub fn binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
return None;
};
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}
if let Some(t) = decimal_coercion(lhs_type, rhs_type) {
return Some(t);
}
numerical_coercion(lhs_type, rhs_type)
}
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(lhs_type, rhs_type)
if lhs_type.is_decimal()
&& rhs_type.is_decimal()
&& std::mem::discriminant(lhs_type)
== std::mem::discriminant(rhs_type) =>
{
get_wider_decimal_type(lhs_type, rhs_type)
}
(lhs_type, rhs_type)
if lhs_type.is_decimal()
&& rhs_type.is_decimal()
&& std::mem::discriminant(lhs_type)
!= std::mem::discriminant(rhs_type) =>
{
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
(Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => {
get_common_decimal_type(lhs_type, rhs_type)
}
(_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
get_common_decimal_type(rhs_type, lhs_type)
}
(_, _) => None,
}
}
fn get_wider_decimal_type_cross_variant(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
let (p1, s1) = match lhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};
let (p2, s2) = match rhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};
let s = s1.max(s2);
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
let required_precision = (range + s) as u8;
match (lhs_type, rhs_type) {
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
if required_precision <= DECIMAL64_MAX_PRECISION =>
{
Some(Decimal64(required_precision, s))
}
(Decimal32(_, _), Decimal128(_, _))
| (Decimal128(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal128(_, _))
| (Decimal128(_, _), Decimal64(_, _))
if required_precision <= DECIMAL128_MAX_PRECISION =>
{
Some(Decimal128(required_precision, s))
}
(Decimal32(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal64(_, _))
| (Decimal128(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal128(_, _))
if required_precision <= DECIMAL256_MAX_PRECISION =>
{
Some(Decimal256(required_precision, s))
}
_ => None,
}
}
fn get_common_decimal_type(
decimal_type: &DataType,
other_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match decimal_type {
Decimal32(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal64(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal128(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal256(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
_ => None,
}
}
fn get_wider_decimal_type(
lhs_decimal_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
match (lhs_decimal_type, rhs_type) {
(DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => {
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal32_type((range + s) as u8, s))
}
(DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => {
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal64_type((range + s) as u8, s))
}
(DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => {
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal128_type((range + s) as u8, s))
}
(DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal256_type((range + s) as u8, s))
}
(_, _) => None,
}
}
fn coerce_numeric_type_to_decimal32(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match numeric_type {
Int8 | UInt8 => Some(Decimal32(3, 0)),
Int16 | UInt16 => Some(Decimal32(5, 0)),
Float16 => Some(Decimal32(6, 3)),
_ => None,
}
}
fn coerce_numeric_type_to_decimal64(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match numeric_type {
Int8 | UInt8 => Some(Decimal64(3, 0)),
Int16 | UInt16 => Some(Decimal64(5, 0)),
Int32 | UInt32 => Some(Decimal64(10, 0)),
Float16 => Some(Decimal64(6, 3)),
Float32 => Some(Decimal64(14, 7)),
_ => None,
}
}
fn coerce_numeric_type_to_decimal128(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match numeric_type {
Int8 | UInt8 => Some(Decimal128(3, 0)),
Int16 | UInt16 => Some(Decimal128(5, 0)),
Int32 | UInt32 => Some(Decimal128(10, 0)),
Int64 | UInt64 => Some(Decimal128(20, 0)),
Float16 => Some(Decimal128(6, 3)),
Float32 => Some(Decimal128(14, 7)),
Float64 => Some(Decimal128(30, 15)),
_ => None,
}
}
fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match numeric_type {
Int8 | UInt8 => Some(Decimal256(3, 0)),
Int16 | UInt16 => Some(Decimal256(5, 0)),
Int32 | UInt32 => Some(Decimal256(10, 0)),
Int64 | UInt64 => Some(Decimal256(20, 0)),
Float16 => Some(Decimal256(6, 3)),
Float32 => Some(Decimal256(14, 7)),
Float64 => Some(Decimal256(30, 15)),
_ => None,
}
}
fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Struct(lhs_fields), Struct(rhs_fields)) => {
if lhs_fields.len() != rhs_fields.len() {
return None;
}
if fields_have_same_names(lhs_fields, rhs_fields) {
return coerce_struct_by_name(lhs_fields, rhs_fields);
}
coerce_struct_by_position(lhs_fields, rhs_fields)
}
_ => None,
}
}
fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool {
#[cfg(debug_assertions)]
{
let lhs_names: HashSet<_> = lhs_fields.iter().map(|f| f.name()).collect();
assert_eq!(
lhs_names.len(),
lhs_fields.len(),
"Struct has duplicate field names (should be caught by Arrow schema validation)"
);
let rhs_names_check: HashSet<_> = rhs_fields.iter().map(|f| f.name()).collect();
assert_eq!(
rhs_names_check.len(),
rhs_fields.len(),
"Struct has duplicate field names (should be caught by Arrow schema validation)"
);
}
let rhs_names: HashSet<&str> = rhs_fields.iter().map(|f| f.name().as_str()).collect();
lhs_fields
.iter()
.all(|lf| rhs_names.contains(lf.name().as_str()))
}
fn coerce_struct_by_name(lhs_fields: &Fields, rhs_fields: &Fields) -> Option<DataType> {
use arrow::datatypes::DataType::*;
let rhs_by_name: HashMap<&str, &FieldRef> =
rhs_fields.iter().map(|f| (f.name().as_str(), f)).collect();
let mut coerced: Vec<FieldRef> = Vec::with_capacity(lhs_fields.len());
for lhs in lhs_fields.iter() {
let rhs = rhs_by_name.get(lhs.name().as_str()).unwrap(); let coerced_type = comparison_coercion(lhs.data_type(), rhs.data_type())?;
let is_nullable = lhs.is_nullable() || rhs.is_nullable();
coerced.push(Arc::new(Field::new(
lhs.name().clone(),
coerced_type,
is_nullable,
)));
}
Some(Struct(coerced.into()))
}
fn coerce_struct_by_position(
lhs_fields: &Fields,
rhs_fields: &Fields,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
let coerced_types: Vec<DataType> = lhs_fields
.iter()
.zip(rhs_fields.iter())
.map(|(l, r)| comparison_coercion(l.data_type(), r.data_type()))
.collect::<Option<Vec<DataType>>>()?;
let orig_pairs = lhs_fields.iter().zip(rhs_fields.iter());
let fields: Vec<FieldRef> = coerced_types
.into_iter()
.zip(orig_pairs)
.map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs))
.collect();
Some(Struct(fields.into()))
}
fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef {
let is_nullable = lhs.is_nullable() || rhs.is_nullable();
let name = lhs.name(); Arc::new(Field::new(name, common_type, is_nullable))
}
fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => {
struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map(
|key_value_type| {
Map(
Arc::new((**lhs_field).clone().with_data_type(key_value_type)),
*lhs_ordered && *rhs_ordered,
)
},
)
}
_ => None,
}
}
fn mathematics_numerical_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) {
return None;
};
match (lhs_type, rhs_type) {
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
}
(Dictionary(_, value_type), _) => {
mathematics_numerical_coercion(value_type, rhs_type)
}
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(lhs_type, value_type)
}
_ => numerical_coercion(lhs_type, rhs_type),
}
}
fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(_, Float16) | (Float16, _) => Some(Float16),
(UInt64, Int64 | Int32 | Int16 | Int8)
| (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt32, Int32 | Int16 | Int8)
| (Int32 | Int16 | Int8, UInt32) => Some(Int64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => {
Some(Int32)
}
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}
fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal32(
DECIMAL32_MAX_PRECISION.min(precision),
DECIMAL32_MAX_SCALE.min(scale),
)
}
fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal64(
DECIMAL64_MAX_PRECISION.min(precision),
DECIMAL64_MAX_SCALE.min(scale),
)
}
fn create_decimal128_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DECIMAL128_MAX_PRECISION.min(precision),
DECIMAL128_MAX_SCALE.min(scale),
)
}
fn create_decimal256_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal256(
DECIMAL256_MAX_PRECISION.min(precision),
DECIMAL256_MAX_SCALE.min(scale),
)
}
fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(_, Null) => lhs_type.is_numeric(),
(Null, _) => rhs_type.is_numeric(),
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
lhs_value_type.is_numeric() && rhs_value_type.is_numeric()
}
(Dictionary(_, value_type), _) => {
value_type.is_numeric() && rhs_type.is_numeric()
}
(_, Dictionary(_, value_type)) => {
lhs_type.is_numeric() && value_type.is_numeric()
}
_ => lhs_type.is_numeric() && rhs_type.is_numeric(),
}
}
fn dictionary_comparison_coercion_generic(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
coerce_fn: fn(&DataType, &DataType) -> Option<DataType>,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(
Dictionary(_lhs_index_type, lhs_value_type),
Dictionary(_rhs_index_type, rhs_value_type),
) => coerce_fn(lhs_value_type, rhs_value_type),
(d @ Dictionary(_, value_type), other_type)
| (other_type, d @ Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(Dictionary(_index_type, value_type), _) => coerce_fn(value_type, rhs_type),
(_, Dictionary(_index_type, value_type)) => coerce_fn(lhs_type, value_type),
_ => None,
}
}
fn dictionary_comparison_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
dictionary_comparison_coercion_generic(
lhs_type,
rhs_type,
preserve_dictionaries,
comparison_coercion,
)
}
fn dictionary_comparison_coercion_numeric(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
dictionary_comparison_coercion_generic(
lhs_type,
rhs_type,
preserve_dictionaries,
comparison_coercion_numeric,
)
}
fn ree_comparison_coercion_generic(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_ree: bool,
coerce_fn: fn(&DataType, &DataType) -> Option<DataType>,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(RunEndEncoded(_, lhs_values_field), RunEndEncoded(_, rhs_values_field)) => {
coerce_fn(lhs_values_field.data_type(), rhs_values_field.data_type())
}
(ree @ RunEndEncoded(_, values_field), other_type)
| (other_type, ree @ RunEndEncoded(_, values_field))
if preserve_ree && values_field.data_type() == other_type =>
{
Some(ree.clone())
}
(RunEndEncoded(_, values_field), _) => {
coerce_fn(values_field.data_type(), rhs_type)
}
(_, RunEndEncoded(_, values_field)) => {
coerce_fn(lhs_type, values_field.data_type())
}
_ => None,
}
}
fn ree_comparison_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_ree: bool,
) -> Option<DataType> {
ree_comparison_coercion_generic(lhs_type, rhs_type, preserve_ree, comparison_coercion)
}
fn ree_comparison_coercion_numeric(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_ree: bool,
) -> Option<DataType> {
ree_comparison_coercion_generic(
lhs_type,
rhs_type,
preserve_ree,
comparison_coercion_numeric,
)
}
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) {
(Utf8View, from_type) | (from_type, Utf8View) => {
string_concat_internal_coercion(from_type, &Utf8View)
}
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
string_coercion(lhs_value_type, rhs_value_type).or(None)
}
_ => None,
})
}
fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type.equals_datatype(rhs_type) {
Some(lhs_type.to_owned())
} else {
None
}
}
fn string_concat_internal_coercion(
from_type: &DataType,
to_type: &DataType,
) -> Option<DataType> {
if can_cast_types(from_type, to_type) {
Some(to_type.to_owned())
} else {
None
}
}
pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
Some(Utf8View)
}
(LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
(Utf8, Utf8) => Some(Utf8),
_ => None,
}
}
fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8 | LargeUtf8 | Utf8View, other_type)
| (other_type, Utf8 | LargeUtf8 | Utf8View)
if other_type.is_numeric() =>
{
Some(other_type.clone())
}
_ => None,
}
}
fn coerce_list_children(lhs_field: &FieldRef, rhs_field: &FieldRef) -> Option<FieldRef> {
let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()];
Some(Arc::new(
(**lhs_field)
.clone()
.with_data_type(type_union_resolution(&data_types)?)
.with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()),
))
}
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => {
if ls == rs {
Some(FixedSizeList(
coerce_list_children(lhs_field, rhs_field)?,
*rs,
))
} else {
Some(List(coerce_list_children(lhs_field, rhs_field)?))
}
}
(
LargeList(lhs_field),
List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _),
)
| (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => {
Some(LargeList(coerce_list_children(lhs_field, rhs_field)?))
}
(List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _))
| (FixedSizeList(lhs_field, _), List(rhs_field)) => {
Some(List(coerce_list_children(lhs_field, rhs_field)?))
}
_ => None,
}
}
pub fn binary_to_string_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Binary, Utf8) => Some(Utf8),
(Binary, LargeUtf8) => Some(LargeUtf8),
(BinaryView, Utf8) => Some(Utf8View),
(BinaryView, LargeUtf8) => Some(LargeUtf8),
(LargeBinary, Utf8) => Some(LargeUtf8),
(LargeBinary, LargeUtf8) => Some(LargeUtf8),
(Utf8, Binary) => Some(Utf8),
(Utf8, LargeBinary) => Some(LargeUtf8),
(Utf8, BinaryView) => Some(Utf8View),
(LargeUtf8, Binary) => Some(LargeUtf8),
(LargeUtf8, LargeBinary) => Some(LargeUtf8),
(LargeUtf8, BinaryView) => Some(LargeUtf8),
_ => None,
}
}
fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View)
| (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => {
Some(BinaryView)
}
(LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary)
| (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary),
(Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => {
Some(LargeBinary)
}
(Binary, Utf8) | (Utf8, Binary) => Some(Binary),
(FixedSizeBinary(_), Binary) | (Binary, FixedSizeBinary(_)) => Some(Binary),
(FixedSizeBinary(_), BinaryView) | (BinaryView, FixedSizeBinary(_)) => {
Some(BinaryView)
}
_ => None,
}
}
pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| binary_to_string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
.or_else(|| ree_comparison_coercion(lhs_type, rhs_type, false))
.or_else(|| regex_null_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()),
(Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()),
(Null, Null) => Some(Utf8),
_ => None,
}
}
pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
.or_else(|| regex_null_coercion(lhs_type, rhs_type))
}
fn is_time_with_valid_unit(datatype: &DataType) -> bool {
matches!(
datatype,
&DataType::Time32(Second)
| &DataType::Time32(Millisecond)
| &DataType::Time64(Microsecond)
| &DataType::Time64(Nanosecond)
)
}
fn temporal_coercion_nonstrict_timezone(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
let tz = match (lhs_tz, rhs_tz) {
(Some(lhs_tz), Some(_rhs_tz)) => Some(Arc::clone(lhs_tz)),
(Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
(None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
(None, None) => None,
};
let unit = timeunit_coercion(lhs_unit, rhs_unit);
Some(Timestamp(unit, tz))
}
_ => temporal_coercion(lhs_type, rhs_type),
}
}
fn temporal_coercion_strict_timezone(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
let tz = match (lhs_tz, rhs_tz) {
(Some(lhs_tz), Some(rhs_tz)) => {
match (lhs_tz.as_ref(), rhs_tz.as_ref()) {
("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)),
(lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)),
_ => {
return None;
}
}
}
(Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
(None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
(None, None) => None,
};
let unit = timeunit_coercion(lhs_unit, rhs_unit);
Some(Timestamp(unit, tz))
}
_ => temporal_coercion(lhs_type, rhs_type),
}
}
fn temporal_math_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<(DataType, DataType)> {
use DataType::*;
match (lhs_type, rhs_type) {
(Date32, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) => {
Some((Date32, Interval(MonthDayNano)))
}
(Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date32) => {
Some((Interval(MonthDayNano), Date32))
}
(Date64, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) => {
Some((Date64, Interval(MonthDayNano)))
}
(Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date64) => {
Some((Interval(MonthDayNano), Date64))
}
(Date32, Time32(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))),
(Time32(_), Date32) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))),
(Date32, Time64(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))),
(Time64(_), Date32) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))),
(Date64, Time32(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))),
(Time32(_), Date64) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))),
(Date64, Time64(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))),
(Time64(_), Date64) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))),
(Timestamp(ts_unit, tz), Duration(_)) => {
Some((Timestamp(*ts_unit, tz.clone()), Duration(*ts_unit)))
}
(Duration(_), Timestamp(ts_unit, tz)) => {
Some((Duration(*ts_unit), Timestamp(*ts_unit, tz.clone())))
}
(Time32(_) | Time64(_), Time32(_) | Time64(_)) => {
Some((Interval(MonthDayNano), Interval(MonthDayNano)))
}
(Time32(_) | Time64(_), Interval(_)) => {
Some((Interval(MonthDayNano), Interval(MonthDayNano)))
}
(Interval(_), Time32(_) | Time64(_)) => {
Some((Interval(MonthDayNano), Interval(MonthDayNano)))
}
(
Interval(_),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float16
| Float32 | Float64,
) => Some((Interval(MonthDayNano), Interval(MonthDayNano))),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float16
| Float32 | Float64,
Interval(_),
) => Some((Interval(MonthDayNano), Interval(MonthDayNano))),
_ => None,
}
}
fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
use arrow::datatypes::IntervalUnit::*;
use arrow::datatypes::TimeUnit::*;
match (lhs_type, rhs_type) {
(Interval(_) | Duration(_), Interval(_) | Duration(_)) => {
Some(Interval(MonthDayNano))
}
(Date32, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64)
| (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date32) => {
Some(Date32)
}
(Date64, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64)
| (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date64) => {
Some(Date64)
}
(Date64, Date32) | (Date32, Date64) => Some(Date64),
(Date32, Time32(_)) | (Time32(_), Date32) => Some(Timestamp(Nanosecond, None)),
(Date32, Time64(_)) | (Time64(_), Date32) => Some(Timestamp(Nanosecond, None)),
(Date64, Time32(_)) | (Time32(_), Date64) => Some(Timestamp(Nanosecond, None)),
(Date64, Time64(_)) | (Time64(_), Date64) => Some(Timestamp(Nanosecond, None)),
(Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => {
Some(Timestamp(Nanosecond, None))
}
(Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => {
Some(Timestamp(Nanosecond, None))
}
(Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
Some(Timestamp(Nanosecond, None))
}
(Timestamp(_, _tz), Date32) | (Date32, Timestamp(_, _tz)) => {
Some(Timestamp(Nanosecond, None))
}
_ => None,
}
}
fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit {
use arrow::datatypes::TimeUnit::*;
match (lhs_unit, rhs_unit) {
(Second, Millisecond) => Second,
(Second, Microsecond) => Second,
(Second, Nanosecond) => Second,
(Millisecond, Second) => Second,
(Millisecond, Microsecond) => Millisecond,
(Millisecond, Nanosecond) => Millisecond,
(Microsecond, Second) => Second,
(Microsecond, Millisecond) => Millisecond,
(Microsecond, Nanosecond) => Microsecond,
(Nanosecond, Second) => Second,
(Nanosecond, Millisecond) => Millisecond,
(Nanosecond, Microsecond) => Microsecond,
(l, r) => {
assert_eq!(l, r);
*l
}
}
}
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, other_type) | (other_type, DataType::Null) => {
if can_cast_types(&DataType::Null, other_type) {
Some(other_type.clone())
} else {
None
}
}
_ => None,
}
}
#[cfg(test)]
mod tests;