use arrow_format::ipc::planus::ReadAsRoot;
use crate::{
datatypes::{
get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema,
TimeUnit, UnionMode,
},
error::{ArrowError, Result},
};
use super::{
super::{IpcField, IpcSchema},
StreamMetadata,
};
fn try_unzip_vec<A, B, I: Iterator<Item = Result<(A, B)>>>(iter: I) -> Result<(Vec<A>, Vec<B>)> {
let mut a = vec![];
let mut b = vec![];
for maybe_item in iter {
let (a_i, b_i) = maybe_item?;
a.push(a_i);
b.push(b_i);
}
Ok((a, b))
}
fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> Result<(Field, IpcField)> {
let metadata = read_metadata(&ipc_field)?;
let extension = get_extension(&metadata);
let (data_type, ipc_field_) = get_data_type(ipc_field, extension, true)?;
let field = Field {
name: ipc_field
.name()?
.ok_or_else(|| ArrowError::oos("Every field in IPC must have a name"))?
.to_string(),
data_type,
is_nullable: ipc_field.nullable()?,
metadata,
};
Ok((field, ipc_field_))
}
fn read_metadata(field: &arrow_format::ipc::FieldRef) -> Result<Metadata> {
Ok(if let Some(list) = field.custom_metadata()? {
let mut metadata_map = Metadata::new();
for kv in list {
let kv = kv?;
if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
metadata_map.insert(k.to_string(), v.to_string());
}
}
metadata_map
} else {
Metadata::default()
})
}
fn deserialize_integer(int: arrow_format::ipc::IntRef) -> Result<IntegerType> {
Ok(match (int.bit_width()?, int.is_signed()?) {
(8, true) => IntegerType::Int8,
(8, false) => IntegerType::UInt8,
(16, true) => IntegerType::Int16,
(16, false) => IntegerType::UInt16,
(32, true) => IntegerType::Int32,
(32, false) => IntegerType::UInt32,
(64, true) => IntegerType::Int64,
(64, false) => IntegerType::UInt64,
_ => {
return Err(ArrowError::oos(
"IPC: indexType can only be 8, 16, 32 or 64.",
))
}
})
}
fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> Result<TimeUnit> {
use arrow_format::ipc::TimeUnit::*;
Ok(match time_unit {
Second => TimeUnit::Second,
Millisecond => TimeUnit::Millisecond,
Microsecond => TimeUnit::Microsecond,
Nanosecond => TimeUnit::Nanosecond,
})
}
fn get_data_type(
field: arrow_format::ipc::FieldRef,
extension: Extension,
may_be_dictionary: bool,
) -> Result<(DataType, IpcField)> {
if let Some(dictionary) = field.dictionary()? {
if may_be_dictionary {
let int = dictionary
.index_type()?
.ok_or_else(|| ArrowError::oos("indexType is mandatory in Dictionary."))?;
let index_type = deserialize_integer(int)?;
let (inner, mut ipc_field) = get_data_type(field, extension, false)?;
ipc_field.dictionary_id = Some(dictionary.id()?);
return Ok((
DataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
ipc_field,
));
}
}
if let Some(extension) = extension {
let (name, metadata) = extension;
let (data_type, fields) = get_data_type(field, None, false)?;
return Ok((
DataType::Extension(name, Box::new(data_type), metadata),
fields,
));
}
let type_ = field
.type_()?
.ok_or_else(|| ArrowError::oos("IPC: field type is mandatory"))?;
use arrow_format::ipc::TypeRef::*;
Ok(match type_ {
Null(_) => (DataType::Null, IpcField::default()),
Bool(_) => (DataType::Boolean, IpcField::default()),
Int(int) => {
let data_type = deserialize_integer(int)?.into();
(data_type, IpcField::default())
}
Binary(_) => (DataType::Binary, IpcField::default()),
LargeBinary(_) => (DataType::LargeBinary, IpcField::default()),
Utf8(_) => (DataType::Utf8, IpcField::default()),
LargeUtf8(_) => (DataType::LargeUtf8, IpcField::default()),
FixedSizeBinary(fixed) => (
DataType::FixedSizeBinary(fixed.byte_width()? as usize),
IpcField::default(),
),
FloatingPoint(float) => {
let data_type = match float.precision()? {
arrow_format::ipc::Precision::Half => DataType::Float16,
arrow_format::ipc::Precision::Single => DataType::Float32,
arrow_format::ipc::Precision::Double => DataType::Float64,
};
(data_type, IpcField::default())
}
Date(date) => {
let data_type = match date.unit()? {
arrow_format::ipc::DateUnit::Day => DataType::Date32,
arrow_format::ipc::DateUnit::Millisecond => DataType::Date64,
};
(data_type, IpcField::default())
}
Time(time) => {
let unit = deserialize_timeunit(time.unit()?)?;
let data_type = match (time.bit_width()?, unit) {
(32, TimeUnit::Second) => DataType::Time32(TimeUnit::Second),
(32, TimeUnit::Millisecond) => DataType::Time32(TimeUnit::Millisecond),
(64, TimeUnit::Microsecond) => DataType::Time64(TimeUnit::Microsecond),
(64, TimeUnit::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond),
(bits, precision) => {
return Err(ArrowError::nyi(format!(
"Time type with bit width of {} and unit of {:?}",
bits, precision
)))
}
};
(data_type, IpcField::default())
}
Timestamp(timestamp) => {
let timezone = timestamp.timezone()?.map(|tz| tz.to_string());
let time_unit = deserialize_timeunit(timestamp.unit()?)?;
(
DataType::Timestamp(time_unit, timezone),
IpcField::default(),
)
}
Interval(interval) => {
let data_type = match interval.unit()? {
arrow_format::ipc::IntervalUnit::YearMonth => {
DataType::Interval(IntervalUnit::YearMonth)
}
arrow_format::ipc::IntervalUnit::DayTime => {
DataType::Interval(IntervalUnit::DayTime)
}
arrow_format::ipc::IntervalUnit::MonthDayNano => {
DataType::Interval(IntervalUnit::MonthDayNano)
}
};
(data_type, IpcField::default())
}
Duration(duration) => {
let time_unit = deserialize_timeunit(duration.unit()?)?;
(DataType::Duration(time_unit), IpcField::default())
}
Decimal(decimal) => {
let data_type =
DataType::Decimal(decimal.precision()? as usize, decimal.scale()? as usize);
(data_type, IpcField::default())
}
List(_) => {
let children = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: List must contain children"))?;
let inner = children
.get(0)
.ok_or_else(|| ArrowError::oos("IPC: List must contain one child"))??;
let (field, ipc_field) = deserialize_field(inner)?;
(
DataType::List(Box::new(field)),
IpcField {
fields: vec![ipc_field],
dictionary_id: None,
},
)
}
LargeList(_) => {
let children = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: List must contain children"))?;
let inner = children
.get(0)
.ok_or_else(|| ArrowError::oos("IPC: List must contain one child"))??;
let (field, ipc_field) = deserialize_field(inner)?;
(
DataType::LargeList(Box::new(field)),
IpcField {
fields: vec![ipc_field],
dictionary_id: None,
},
)
}
FixedSizeList(list) => {
let children = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: FixedSizeList must contain children"))?;
let inner = children
.get(0)
.ok_or_else(|| ArrowError::oos("IPC: FixedSizeList must contain one child"))??;
let (field, ipc_field) = deserialize_field(inner)?;
let size = list.list_size()? as usize;
(
DataType::FixedSizeList(Box::new(field), size),
IpcField {
fields: vec![ipc_field],
dictionary_id: None,
},
)
}
Struct(_) => {
let fields = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: Struct must contain children"))?;
if fields.is_empty() {
return Err(ArrowError::oos(
"IPC: Struct must contain at least one child",
));
}
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
let (field, fields) = deserialize_field(field?)?;
Ok((field, fields))
}))?;
let ipc_field = IpcField {
fields: ipc_fields,
dictionary_id: None,
};
(DataType::Struct(fields), ipc_field)
}
Union(union_) => {
let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
let ids = union_.type_ids()?.map(|x| x.iter().collect());
let fields = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: Union must contain children"))?;
if fields.is_empty() {
return Err(ArrowError::oos(
"IPC: Union must contain at least one child",
));
}
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
let (field, fields) = deserialize_field(field?)?;
Ok((field, fields))
}))?;
let ipc_field = IpcField {
fields: ipc_fields,
dictionary_id: None,
};
(DataType::Union(fields, ids, mode), ipc_field)
}
Map(map) => {
let is_sorted = map.keys_sorted()?;
let children = field
.children()?
.ok_or_else(|| ArrowError::oos("IPC: Map must contain children"))?;
let inner = children
.get(0)
.ok_or_else(|| ArrowError::oos("IPC: Map must contain one child"))??;
let (field, ipc_field) = deserialize_field(inner)?;
let data_type = DataType::Map(Box::new(field), is_sorted);
(
data_type,
IpcField {
fields: vec![ipc_field],
dictionary_id: None,
},
)
}
})
}
pub fn deserialize_schema(bytes: &[u8]) -> Result<(Schema, IpcSchema)> {
let message = arrow_format::ipc::MessageRef::read_as_root(bytes)
.map_err(|err| ArrowError::oos(format!("Unable deserialize message: {:?}", err)))?;
let schema = match message.header()?.ok_or_else(|| {
ArrowError::oos("Unable to convert flight data header to a record batch".to_string())
})? {
arrow_format::ipc::MessageHeaderRef::Schema(schema) => Ok(schema),
_ => Err(ArrowError::nyi(
"flight currently only supports reading RecordBatch messages",
)),
}?;
fb_to_schema(schema)
}
pub(super) fn fb_to_schema(schema: arrow_format::ipc::SchemaRef) -> Result<(Schema, IpcSchema)> {
let fields = schema
.fields()?
.ok_or_else(|| ArrowError::oos("IPC: Schema must contain fields"))?;
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
let (field, fields) = deserialize_field(field?)?;
Ok((field, fields))
}))?;
let is_little_endian = match schema.endianness()? {
arrow_format::ipc::Endianness::Little => true,
arrow_format::ipc::Endianness::Big => false,
};
let mut metadata = Metadata::default();
if let Some(md_fields) = schema.custom_metadata()? {
for kv in md_fields {
let kv = kv?;
let k_str = kv.key()?;
let v_str = kv.value()?;
if let Some(k) = k_str {
if let Some(v) = v_str {
metadata.insert(k.to_string(), v.to_string());
}
}
}
}
Ok((
Schema { fields, metadata },
IpcSchema {
fields: ipc_fields,
is_little_endian,
},
))
}
pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> Result<StreamMetadata> {
let message = arrow_format::ipc::MessageRef::read_as_root(meta).map_err(|err| {
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err))
})?;
let version = message.version()?;
let header = message
.header()?
.ok_or_else(|| ArrowError::oos("Unable to read the first IPC message"))?;
let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
schema
} else {
return Err(ArrowError::oos(
"The first IPC message of the stream must be a schema",
));
};
let (schema, ipc_schema) = fb_to_schema(schema)?;
Ok(StreamMetadata {
schema,
version,
ipc_schema,
})
}