use super::{
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
TypeSignature,
};
use crate::error::{Result, _internal_err};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
};
use std::{fmt::Display, sync::Arc};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum NativeType {
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
Timestamp(TimeUnit, Option<Arc<str>>),
Date,
Time(TimeUnit),
Duration(TimeUnit),
Interval(IntervalUnit),
Binary,
FixedSizeBinary(i32),
String,
List(LogicalFieldRef),
FixedSizeList(LogicalFieldRef, i32),
Struct(LogicalFields),
Union(LogicalUnionFields),
Decimal(u8, i8),
Map(LogicalFieldRef),
}
impl Display for NativeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NativeType::{self:?}")
}
}
impl LogicalType for NativeType {
fn native(&self) -> &NativeType {
self
}
fn signature(&self) -> TypeSignature<'_> {
TypeSignature::Native(self)
}
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
use DataType::*;
fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
to.name.clone(),
to.logical_type.default_cast_for(from.data_type())?,
to.nullable,
)))
}
Ok(match (self, origin) {
(Self::Null, _) => Null,
(Self::Boolean, _) => Boolean,
(Self::Int8, _) => Int8,
(Self::Int16, _) => Int16,
(Self::Int32, _) => Int32,
(Self::Int64, _) => Int64,
(Self::UInt8, _) => UInt8,
(Self::UInt16, _) => UInt16,
(Self::UInt32, _) => UInt32,
(Self::UInt64, _) => UInt64,
(Self::Float16, _) => Float16,
(Self::Float32, _) => Float32,
(Self::Float64, _) => Float64,
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
(Self::Date, origin) if matches!(origin, Date32 | Date64) => {
origin.to_owned()
}
(Self::Date, _) => Date32,
(Self::Time(tu), _) => match tu {
TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu),
TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu),
},
(Self::Duration(tu), _) => Duration(*tu),
(Self::Interval(iu), _) => Interval(*iu),
(Self::Binary, LargeUtf8) => LargeBinary,
(Self::Binary, Utf8View) => BinaryView,
(Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => {
BinaryView
}
(Self::Binary, data_type) if can_cast_types(data_type, &LargeBinary) => {
LargeBinary
}
(Self::Binary, data_type) if can_cast_types(data_type, &Binary) => Binary,
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
(Self::String, LargeBinary) => LargeUtf8,
(Self::String, BinaryView) => Utf8View,
(Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(),
(Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View,
(Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => {
LargeUtf8
}
(Self::String, data_type) if can_cast_types(data_type, &Utf8) => Utf8,
(Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
List(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeList(from_field)) => {
LargeList(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), ListView(from_field)) => {
ListView(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeListView(from_field)) => {
LargeListView(default_field_cast(to_field, from_field)?)
}
(Self::List(field), _) => List(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
))),
(
Self::FixedSizeList(to_field, to_size),
FixedSizeList(from_field, from_size),
) if from_size == to_size => {
FixedSizeList(default_field_cast(to_field, from_field)?, *to_size)
}
(
Self::FixedSizeList(to_field, size),
List(from_field)
| LargeList(from_field)
| ListView(from_field)
| LargeListView(from_field),
) => FixedSizeList(default_field_cast(to_field, from_field)?, *size),
(Self::FixedSizeList(field, size), _) => FixedSizeList(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
)),
*size,
),
(Self::Struct(to_fields), Struct(from_fields))
if from_fields.len() == to_fields.len() =>
{
Struct(
from_fields
.iter()
.zip(to_fields.iter())
.map(|(from, to)| default_field_cast(to, from))
.collect::<Result<Fields>>()?,
)
}
(Self::Struct(to_fields), Null) => Struct(
to_fields
.iter()
.map(|field| {
Ok(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)))
})
.collect::<Result<Fields>>()?,
),
(Self::Map(to_field), Map(from_field, sorted)) => {
Map(default_field_cast(to_field, from_field)?, *sorted)
}
(Self::Map(field), Null) => Map(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)),
false,
),
(Self::Union(to_fields), Union(from_fields, mode))
if from_fields.len() == to_fields.len() =>
{
Union(
from_fields
.iter()
.zip(to_fields.iter())
.map(|((_, from), (i, to))| {
Ok((*i, default_field_cast(to, from)?))
})
.collect::<Result<UnionFields>>()?,
*mode,
)
}
_ => {
return _internal_err!(
"Unavailable default cast for native type {:?} from physical type {:?}",
self,
origin
)
}
})
}
}
impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
}
}
impl From<DataType> for NativeType {
fn from(value: DataType) -> Self {
use NativeType::*;
match value {
DataType::Null => Null,
DataType::Boolean => Boolean,
DataType::Int8 => Int8,
DataType::Int16 => Int16,
DataType::Int32 => Int32,
DataType::Int64 => Int64,
DataType::UInt8 => UInt8,
DataType::UInt16 => UInt16,
DataType::UInt32 => UInt32,
DataType::UInt64 => UInt64,
DataType::Float16 => Float16,
DataType::Float32 => Float32,
DataType::Float64 => Float64,
DataType::Timestamp(tu, tz) => Timestamp(tu, tz),
DataType::Date32 | DataType::Date64 => Date,
DataType::Time32(tu) | DataType::Time64(tu) => Time(tu),
DataType::Duration(tu) => Duration(tu),
DataType::Interval(iu) => Interval(iu),
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary,
DataType::FixedSizeBinary(size) => FixedSizeBinary(size),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => String,
DataType::List(field)
| DataType::ListView(field)
| DataType::LargeList(field)
| DataType::LargeListView(field) => List(Arc::new(field.as_ref().into())),
DataType::FixedSizeList(field, size) => {
FixedSizeList(Arc::new(field.as_ref().into()), size)
}
DataType::Struct(fields) => Struct(LogicalFields::from(&fields)),
DataType::Union(union_fields, _) => {
Union(LogicalUnionFields::from(&union_fields))
}
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
}
}
}
impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal(_, _)
)
}
#[inline]
pub fn is_integer(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
)
}
#[inline]
pub fn is_timestamp(&self) -> bool {
matches!(self, NativeType::Timestamp(_, _))
}
#[inline]
pub fn is_date(&self) -> bool {
matches!(self, NativeType::Date)
}
#[inline]
pub fn is_time(&self) -> bool {
matches!(self, NativeType::Time(_))
}
#[inline]
pub fn is_interval(&self) -> bool {
matches!(self, NativeType::Interval(_))
}
#[inline]
pub fn is_duration(&self) -> bool {
matches!(self, NativeType::Duration(_))
}
}