use std::fmt::{Debug, Formatter};
use std::iter;
use std::sync::Arc;
use flatbuffers::{FlatBufferBuilder, Follow, WIPOffset, root};
use itertools::Itertools;
use vortex_buffer::{Alignment, ByteBuffer};
use vortex_dtype::{DType, TryFromBytes};
use vortex_error::{
VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
};
use vortex_flatbuffers::array::Compression;
use vortex_flatbuffers::{
FlatBuffer, FlatBufferRoot, ReadFlatBuffer, WriteFlatBuffer, array as fba,
};
use crate::stats::StatsSet;
use crate::{Array, ArrayContext, ArrayRef, ArrayVisitor, ArrayVisitorExt};
#[derive(Default, Debug)]
pub struct SerializeOptions {
pub offset: usize,
pub include_padding: bool,
}
impl dyn Array + '_ {
pub fn serialize(
&self,
ctx: &ArrayContext,
options: &SerializeOptions,
) -> VortexResult<Vec<ByteBuffer>> {
let array_buffers = self
.depth_first_traversal()
.flat_map(|f| f.buffers())
.collect::<Vec<_>>();
let mut buffers = vec![];
let mut fb_buffers = Vec::with_capacity(buffers.capacity());
let max_alignment = array_buffers
.iter()
.map(|buf| buf.alignment())
.chain(iter::once(FlatBuffer::alignment()))
.max()
.unwrap_or_else(FlatBuffer::alignment);
let zeros = ByteBuffer::zeroed(*max_alignment);
buffers.push(ByteBuffer::zeroed_aligned(0, max_alignment));
let mut pos = options.offset;
for buffer in array_buffers {
let padding = if options.include_padding {
let padding = pos.next_multiple_of(*buffer.alignment()) - pos;
if padding > 0 {
pos += padding;
buffers.push(zeros.slice(0..padding));
}
padding
} else {
0
};
fb_buffers.push(fba::Buffer::new(
u16::try_from(padding).vortex_expect("padding fits into u16"),
buffer.alignment().exponent(),
Compression::None,
u32::try_from(buffer.len())
.map_err(|_| vortex_err!("All buffers must fit into u32 for serialization"))?,
));
pos += buffer.len();
buffers.push(buffer.aligned(Alignment::none()));
}
let mut fbb = FlatBufferBuilder::new();
let root = ArrayNodeFlatBuffer::try_new(ctx, self)?;
let fb_root = root.write_flatbuffer(&mut fbb);
let fb_buffers = fbb.create_vector(&fb_buffers);
let fb_array = fba::Array::create(
&mut fbb,
&fba::ArrayArgs {
root: Some(fb_root),
buffers: Some(fb_buffers),
},
);
fbb.finish_minimal(fb_array);
let (fb_vec, fb_start) = fbb.collapse();
let fb_end = fb_vec.len();
let fb_buffer = ByteBuffer::from(fb_vec).slice(fb_start..fb_end);
let fb_length = fb_buffer.len();
if options.include_padding {
let padding = pos.next_multiple_of(*FlatBuffer::alignment()) - pos;
if padding > 0 {
buffers.push(zeros.slice(0..padding));
}
}
buffers.push(fb_buffer);
buffers.push(ByteBuffer::from(
u32::try_from(fb_length)
.map_err(|_| vortex_err!("Array metadata flatbuffer must fit into u32 for serialization. Array encoding tree is too large."))?
.to_le_bytes()
.to_vec(),
));
Ok(buffers)
}
}
pub struct ArrayNodeFlatBuffer<'a> {
ctx: &'a ArrayContext,
array: &'a dyn Array,
buffer_idx: u16,
}
impl<'a> ArrayNodeFlatBuffer<'a> {
pub fn try_new(ctx: &'a ArrayContext, array: &'a dyn Array) -> VortexResult<Self> {
for child in array.depth_first_traversal() {
if child.metadata()?.is_none() {
vortex_bail!(
"Array {} does not support serialization",
child.encoding_id()
);
}
}
Ok(Self {
ctx,
array,
buffer_idx: 0,
})
}
}
impl FlatBufferRoot for ArrayNodeFlatBuffer<'_> {}
impl WriteFlatBuffer for ArrayNodeFlatBuffer<'_> {
type Target<'t> = fba::ArrayNode<'t>;
fn write_flatbuffer<'fb>(
&self,
fbb: &mut FlatBufferBuilder<'fb>,
) -> WIPOffset<Self::Target<'fb>> {
let encoding = self.ctx.encoding_idx(&self.array.encoding());
let metadata = self
.array
.metadata()
.vortex_expect("Failed to serialize metadata")
.vortex_expect("Validated that all arrays support serialization");
let metadata = Some(fbb.create_vector(metadata.as_slice()));
let nbuffers = u16::try_from(self.array.nbuffers())
.vortex_expect("Array can have at most u16::MAX buffers");
let mut child_buffer_idx = self.buffer_idx + nbuffers;
let children = &self
.array
.children()
.iter()
.map(|child| {
let msg = ArrayNodeFlatBuffer {
ctx: self.ctx,
array: child,
buffer_idx: child_buffer_idx,
}
.write_flatbuffer(fbb);
child_buffer_idx = u16::try_from(child.nbuffers_recursive())
.ok()
.and_then(|nbuffers| nbuffers.checked_add(child_buffer_idx))
.vortex_expect("Too many buffers (u16) for Array");
msg
})
.collect::<Vec<_>>();
let children = Some(fbb.create_vector(children));
let buffers = Some(fbb.create_vector_from_iter((0..nbuffers).map(|i| i + self.buffer_idx)));
let stats = Some(self.array.statistics().write_flatbuffer(fbb));
fba::ArrayNode::create(
fbb,
&fba::ArrayNodeArgs {
encoding,
metadata,
children,
buffers,
stats,
},
)
}
}
pub trait ArrayChildren {
fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone)]
pub struct ArrayParts {
flatbuffer: FlatBuffer,
flatbuffer_loc: usize,
buffers: Arc<[ByteBuffer]>,
}
impl Debug for ArrayParts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrayParts")
.field("encoding_id", &self.encoding_id())
.field("children", &(0..self.nchildren()).map(|i| self.child(i)))
.field(
"buffers",
&(0..self.nbuffers()).map(|i| self.buffer(i).ok()),
)
.field("metadata", &self.metadata())
.finish()
}
}
impl ArrayParts {
pub fn decode(&self, ctx: &ArrayContext, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
let encoding_id = self.flatbuffer().encoding();
let vtable = ctx
.lookup_encoding(encoding_id)
.ok_or_else(|| vortex_err!("Unknown encoding: {}", encoding_id))?;
let buffers: Vec<_> = (0..self.nbuffers())
.map(|idx| self.buffer(idx))
.try_collect()?;
let children = ArrayPartsChildren { parts: self, ctx };
let decoded = vtable.build(dtype, len, self.metadata(), &buffers, &children)?;
assert_eq!(
decoded.len(),
len,
"Array decoded from {} has incorrect length {}, expected {}",
vtable.id(),
decoded.len(),
len
);
assert_eq!(
decoded.dtype(),
dtype,
"Array decoded from {} has incorrect dtype {}, expected {}",
vtable.id(),
decoded.dtype(),
dtype,
);
assert_eq!(
decoded.encoding_id(),
vtable.id(),
"Array decoded from {} has incorrect encoding {}",
vtable.id(),
decoded.encoding_id(),
);
if let Some(stats) = self.flatbuffer().stats() {
let decoded_statistics = decoded.statistics();
StatsSet::read_flatbuffer(&stats)?
.into_iter()
.for_each(|(stat, val)| decoded_statistics.set(stat, val));
}
Ok(decoded)
}
pub fn encoding_id(&self) -> u16 {
self.flatbuffer().encoding()
}
pub fn metadata(&self) -> &[u8] {
self.flatbuffer()
.metadata()
.map(|metadata| metadata.bytes())
.unwrap_or(&[])
}
pub fn nchildren(&self) -> usize {
self.flatbuffer()
.children()
.map_or(0, |children| children.len())
}
pub fn child(&self, idx: usize) -> ArrayParts {
let children = self
.flatbuffer()
.children()
.vortex_expect("Expected array to have children");
if idx >= children.len() {
vortex_panic!(
"Invalid child index {} for array with {} children",
idx,
children.len()
);
}
self.with_root(children.get(idx))
}
pub fn nbuffers(&self) -> usize {
self.flatbuffer()
.buffers()
.map_or(0, |buffers| buffers.len())
}
pub fn buffer(&self, idx: usize) -> VortexResult<ByteBuffer> {
let buffer_idx = self
.flatbuffer()
.buffers()
.ok_or_else(|| vortex_err!("Array has no buffers"))?
.get(idx);
self.buffers
.get(buffer_idx as usize)
.cloned()
.ok_or_else(|| {
vortex_err!(
"Invalid buffer index {} for array with {} buffers",
buffer_idx,
self.nbuffers()
)
})
}
fn flatbuffer(&self) -> fba::ArrayNode<'_> {
unsafe { fba::ArrayNode::follow(self.flatbuffer.as_ref(), self.flatbuffer_loc) }
}
fn with_root(&self, root: fba::ArrayNode) -> Self {
let mut this = self.clone();
this.flatbuffer_loc = root._tab.loc();
this
}
}
struct ArrayPartsChildren<'a> {
parts: &'a ArrayParts,
ctx: &'a ArrayContext,
}
impl ArrayChildren for ArrayPartsChildren<'_> {
fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
self.parts.child(index).decode(self.ctx, dtype, len)
}
fn len(&self) -> usize {
self.parts.nchildren()
}
}
impl TryFrom<ByteBuffer> for ArrayParts {
type Error = VortexError;
fn try_from(value: ByteBuffer) -> Result<Self, Self::Error> {
if value.len() < 4 {
vortex_bail!("ArrayParts buffer is too short");
}
let value = value.aligned(Alignment::none());
let fb_length = u32::try_from_le_bytes(&value.as_slice()[value.len() - 4..])? as usize;
if value.len() < 4 + fb_length {
vortex_bail!("ArrayParts buffer is too short for flatbuffer");
}
let fb_offset = value.len() - 4 - fb_length;
let fb_buffer = value.slice(fb_offset..fb_offset + fb_length);
let fb_buffer = FlatBuffer::align_from(fb_buffer);
let fb_array = root::<fba::Array>(fb_buffer.as_ref())?;
let fb_root = fb_array.root().vortex_expect("Array must have a root node");
let mut offset = 0;
let buffers: Arc<[ByteBuffer]> = fb_array
.buffers()
.unwrap_or_default()
.iter()
.map(|fb_buffer| {
offset += fb_buffer.padding() as usize;
let buffer_len = fb_buffer.length() as usize;
let buffer = value
.slice(offset..(offset + buffer_len))
.aligned(Alignment::from_exponent(fb_buffer.alignment_exponent()));
offset += buffer_len;
buffer
})
.collect();
Ok(ArrayParts {
flatbuffer: fb_buffer.clone(),
flatbuffer_loc: fb_root._tab.loc(),
buffers,
})
}
}