use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use crate::{ArrowError, Field, FieldRef, Fields, UnionFields};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DataType {
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
Timestamp(TimeUnit, Option<Arc<str>>),
Date32,
Date64,
Time32(TimeUnit),
Time64(TimeUnit),
Duration(TimeUnit),
Interval(IntervalUnit),
Binary,
FixedSizeBinary(i32),
LargeBinary,
BinaryView,
Utf8,
LargeUtf8,
Utf8View,
List(FieldRef),
ListView(FieldRef),
FixedSizeList(FieldRef, i32),
LargeList(FieldRef),
LargeListView(FieldRef),
Struct(Fields),
Union(UnionFields, UnionMode),
Dictionary(Box<DataType>, Box<DataType>),
Decimal128(u8, i8),
Decimal256(u8, i8),
Map(FieldRef, bool),
RunEndEncoded(FieldRef, FieldRef),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum TimeUnit {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum IntervalUnit {
YearMonth,
DayTime,
MonthDayNano,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum UnionMode {
Sparse,
Dense,
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self:?}")
}
}
impl FromStr for DataType {
type Err = ArrowError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
crate::datatype_parse::parse_data_type(s)
}
}
impl TryFrom<&str> for DataType {
type Error = ArrowError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
impl DataType {
#[inline]
pub fn is_primitive(&self) -> bool {
self.is_numeric() || self.is_temporal()
}
#[inline]
pub fn is_numeric(&self) -> bool {
use DataType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal128(_, _)
| Decimal256(_, _)
)
}
#[inline]
pub fn is_temporal(&self) -> bool {
use DataType::*;
matches!(
self,
Date32 | Date64 | Timestamp(_, _) | Time32(_) | Time64(_) | Duration(_) | Interval(_)
)
}
#[inline]
pub fn is_floating(&self) -> bool {
use DataType::*;
matches!(self, Float16 | Float32 | Float64)
}
#[inline]
pub fn is_integer(&self) -> bool {
self.is_signed_integer() || self.is_unsigned_integer()
}
#[inline]
pub fn is_signed_integer(&self) -> bool {
use DataType::*;
matches!(self, Int8 | Int16 | Int32 | Int64)
}
#[inline]
pub fn is_unsigned_integer(&self) -> bool {
use DataType::*;
matches!(self, UInt8 | UInt16 | UInt32 | UInt64)
}
#[inline]
pub fn is_dictionary_key_type(&self) -> bool {
self.is_integer()
}
#[inline]
pub fn is_run_ends_type(&self) -> bool {
use DataType::*;
matches!(self, Int16 | Int32 | Int64)
}
#[inline]
pub fn is_nested(&self) -> bool {
use DataType::*;
match self {
Dictionary(_, v) => DataType::is_nested(v.as_ref()),
List(_) | FixedSizeList(_, _) | LargeList(_) | Struct(_) | Union(_, _) | Map(_, _) => {
true
}
_ => false,
}
}
#[inline]
pub fn is_null(&self) -> bool {
use DataType::*;
matches!(self, Null)
}
pub fn equals_datatype(&self, other: &DataType) -> bool {
match (&self, other) {
(DataType::List(a), DataType::List(b))
| (DataType::LargeList(a), DataType::LargeList(b)) => {
a.is_nullable() == b.is_nullable() && a.data_type().equals_datatype(b.data_type())
}
(DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => {
a_size == b_size
&& a.is_nullable() == b.is_nullable()
&& a.data_type().equals_datatype(b.data_type())
}
(DataType::Struct(a), DataType::Struct(b)) => {
a.len() == b.len()
&& a.iter().zip(b).all(|(a, b)| {
a.is_nullable() == b.is_nullable()
&& a.data_type().equals_datatype(b.data_type())
})
}
(DataType::Map(a_field, a_is_sorted), DataType::Map(b_field, b_is_sorted)) => {
a_field.is_nullable() == b_field.is_nullable()
&& a_field.data_type().equals_datatype(b_field.data_type())
&& a_is_sorted == b_is_sorted
}
(DataType::Dictionary(a_key, a_value), DataType::Dictionary(b_key, b_value)) => {
a_key.equals_datatype(b_key) && a_value.equals_datatype(b_value)
}
(
DataType::RunEndEncoded(a_run_ends, a_values),
DataType::RunEndEncoded(b_run_ends, b_values),
) => {
a_run_ends.is_nullable() == b_run_ends.is_nullable()
&& a_run_ends
.data_type()
.equals_datatype(b_run_ends.data_type())
&& a_values.is_nullable() == b_values.is_nullable()
&& a_values.data_type().equals_datatype(b_values.data_type())
}
(
DataType::Union(a_union_fields, a_union_mode),
DataType::Union(b_union_fields, b_union_mode),
) => {
a_union_mode == b_union_mode
&& a_union_fields.len() == b_union_fields.len()
&& a_union_fields.iter().all(|a| {
b_union_fields.iter().any(|b| {
a.0 == b.0
&& a.1.is_nullable() == b.1.is_nullable()
&& a.1.data_type().equals_datatype(b.1.data_type())
})
})
}
_ => self == other,
}
}
#[inline]
pub fn primitive_width(&self) -> Option<usize> {
match self {
DataType::Null => None,
DataType::Boolean => None,
DataType::Int8 | DataType::UInt8 => Some(1),
DataType::Int16 | DataType::UInt16 | DataType::Float16 => Some(2),
DataType::Int32 | DataType::UInt32 | DataType::Float32 => Some(4),
DataType::Int64 | DataType::UInt64 | DataType::Float64 => Some(8),
DataType::Timestamp(_, _) => Some(8),
DataType::Date32 | DataType::Time32(_) => Some(4),
DataType::Date64 | DataType::Time64(_) => Some(8),
DataType::Duration(_) => Some(8),
DataType::Interval(IntervalUnit::YearMonth) => Some(4),
DataType::Interval(IntervalUnit::DayTime) => Some(8),
DataType::Interval(IntervalUnit::MonthDayNano) => Some(16),
DataType::Decimal128(_, _) => Some(16),
DataType::Decimal256(_, _) => Some(32),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => None,
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => None,
DataType::FixedSizeBinary(_) => None,
DataType::List(_)
| DataType::ListView(_)
| DataType::LargeList(_)
| DataType::LargeListView(_)
| DataType::Map(_, _) => None,
DataType::FixedSizeList(_, _) => None,
DataType::Struct(_) => None,
DataType::Union(_, _) => None,
DataType::Dictionary(_, _) => None,
DataType::RunEndEncoded(_, _) => None,
}
}
pub fn size(&self) -> usize {
std::mem::size_of_val(self)
+ match self {
DataType::Null
| DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::BinaryView
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Utf8View
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _) => 0,
DataType::Timestamp(_, s) => s.as_ref().map(|s| s.len()).unwrap_or_default(),
DataType::List(field)
| DataType::ListView(field)
| DataType::FixedSizeList(field, _)
| DataType::LargeList(field)
| DataType::LargeListView(field)
| DataType::Map(field, _) => field.size(),
DataType::Struct(fields) => fields.size(),
DataType::Union(fields, _) => fields.size(),
DataType::Dictionary(dt1, dt2) => dt1.size() + dt2.size(),
DataType::RunEndEncoded(run_ends, values) => {
run_ends.size() - std::mem::size_of_val(run_ends) + values.size()
- std::mem::size_of_val(values)
}
}
}
pub fn contains(&self, other: &DataType) -> bool {
match (self, other) {
(DataType::List(f1), DataType::List(f2))
| (DataType::LargeList(f1), DataType::LargeList(f2)) => f1.contains(f2),
(DataType::FixedSizeList(f1, s1), DataType::FixedSizeList(f2, s2)) => {
s1 == s2 && f1.contains(f2)
}
(DataType::Map(f1, s1), DataType::Map(f2, s2)) => s1 == s2 && f1.contains(f2),
(DataType::Struct(f1), DataType::Struct(f2)) => f1.contains(f2),
(DataType::Union(f1, s1), DataType::Union(f2, s2)) => {
s1 == s2
&& f1
.iter()
.all(|f1| f2.iter().any(|f2| f1.0 == f2.0 && f1.1.contains(f2.1)))
}
(DataType::Dictionary(k1, v1), DataType::Dictionary(k2, v2)) => {
k1.contains(k2) && v1.contains(v2)
}
_ => self == other,
}
}
pub fn new_list(data_type: DataType, nullable: bool) -> Self {
DataType::List(Arc::new(Field::new_list_field(data_type, nullable)))
}
pub fn new_large_list(data_type: DataType, nullable: bool) -> Self {
DataType::LargeList(Arc::new(Field::new_list_field(data_type, nullable)))
}
pub fn new_fixed_size_list(data_type: DataType, size: i32, nullable: bool) -> Self {
DataType::FixedSizeList(Arc::new(Field::new_list_field(data_type, nullable)), size)
}
}
pub const DECIMAL128_MAX_PRECISION: u8 = 38;
pub const DECIMAL128_MAX_SCALE: i8 = 38;
pub const DECIMAL256_MAX_PRECISION: u8 = 76;
pub const DECIMAL256_MAX_SCALE: i8 = 76;
pub const DECIMAL_DEFAULT_SCALE: i8 = 10;
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "serde")]
fn serde_struct_type() {
use std::collections::HashMap;
let kv_array = [("k".to_string(), "v".to_string())];
let field_metadata: HashMap<String, String> = kv_array.iter().cloned().collect();
let first_name =
Field::new("first_name", DataType::Utf8, false).with_metadata(field_metadata);
let last_name =
Field::new("last_name", DataType::Utf8, false).with_metadata(HashMap::default());
let person = DataType::Struct(Fields::from(vec![
first_name,
last_name,
Field::new(
"address",
DataType::Struct(Fields::from(vec![
Field::new("street", DataType::Utf8, false),
Field::new("zip", DataType::UInt16, false),
])),
false,
),
]));
let serialized = serde_json::to_string(&person).unwrap();
assert_eq!(
"{\"Struct\":[\
{\"name\":\"first_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{\"k\":\"v\"}},\
{\"name\":\"last_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{}},\
{\"name\":\"address\",\"data_type\":{\"Struct\":\
[{\"name\":\"street\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{}},\
{\"name\":\"zip\",\"data_type\":\"UInt16\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{}}\
]},\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{}}]}",
serialized
);
let deserialized = serde_json::from_str(&serialized).unwrap();
assert_eq!(person, deserialized);
}
#[test]
fn test_list_datatype_equality() {
let list_a = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let list_b = DataType::List(Arc::new(Field::new("array", DataType::Int32, true)));
let list_c = DataType::List(Arc::new(Field::new("item", DataType::Int32, false)));
let list_d = DataType::List(Arc::new(Field::new("item", DataType::UInt32, true)));
assert!(list_a.equals_datatype(&list_b));
assert!(!list_a.equals_datatype(&list_c));
assert!(!list_b.equals_datatype(&list_c));
assert!(!list_a.equals_datatype(&list_d));
let list_e =
DataType::FixedSizeList(Arc::new(Field::new("item", list_a.clone(), false)), 3);
let list_f =
DataType::FixedSizeList(Arc::new(Field::new("array", list_b.clone(), false)), 3);
let list_g = DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::FixedSizeBinary(3), true)),
3,
);
assert!(list_e.equals_datatype(&list_f));
assert!(!list_e.equals_datatype(&list_g));
assert!(!list_f.equals_datatype(&list_g));
let list_h = DataType::Struct(Fields::from(vec![Field::new("f1", list_e, true)]));
let list_i = DataType::Struct(Fields::from(vec![Field::new("f1", list_f.clone(), true)]));
let list_j = DataType::Struct(Fields::from(vec![Field::new("f1", list_f.clone(), false)]));
let list_k = DataType::Struct(Fields::from(vec![
Field::new("f1", list_f.clone(), false),
Field::new("f2", list_g.clone(), false),
Field::new("f3", DataType::Utf8, true),
]));
let list_l = DataType::Struct(Fields::from(vec![
Field::new("ff1", list_f.clone(), false),
Field::new("ff2", list_g.clone(), false),
Field::new("ff3", DataType::LargeUtf8, true),
]));
let list_m = DataType::Struct(Fields::from(vec![
Field::new("ff1", list_f, false),
Field::new("ff2", list_g, false),
Field::new("ff3", DataType::Utf8, true),
]));
assert!(list_h.equals_datatype(&list_i));
assert!(!list_h.equals_datatype(&list_j));
assert!(!list_k.equals_datatype(&list_l));
assert!(list_k.equals_datatype(&list_m));
let list_n = DataType::Map(Arc::new(Field::new("f1", list_a.clone(), true)), true);
let list_o = DataType::Map(Arc::new(Field::new("f2", list_b.clone(), true)), true);
let list_p = DataType::Map(Arc::new(Field::new("f2", list_b.clone(), true)), false);
let list_q = DataType::Map(Arc::new(Field::new("f2", list_c.clone(), true)), true);
let list_r = DataType::Map(Arc::new(Field::new("f1", list_a.clone(), false)), true);
assert!(list_n.equals_datatype(&list_o));
assert!(!list_n.equals_datatype(&list_p));
assert!(!list_n.equals_datatype(&list_q));
assert!(!list_n.equals_datatype(&list_r));
let list_s = DataType::Dictionary(Box::new(DataType::UInt8), Box::new(list_a));
let list_t = DataType::Dictionary(Box::new(DataType::UInt8), Box::new(list_b.clone()));
let list_u = DataType::Dictionary(Box::new(DataType::Int8), Box::new(list_b));
let list_v = DataType::Dictionary(Box::new(DataType::UInt8), Box::new(list_c));
assert!(list_s.equals_datatype(&list_t));
assert!(!list_s.equals_datatype(&list_u));
assert!(!list_s.equals_datatype(&list_v));
let union_a = DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::UInt8, false),
],
),
UnionMode::Sparse,
);
let union_b = DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("ff1", DataType::Utf8, false),
Field::new("ff2", DataType::UInt8, false),
],
),
UnionMode::Sparse,
);
let union_c = DataType::Union(
UnionFields::new(
vec![2, 1],
vec![
Field::new("fff2", DataType::UInt8, false),
Field::new("fff1", DataType::Utf8, false),
],
),
UnionMode::Sparse,
);
let union_d = DataType::Union(
UnionFields::new(
vec![2, 1],
vec![
Field::new("fff1", DataType::Int8, false),
Field::new("fff2", DataType::UInt8, false),
],
),
UnionMode::Sparse,
);
let union_e = DataType::Union(
UnionFields::new(
vec![1, 2],
vec![
Field::new("f1", DataType::Utf8, true),
Field::new("f2", DataType::UInt8, false),
],
),
UnionMode::Sparse,
);
assert!(union_a.equals_datatype(&union_b));
assert!(union_a.equals_datatype(&union_c));
assert!(!union_a.equals_datatype(&union_d));
assert!(!union_a.equals_datatype(&union_e));
let list_w = DataType::RunEndEncoded(
Arc::new(Field::new("f1", DataType::Int64, true)),
Arc::new(Field::new("f2", DataType::Utf8, true)),
);
let list_x = DataType::RunEndEncoded(
Arc::new(Field::new("ff1", DataType::Int64, true)),
Arc::new(Field::new("ff2", DataType::Utf8, true)),
);
let list_y = DataType::RunEndEncoded(
Arc::new(Field::new("ff1", DataType::UInt16, true)),
Arc::new(Field::new("ff2", DataType::Utf8, true)),
);
let list_z = DataType::RunEndEncoded(
Arc::new(Field::new("f1", DataType::Int64, false)),
Arc::new(Field::new("f2", DataType::Utf8, true)),
);
assert!(list_w.equals_datatype(&list_x));
assert!(!list_w.equals_datatype(&list_y));
assert!(!list_w.equals_datatype(&list_z));
}
#[test]
fn create_struct_type() {
let _person = DataType::Struct(Fields::from(vec![
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new(
"address",
DataType::Struct(Fields::from(vec![
Field::new("street", DataType::Utf8, false),
Field::new("zip", DataType::UInt16, false),
])),
false,
),
]));
}
#[test]
fn test_nested() {
let list = DataType::List(Arc::new(Field::new("foo", DataType::Utf8, true)));
assert!(!DataType::is_nested(&DataType::Boolean));
assert!(!DataType::is_nested(&DataType::Int32));
assert!(!DataType::is_nested(&DataType::Utf8));
assert!(DataType::is_nested(&list));
assert!(!DataType::is_nested(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Boolean)
)));
assert!(!DataType::is_nested(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Int64)
)));
assert!(!DataType::is_nested(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::LargeUtf8)
)));
assert!(DataType::is_nested(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(list)
)));
}
#[test]
fn test_integer() {
assert!(DataType::is_integer(&DataType::Int32));
assert!(DataType::is_integer(&DataType::UInt64));
assert!(!DataType::is_integer(&DataType::Float16));
assert!(DataType::is_signed_integer(&DataType::Int32));
assert!(!DataType::is_signed_integer(&DataType::UInt64));
assert!(!DataType::is_signed_integer(&DataType::Float16));
assert!(!DataType::is_unsigned_integer(&DataType::Int32));
assert!(DataType::is_unsigned_integer(&DataType::UInt64));
assert!(!DataType::is_unsigned_integer(&DataType::Float16));
assert!(DataType::is_dictionary_key_type(&DataType::Int32));
assert!(DataType::is_dictionary_key_type(&DataType::UInt64));
assert!(!DataType::is_dictionary_key_type(&DataType::Float16));
}
#[test]
fn test_floating() {
assert!(DataType::is_floating(&DataType::Float16));
assert!(!DataType::is_floating(&DataType::Int32));
}
#[test]
fn test_datatype_is_null() {
assert!(DataType::is_null(&DataType::Null));
assert!(!DataType::is_null(&DataType::Int32));
}
#[test]
fn size_should_not_regress() {
assert_eq!(std::mem::size_of::<DataType>(), 24);
}
#[test]
#[should_panic(expected = "duplicate type id: 1")]
fn test_union_with_duplicated_type_id() {
let type_ids = vec![1, 1];
let _union = DataType::Union(
UnionFields::new(
type_ids,
vec![
Field::new("f1", DataType::Int32, false),
Field::new("f2", DataType::Utf8, false),
],
),
UnionMode::Dense,
);
}
#[test]
fn test_try_from_str() {
let data_type: DataType = "Int32".try_into().unwrap();
assert_eq!(data_type, DataType::Int32);
}
#[test]
fn test_from_str() {
let data_type: DataType = "UInt64".parse().unwrap();
assert_eq!(data_type, DataType::UInt64);
}
}