use std::fmt;
use serde_derive::{Deserialize, Serialize};
use serde_json::{json, Value, Value::String as VString};
use crate::error::{ArrowError, Result};
use super::Field;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum DataType {
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
Timestamp(TimeUnit, Option<String>),
Date32,
Date64,
Time32(TimeUnit),
Time64(TimeUnit),
Duration(TimeUnit),
Interval(IntervalUnit),
Binary,
FixedSizeBinary(i32),
LargeBinary,
Utf8,
LargeUtf8,
List(Box<Field>),
FixedSizeList(Box<Field>, i32),
LargeList(Box<Field>),
Struct(Vec<Field>),
Union(Vec<Field>),
Dictionary(Box<DataType>, Box<DataType>),
Decimal(usize, usize),
Map(Box<Field>, bool),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TimeUnit {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum IntervalUnit {
YearMonth,
DayTime,
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl DataType {
pub(crate) fn from(json: &Value) -> Result<DataType> {
let default_field = Field::new("", DataType::Boolean, true);
match *json {
Value::Object(ref map) => match map.get("name") {
Some(s) if s == "null" => Ok(DataType::Null),
Some(s) if s == "bool" => Ok(DataType::Boolean),
Some(s) if s == "binary" => Ok(DataType::Binary),
Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
Some(s) if s == "utf8" => Ok(DataType::Utf8),
Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
Some(s) if s == "fixedsizebinary" => {
if let Some(Value::Number(size)) = map.get("byteWidth") {
Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
} else {
Err(ArrowError::ParseError(
"Expecting a byteWidth for fixedsizebinary".to_string(),
))
}
}
Some(s) if s == "decimal" => {
let precision = match map.get("precision") {
Some(p) => Ok(p.as_u64().unwrap() as usize),
None => Err(ArrowError::ParseError(
"Expecting a precision for decimal".to_string(),
)),
};
let scale = match map.get("scale") {
Some(s) => Ok(s.as_u64().unwrap() as usize),
_ => Err(ArrowError::ParseError(
"Expecting a scale for decimal".to_string(),
)),
};
Ok(DataType::Decimal(precision?, scale?))
}
Some(s) if s == "floatingpoint" => match map.get("precision") {
Some(p) if p == "HALF" => Ok(DataType::Float16),
Some(p) if p == "SINGLE" => Ok(DataType::Float32),
Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
_ => Err(ArrowError::ParseError(
"floatingpoint precision missing or invalid".to_string(),
)),
},
Some(s) if s == "timestamp" => {
let unit = match map.get("unit") {
Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
_ => Err(ArrowError::ParseError(
"timestamp unit missing or invalid".to_string(),
)),
};
let tz = match map.get("timezone") {
None => Ok(None),
Some(VString(tz)) => Ok(Some(tz.clone())),
_ => Err(ArrowError::ParseError(
"timezone must be a string".to_string(),
)),
};
Ok(DataType::Timestamp(unit?, tz?))
}
Some(s) if s == "date" => match map.get("unit") {
Some(p) if p == "DAY" => Ok(DataType::Date32),
Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
_ => Err(ArrowError::ParseError(
"date unit missing or invalid".to_string(),
)),
},
Some(s) if s == "time" => {
let unit = match map.get("unit") {
Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
_ => Err(ArrowError::ParseError(
"time unit missing or invalid".to_string(),
)),
};
match map.get("bitWidth") {
Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
_ => Err(ArrowError::ParseError(
"time bitWidth missing or invalid".to_string(),
)),
}
}
Some(s) if s == "duration" => match map.get("unit") {
Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
Some(p) if p == "MILLISECOND" => {
Ok(DataType::Duration(TimeUnit::Millisecond))
}
Some(p) if p == "MICROSECOND" => {
Ok(DataType::Duration(TimeUnit::Microsecond))
}
Some(p) if p == "NANOSECOND" => {
Ok(DataType::Duration(TimeUnit::Nanosecond))
}
_ => Err(ArrowError::ParseError(
"time unit missing or invalid".to_string(),
)),
},
Some(s) if s == "interval" => match map.get("unit") {
Some(p) if p == "DAY_TIME" => {
Ok(DataType::Interval(IntervalUnit::DayTime))
}
Some(p) if p == "YEAR_MONTH" => {
Ok(DataType::Interval(IntervalUnit::YearMonth))
}
_ => Err(ArrowError::ParseError(
"interval unit missing or invalid".to_string(),
)),
},
Some(s) if s == "int" => match map.get("isSigned") {
Some(&Value::Bool(true)) => match map.get("bitWidth") {
Some(&Value::Number(ref n)) => match n.as_u64() {
Some(8) => Ok(DataType::Int8),
Some(16) => Ok(DataType::Int16),
Some(32) => Ok(DataType::Int32),
Some(64) => Ok(DataType::Int64),
_ => Err(ArrowError::ParseError(
"int bitWidth missing or invalid".to_string(),
)),
},
_ => Err(ArrowError::ParseError(
"int bitWidth missing or invalid".to_string(),
)),
},
Some(&Value::Bool(false)) => match map.get("bitWidth") {
Some(&Value::Number(ref n)) => match n.as_u64() {
Some(8) => Ok(DataType::UInt8),
Some(16) => Ok(DataType::UInt16),
Some(32) => Ok(DataType::UInt32),
Some(64) => Ok(DataType::UInt64),
_ => Err(ArrowError::ParseError(
"int bitWidth missing or invalid".to_string(),
)),
},
_ => Err(ArrowError::ParseError(
"int bitWidth missing or invalid".to_string(),
)),
},
_ => Err(ArrowError::ParseError(
"int signed missing or invalid".to_string(),
)),
},
Some(s) if s == "list" => {
Ok(DataType::List(Box::new(default_field)))
}
Some(s) if s == "largelist" => {
Ok(DataType::LargeList(Box::new(default_field)))
}
Some(s) if s == "fixedsizelist" => {
if let Some(Value::Number(size)) = map.get("listSize") {
Ok(DataType::FixedSizeList(
Box::new(default_field),
size.as_i64().unwrap() as i32,
))
} else {
Err(ArrowError::ParseError(
"Expecting a listSize for fixedsizelist".to_string(),
))
}
}
Some(s) if s == "struct" => {
Ok(DataType::Struct(vec![]))
}
Some(s) if s == "map" => {
if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
Ok(DataType::Map(Box::new(default_field), *keys_sorted))
} else {
Err(ArrowError::ParseError(
"Expecting a keysSorted for map".to_string(),
))
}
}
Some(other) => Err(ArrowError::ParseError(format!(
"invalid or unsupported type name: {} in {:?}",
other, json
))),
None => Err(ArrowError::ParseError("type name missing".to_string())),
},
_ => Err(ArrowError::ParseError(
"invalid json value type".to_string(),
)),
}
}
pub fn to_json(&self) -> Value {
match self {
DataType::Null => json!({"name": "null"}),
DataType::Boolean => json!({"name": "bool"}),
DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
DataType::Utf8 => json!({"name": "utf8"}),
DataType::LargeUtf8 => json!({"name": "largeutf8"}),
DataType::Binary => json!({"name": "binary"}),
DataType::LargeBinary => json!({"name": "largebinary"}),
DataType::FixedSizeBinary(byte_width) => {
json!({"name": "fixedsizebinary", "byteWidth": byte_width})
}
DataType::Struct(_) => json!({"name": "struct"}),
DataType::Union(_) => json!({"name": "union"}),
DataType::List(_) => json!({ "name": "list"}),
DataType::LargeList(_) => json!({ "name": "largelist"}),
DataType::FixedSizeList(_, length) => {
json!({"name":"fixedsizelist", "listSize": length})
}
DataType::Time32(unit) => {
json!({"name": "time", "bitWidth": 32, "unit": match unit {
TimeUnit::Second => "SECOND",
TimeUnit::Millisecond => "MILLISECOND",
TimeUnit::Microsecond => "MICROSECOND",
TimeUnit::Nanosecond => "NANOSECOND",
}})
}
DataType::Time64(unit) => {
json!({"name": "time", "bitWidth": 64, "unit": match unit {
TimeUnit::Second => "SECOND",
TimeUnit::Millisecond => "MILLISECOND",
TimeUnit::Microsecond => "MICROSECOND",
TimeUnit::Nanosecond => "NANOSECOND",
}})
}
DataType::Date32 => {
json!({"name": "date", "unit": "DAY"})
}
DataType::Date64 => {
json!({"name": "date", "unit": "MILLISECOND"})
}
DataType::Timestamp(unit, None) => {
json!({"name": "timestamp", "unit": match unit {
TimeUnit::Second => "SECOND",
TimeUnit::Millisecond => "MILLISECOND",
TimeUnit::Microsecond => "MICROSECOND",
TimeUnit::Nanosecond => "NANOSECOND",
}})
}
DataType::Timestamp(unit, Some(tz)) => {
json!({"name": "timestamp", "unit": match unit {
TimeUnit::Second => "SECOND",
TimeUnit::Millisecond => "MILLISECOND",
TimeUnit::Microsecond => "MICROSECOND",
TimeUnit::Nanosecond => "NANOSECOND",
}, "timezone": tz})
}
DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
IntervalUnit::YearMonth => "YEAR_MONTH",
IntervalUnit::DayTime => "DAY_TIME",
}}),
DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
TimeUnit::Second => "SECOND",
TimeUnit::Millisecond => "MILLISECOND",
TimeUnit::Microsecond => "MICROSECOND",
TimeUnit::Nanosecond => "NANOSECOND",
}}),
DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
DataType::Decimal(precision, scale) => {
json!({"name": "decimal", "precision": precision, "scale": scale})
}
DataType::Map(_, keys_sorted) => {
json!({"name": "map", "keysSorted": keys_sorted})
}
}
}
pub fn is_numeric(t: &DataType) -> bool {
use DataType::*;
matches!(
t,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
)
}
pub(crate) fn equals_datatype(&self, other: &DataType) -> bool {
match (&self, other) {
(DataType::List(a), DataType::List(b))
| (DataType::LargeList(a), DataType::LargeList(b)) => {
a.is_nullable() == b.is_nullable()
&& a.data_type().equals_datatype(b.data_type())
}
(DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => {
a_size == b_size
&& a.is_nullable() == b.is_nullable()
&& a.data_type().equals_datatype(b.data_type())
}
(DataType::Struct(a), DataType::Struct(b)) => {
a.len() == b.len()
&& a.iter().zip(b).all(|(a, b)| {
a.is_nullable() == b.is_nullable()
&& a.data_type().equals_datatype(b.data_type())
})
}
(
DataType::Map(a_field, a_is_sorted),
DataType::Map(b_field, b_is_sorted),
) => a_field == b_field && a_is_sorted == b_is_sorted,
_ => self == other,
}
}
}