use std::{collections::HashMap, sync::Arc};
use arrow_format::ipc;
use arrow_format::ipc::flatbuffers::FlatBufferBuilder;
use arrow_format::ipc::Message::CompressionType;
use crate::array::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::io::ipc::endianess::is_native_little_endian;
use crate::record_batch::RecordBatch;
use super::{write, write_dictionary};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Compression {
LZ4,
ZSTD,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct WriteOptions {
pub compression: Option<Compression>,
}
fn encode_dictionary(
field: &Field,
array: &Arc<dyn Array>,
options: &WriteOptions,
dictionary_tracker: &mut DictionaryTracker,
encoded_dictionaries: &mut Vec<EncodedData>,
) -> Result<()> {
use PhysicalType::*;
match array.data_type().to_physical_type() {
Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
| FixedSizeBinary => Ok(()),
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
let values = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap().values();
let field = Field::new("item", values.data_type().clone(), true);
encode_dictionary(&field,
values,
options,
dictionary_tracker,
encoded_dictionaries
)?;
let emit = dictionary_tracker.insert(dict_id, array)?;
if emit {
encoded_dictionaries.push(dictionary_batch_to_bytes(
dict_id,
array.as_ref(),
options,
is_native_little_endian(),
));
};
Ok(())
}),
Struct => {
let values = array
.as_any()
.downcast_ref::<StructArray>()
.unwrap()
.values();
let fields = if let DataType::Struct(fields) = array.data_type() {
fields
} else {
unreachable!()
};
fields
.iter()
.zip(values.iter())
.try_for_each(|(field, values)| {
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
})
}
List => {
let values = array
.as_any()
.downcast_ref::<ListArray<i32>>()
.unwrap()
.values();
let field = if let DataType::List(field) = field.data_type() {
field.as_ref()
} else {
unreachable!()
};
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
}
LargeList => {
let values = array
.as_any()
.downcast_ref::<ListArray<i64>>()
.unwrap()
.values();
let field = if let DataType::LargeList(field) = field.data_type() {
field.as_ref()
} else {
unreachable!()
};
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
}
FixedSizeList => {
let values = array
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.values();
let field = if let DataType::FixedSizeList(field, _) = field.data_type() {
field.as_ref()
} else {
unreachable!()
};
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
}
Union => {
let values = array
.as_any()
.downcast_ref::<UnionArray>()
.unwrap()
.fields();
let fields = if let DataType::Union(fields, _, _) = field.data_type() {
fields
} else {
unreachable!()
};
fields
.iter()
.zip(values.iter())
.try_for_each(|(field, values)| {
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
})
}
Map => {
let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
let field = if let DataType::Map(field, _) = field.data_type() {
field.as_ref()
} else {
unreachable!()
};
encode_dictionary(
field,
values,
options,
dictionary_tracker,
encoded_dictionaries,
)
}
}
}
pub fn encoded_batch(
batch: &RecordBatch,
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let schema = batch.schema();
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());
for (field, column) in schema.fields().iter().zip(batch.columns()) {
encode_dictionary(
field,
column,
options,
dictionary_tracker,
&mut encoded_dictionaries,
)?;
}
let encoded_message = record_batch_to_bytes(batch, options);
Ok((encoded_dictionaries, encoded_message))
}
fn record_batch_to_bytes(batch: &RecordBatch, options: &WriteOptions) -> EncodedData {
let mut fbb = FlatBufferBuilder::new();
let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let mut offset = 0;
for array in batch.columns() {
write(
array.as_ref(),
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut offset,
is_native_little_endian(),
options.compression,
)
}
let buffers = fbb.create_vector(&buffers);
let nodes = fbb.create_vector(&nodes);
let compression = if let Some(compression) = options.compression {
let compression = match compression {
Compression::LZ4 => CompressionType::LZ4_FRAME,
Compression::ZSTD => CompressionType::ZSTD,
};
let mut compression_builder = ipc::Message::BodyCompressionBuilder::new(&mut fbb);
compression_builder.add_codec(compression);
Some(compression_builder.finish())
} else {
None
};
let root = {
let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
batch_builder.add_length(batch.num_rows() as i64);
batch_builder.add_nodes(nodes);
batch_builder.add_buffers(buffers);
if let Some(compression) = compression {
batch_builder.add_compression(compression)
}
let b = batch_builder.finish();
b.as_union_value()
};
let mut message = ipc::Message::MessageBuilder::new(&mut fbb);
message.add_version(ipc::Schema::MetadataVersion::V5);
message.add_header_type(ipc::Message::MessageHeader::RecordBatch);
message.add_bodyLength(arrow_data.len() as i64);
message.add_header(root);
let root = message.finish();
fbb.finish(root, None);
let finished_data = fbb.finished_data();
EncodedData {
ipc_message: finished_data.to_vec(),
arrow_data,
}
}
fn dictionary_batch_to_bytes(
dict_id: i64,
array: &dyn Array,
options: &WriteOptions,
is_little_endian: bool,
) -> EncodedData {
let mut fbb = FlatBufferBuilder::new();
let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let length = write_dictionary(
array,
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut 0,
is_little_endian,
options.compression,
false,
);
let buffers = fbb.create_vector(&buffers);
let nodes = fbb.create_vector(&nodes);
let compression = if let Some(compression) = options.compression {
let compression = match compression {
Compression::LZ4 => CompressionType::LZ4_FRAME,
Compression::ZSTD => CompressionType::ZSTD,
};
let mut compression_builder = ipc::Message::BodyCompressionBuilder::new(&mut fbb);
compression_builder.add_codec(compression);
Some(compression_builder.finish())
} else {
None
};
let root = {
let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
batch_builder.add_length(length as i64);
batch_builder.add_nodes(nodes);
batch_builder.add_buffers(buffers);
if let Some(compression) = compression {
batch_builder.add_compression(compression)
}
batch_builder.finish()
};
let root = {
let mut batch_builder = ipc::Message::DictionaryBatchBuilder::new(&mut fbb);
batch_builder.add_id(dict_id);
batch_builder.add_data(root);
batch_builder.finish().as_union_value()
};
let root = {
let mut message_builder = ipc::Message::MessageBuilder::new(&mut fbb);
message_builder.add_version(ipc::Schema::MetadataVersion::V5);
message_builder.add_header_type(ipc::Message::MessageHeader::DictionaryBatch);
message_builder.add_bodyLength(arrow_data.len() as i64);
message_builder.add_header(root);
message_builder.finish()
};
fbb.finish(root, None);
let finished_data = fbb.finished_data();
EncodedData {
ipc_message: finished_data.to_vec(),
arrow_data,
}
}
pub struct DictionaryTracker {
written: HashMap<i64, Arc<dyn Array>>,
error_on_replacement: bool,
}
impl DictionaryTracker {
pub fn new(error_on_replacement: bool) -> Self {
Self {
written: HashMap::new(),
error_on_replacement,
}
}
pub fn insert(&mut self, dict_id: i64, array: &Arc<dyn Array>) -> Result<bool> {
let values = match array.data_type() {
DataType::Dictionary(key_type, _) => {
match_integer_type!(key_type, |$T| {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<$T>>()
.unwrap();
array.values()
})
}
_ => unreachable!(),
};
if let Some(last) = self.written.get(&dict_id) {
if last.as_ref() == values.as_ref() {
return Ok(false);
} else if self.error_on_replacement {
return Err(ArrowError::InvalidArgumentError(
"Dictionary replacement detected when writing IPC file format. \
Arrow IPC files only support a single dictionary for a given field \
across all batches."
.to_string(),
));
}
};
self.written.insert(dict_id, values.clone());
Ok(true)
}
}
pub struct EncodedData {
pub ipc_message: Vec<u8>,
pub arrow_data: Vec<u8>,
}
#[inline]
pub(crate) fn pad_to_8(len: usize) -> usize {
(((len + 7) & !7) - len) as usize
}