use std::borrow::Cow;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::iter;
use std::sync::Arc;
use flatbuffers::FlatBufferBuilder;
use flatbuffers::Follow;
use flatbuffers::WIPOffset;
use flatbuffers::root;
use vortex_buffer::Alignment;
use vortex_buffer::ByteBuffer;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use vortex_flatbuffers::FlatBuffer;
use vortex_flatbuffers::WriteFlatBuffer;
use vortex_flatbuffers::array as fba;
use vortex_flatbuffers::array::Compression;
use vortex_session::VortexSession;
use vortex_session::registry::ReadContext;
use vortex_utils::aliases::hash_map::HashMap;
use crate::ArrayContext;
use crate::ArrayRef;
use crate::array::new_foreign_array;
use crate::buffer::BufferHandle;
use crate::dtype::DType;
use crate::dtype::TryFromBytes;
use crate::session::ArraySessionExt;
use crate::stats::StatsSet;
#[derive(Default, Debug)]
pub struct SerializeOptions {
pub offset: usize,
pub include_padding: bool,
}
impl ArrayRef {
pub fn serialize(
&self,
ctx: &ArrayContext,
session: &VortexSession,
options: &SerializeOptions,
) -> VortexResult<Vec<ByteBuffer>> {
let array_buffers = self
.depth_first_traversal()
.flat_map(|f| f.buffers())
.collect::<Vec<_>>();
let mut buffers = vec![];
let mut fb_buffers = Vec::with_capacity(buffers.capacity());
let max_alignment = array_buffers
.iter()
.map(|buf| buf.alignment())
.chain(iter::once(FlatBuffer::alignment()))
.max()
.unwrap_or_else(FlatBuffer::alignment);
let zeros = ByteBuffer::zeroed(*max_alignment);
buffers.push(ByteBuffer::zeroed_aligned(0, max_alignment));
let mut pos = options.offset;
for buffer in array_buffers {
let padding = if options.include_padding {
let padding = pos.next_multiple_of(*buffer.alignment()) - pos;
if padding > 0 {
pos += padding;
buffers.push(zeros.slice(0..padding));
}
padding
} else {
0
};
fb_buffers.push(fba::Buffer::new(
u16::try_from(padding).vortex_expect("padding fits into u16"),
buffer.alignment().exponent(),
Compression::None,
u32::try_from(buffer.len())
.map_err(|_| vortex_err!("All buffers must fit into u32 for serialization"))?,
));
pos += buffer.len();
buffers.push(buffer.aligned(Alignment::none()));
}
let mut fbb = FlatBufferBuilder::new();
let root = ArrayNodeFlatBuffer::try_new(ctx, session, self)?;
let fb_root = root.try_write_flatbuffer(&mut fbb)?;
let fb_buffers = fbb.create_vector(&fb_buffers);
let fb_array = fba::Array::create(
&mut fbb,
&fba::ArrayArgs {
root: Some(fb_root),
buffers: Some(fb_buffers),
},
);
fbb.finish_minimal(fb_array);
let (fb_vec, fb_start) = fbb.collapse();
let fb_end = fb_vec.len();
let fb_buffer = ByteBuffer::from(fb_vec).slice(fb_start..fb_end);
let fb_length = fb_buffer.len();
if options.include_padding {
let padding = pos.next_multiple_of(*FlatBuffer::alignment()) - pos;
if padding > 0 {
buffers.push(zeros.slice(0..padding));
}
}
buffers.push(fb_buffer);
buffers.push(ByteBuffer::from(
u32::try_from(fb_length)
.map_err(|_| vortex_err!("Array metadata flatbuffer must fit into u32 for serialization. Array encoding tree is too large."))?
.to_le_bytes()
.to_vec(),
));
Ok(buffers)
}
}
pub struct ArrayNodeFlatBuffer<'a> {
ctx: &'a ArrayContext,
session: &'a VortexSession,
array: &'a ArrayRef,
buffer_idx: u16,
}
impl<'a> ArrayNodeFlatBuffer<'a> {
pub fn try_new(
ctx: &'a ArrayContext,
session: &'a VortexSession,
array: &'a ArrayRef,
) -> VortexResult<Self> {
for child in array.depth_first_traversal() {
if child.metadata(session)?.is_none() {
vortex_bail!(
"Array {} does not support serialization",
child.encoding_id()
);
}
}
let n_buffers_recursive = array.nbuffers_recursive();
if n_buffers_recursive > u16::MAX as usize {
vortex_bail!(
"Array and all descendent arrays can have at most u16::MAX buffers: {}",
n_buffers_recursive
);
};
Ok(Self {
ctx,
session,
array,
buffer_idx: 0,
})
}
pub fn try_write_flatbuffer<'fb>(
&self,
fbb: &mut FlatBufferBuilder<'fb>,
) -> VortexResult<WIPOffset<fba::ArrayNode<'fb>>> {
let encoding_idx = self
.ctx
.intern(&self.array.encoding_id())
.ok_or_else(|| {
vortex_err!(
"Array encoding {} not permitted by ctx",
self.array.encoding_id()
)
})?;
let metadata = self.array.metadata(self.session)?.ok_or_else(|| {
vortex_err!(
"Array {} does not support serialization",
self.array.encoding_id()
)
})?;
let metadata = Some(fbb.create_vector(metadata.as_slice()));
let nbuffers = u16::try_from(self.array.nbuffers())
.map_err(|_| vortex_err!("Array can have at most u16::MAX buffers"))?;
let mut child_buffer_idx = self.buffer_idx + nbuffers;
let children = self
.array
.children()
.iter()
.map(|child| {
let msg = ArrayNodeFlatBuffer {
ctx: self.ctx,
session: self.session,
array: child,
buffer_idx: child_buffer_idx,
}
.try_write_flatbuffer(fbb)?;
child_buffer_idx = u16::try_from(child.nbuffers_recursive())
.ok()
.and_then(|nbuffers| nbuffers.checked_add(child_buffer_idx))
.ok_or_else(|| vortex_err!("Too many buffers (u16) for Array"))?;
Ok(msg)
})
.collect::<VortexResult<Vec<_>>>()?;
let children = Some(fbb.create_vector(&children));
let buffers = Some(fbb.create_vector_from_iter((0..nbuffers).map(|i| i + self.buffer_idx)));
let stats = Some(self.array.statistics().write_flatbuffer(fbb)?);
Ok(fba::ArrayNode::create(
fbb,
&fba::ArrayNodeArgs {
encoding: encoding_idx,
metadata,
children,
buffers,
stats,
},
))
}
}
pub trait ArrayChildren {
fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: AsRef<[ArrayRef]>> ArrayChildren for T {
fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
let array = self.as_ref()[index].clone();
assert_eq!(array.len(), len);
assert_eq!(array.dtype(), dtype);
Ok(array)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
#[derive(Clone)]
pub struct SerializedArray {
flatbuffer: FlatBuffer,
flatbuffer_loc: usize,
buffers: Arc<[BufferHandle]>,
}
impl Debug for SerializedArray {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SerializedArray")
.field("encoding_id", &self.encoding_id())
.field("children", &(0..self.nchildren()).map(|i| self.child(i)))
.field(
"buffers",
&(0..self.nbuffers()).map(|i| self.buffer(i).ok()),
)
.field("metadata", &self.metadata())
.finish()
}
}
impl SerializedArray {
pub fn decode(
&self,
dtype: &DType,
len: usize,
ctx: &ReadContext,
session: &VortexSession,
) -> VortexResult<ArrayRef> {
let encoding_idx = self.flatbuffer().encoding();
let encoding_id = ctx
.resolve(encoding_idx)
.ok_or_else(|| vortex_err!("Unknown encoding index: {}", encoding_idx))?;
let Some(plugin) = session.arrays().registry().find(&encoding_id) else {
if session.allows_unknown() {
return self.decode_foreign(encoding_id, dtype, len, ctx);
}
return Err(vortex_err!("Unknown encoding: {}", encoding_id));
};
let children = SerializedArrayChildren {
ser: self,
ctx,
session,
};
let buffers = self.collect_buffers()?;
let decoded =
plugin.deserialize(dtype, len, self.metadata(), &buffers, &children, session)?;
assert_eq!(
decoded.len(),
len,
"Array decoded from {} has incorrect length {}, expected {}",
encoding_id,
decoded.len(),
len
);
assert_eq!(
decoded.dtype(),
dtype,
"Array decoded from {} has incorrect dtype {}, expected {}",
encoding_id,
decoded.dtype(),
dtype,
);
assert_eq!(
decoded.encoding_id(),
encoding_id,
"Array decoded from {} has incorrect encoding {}",
encoding_id,
decoded.encoding_id(),
);
if let Some(stats) = self.flatbuffer().stats() {
decoded
.statistics()
.set_iter(StatsSet::from_flatbuffer(&stats, dtype, session)?.into_iter());
}
Ok(decoded)
}
fn decode_foreign(
&self,
encoding_id: crate::array::ArrayId,
dtype: &DType,
len: usize,
ctx: &ReadContext,
) -> VortexResult<ArrayRef> {
let children = (0..self.nchildren())
.map(|idx| {
let child = self.child(idx);
let child_encoding_idx = child.flatbuffer().encoding();
let child_encoding_id = ctx
.resolve(child_encoding_idx)
.ok_or_else(|| vortex_err!("Unknown encoding index: {}", child_encoding_idx))?;
child.decode_foreign(child_encoding_id, dtype, len, ctx)
})
.collect::<VortexResult<Vec<_>>>()?;
new_foreign_array(
encoding_id,
dtype.clone(),
len,
self.metadata().to_vec(),
self.collect_buffers()?.into_owned(),
children,
)
}
pub fn encoding_id(&self) -> u16 {
self.flatbuffer().encoding()
}
pub fn metadata(&self) -> &[u8] {
self.flatbuffer()
.metadata()
.map(|metadata| metadata.bytes())
.unwrap_or(&[])
}
pub fn nchildren(&self) -> usize {
self.flatbuffer()
.children()
.map_or(0, |children| children.len())
}
pub fn child(&self, idx: usize) -> SerializedArray {
let children = self
.flatbuffer()
.children()
.vortex_expect("Expected array to have children");
if idx >= children.len() {
vortex_panic!(
"Invalid child index {} for array with {} children",
idx,
children.len()
);
}
self.with_root(children.get(idx))
}
pub fn nbuffers(&self) -> usize {
self.flatbuffer()
.buffers()
.map_or(0, |buffers| buffers.len())
}
pub fn buffer(&self, idx: usize) -> VortexResult<BufferHandle> {
let buffer_idx = self
.flatbuffer()
.buffers()
.ok_or_else(|| vortex_err!("Array has no buffers"))?
.get(idx);
self.buffers
.get(buffer_idx as usize)
.cloned()
.ok_or_else(|| {
vortex_err!(
"Invalid buffer index {} for array with {} buffers",
buffer_idx,
self.nbuffers()
)
})
}
fn collect_buffers(&self) -> VortexResult<Cow<'_, [BufferHandle]>> {
let Some(fb_buffers) = self.flatbuffer().buffers() else {
return Ok(Cow::Borrowed(&[]));
};
let count = fb_buffers.len();
if count == 0 {
return Ok(Cow::Borrowed(&[]));
}
let start = fb_buffers.get(0) as usize;
let contiguous = fb_buffers
.iter()
.enumerate()
.all(|(i, idx)| idx as usize == start + i);
if contiguous {
self.buffers.get(start..start + count).map_or_else(
|| {
vortex_bail!(
"buffer indices {}..{} out of range for {} buffers",
start,
start + count,
self.buffers.len()
)
},
|slice| Ok(Cow::Borrowed(slice)),
)
} else {
(0..count)
.map(|idx| self.buffer(idx))
.collect::<VortexResult<Vec<_>>>()
.map(Cow::Owned)
}
}
pub fn buffer_lengths(&self) -> Vec<usize> {
let fb_array = root::<fba::Array>(self.flatbuffer.as_ref())
.vortex_expect("SerializedArray flatbuffer must be a valid Array");
fb_array
.buffers()
.map(|buffers| buffers.iter().map(|b| b.length() as usize).collect())
.unwrap_or_default()
}
fn validate_array_tree(array_tree: impl Into<ByteBuffer>) -> VortexResult<(FlatBuffer, usize)> {
let fb_buffer = FlatBuffer::align_from(array_tree.into());
let fb_array = root::<fba::Array>(fb_buffer.as_ref())?;
let fb_root = fb_array
.root()
.ok_or_else(|| vortex_err!("Array must have a root node"))?;
let flatbuffer_loc = fb_root._tab.loc();
Ok((fb_buffer, flatbuffer_loc))
}
pub fn from_flatbuffer_with_buffers(
array_tree: impl Into<ByteBuffer>,
buffers: Vec<BufferHandle>,
) -> VortexResult<Self> {
let (flatbuffer, flatbuffer_loc) = Self::validate_array_tree(array_tree)?;
Ok(SerializedArray {
flatbuffer,
flatbuffer_loc,
buffers: buffers.into(),
})
}
pub fn from_array_tree(array_tree: impl Into<ByteBuffer>) -> VortexResult<Self> {
let (flatbuffer, flatbuffer_loc) = Self::validate_array_tree(array_tree)?;
Ok(SerializedArray {
flatbuffer,
flatbuffer_loc,
buffers: Arc::new([]),
})
}
fn flatbuffer(&self) -> fba::ArrayNode<'_> {
unsafe { fba::ArrayNode::follow(self.flatbuffer.as_ref(), self.flatbuffer_loc) }
}
fn with_root(&self, root: fba::ArrayNode) -> Self {
let mut this = self.clone();
this.flatbuffer_loc = root._tab.loc();
this
}
pub fn from_flatbuffer_and_segment(
array_tree: ByteBuffer,
segment: BufferHandle,
) -> VortexResult<Self> {
Self::from_flatbuffer_and_segment_with_overrides(array_tree, segment, &HashMap::new())
}
pub fn from_flatbuffer_and_segment_with_overrides(
array_tree: ByteBuffer,
segment: BufferHandle,
buffer_overrides: &HashMap<u32, ByteBuffer>,
) -> VortexResult<Self> {
let segment = segment.ensure_aligned(Alignment::none())?;
let (fb_buffer, flatbuffer_loc) = Self::validate_array_tree(array_tree)?;
let fb_array = unsafe { fba::root_as_array_unchecked(fb_buffer.as_ref()) };
let mut offset = 0;
let buffers = fb_array
.buffers()
.unwrap_or_default()
.iter()
.enumerate()
.map(|(idx, fb_buf)| {
offset += fb_buf.padding() as usize;
let buffer_len = fb_buf.length() as usize;
let alignment = Alignment::from_exponent(fb_buf.alignment_exponent());
let idx = u32::try_from(idx).vortex_expect("buffer count must fit in u32");
let handle = if let Some(host_data) = buffer_overrides.get(&idx) {
BufferHandle::new_host(host_data.clone()).ensure_aligned(alignment)?
} else {
let buffer = segment.slice(offset..(offset + buffer_len));
buffer.ensure_aligned(alignment)?
};
offset += buffer_len;
Ok(handle)
})
.collect::<VortexResult<Arc<[_]>>>()?;
Ok(SerializedArray {
flatbuffer: fb_buffer,
flatbuffer_loc,
buffers,
})
}
}
struct SerializedArrayChildren<'a> {
ser: &'a SerializedArray,
ctx: &'a ReadContext,
session: &'a VortexSession,
}
impl ArrayChildren for SerializedArrayChildren<'_> {
fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
self.ser
.child(index)
.decode(dtype, len, self.ctx, self.session)
}
fn len(&self) -> usize {
self.ser.nchildren()
}
}
impl TryFrom<ByteBuffer> for SerializedArray {
type Error = VortexError;
fn try_from(value: ByteBuffer) -> Result<Self, Self::Error> {
if value.len() < 4 {
vortex_bail!("SerializedArray buffer is too short");
}
let value = value.aligned(Alignment::none());
let fb_length = u32::try_from_le_bytes(&value.as_slice()[value.len() - 4..])? as usize;
if value.len() < 4 + fb_length {
vortex_bail!("SerializedArray buffer is too short for flatbuffer");
}
let fb_offset = value.len() - 4 - fb_length;
let array_tree = value.slice(fb_offset..fb_offset + fb_length);
let segment = BufferHandle::new_host(value.slice(0..fb_offset));
Self::from_flatbuffer_and_segment(array_tree, segment)
}
}
impl TryFrom<BufferHandle> for SerializedArray {
type Error = VortexError;
fn try_from(value: BufferHandle) -> Result<Self, Self::Error> {
Self::try_from(value.try_to_host_sync()?)
}
}
#[cfg(test)]
mod tests {
use std::sync::LazyLock;
use flatbuffers::FlatBufferBuilder;
use vortex_session::VortexSession;
use vortex_session::registry::ReadContext;
use super::SerializeOptions;
use super::SerializedArray;
use crate::ArrayContext;
use crate::array::ArrayId;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::flatbuffers as fba;
use crate::session::ArraySession;
static SESSION: LazyLock<VortexSession> = LazyLock::new(VortexSession::empty);
#[test]
fn unknown_array_encoding_allow_unknown() {
let mut fbb = FlatBufferBuilder::new();
let child_metadata = fbb.create_vector(&[9u8]);
let child = fba::ArrayNode::create(
&mut fbb,
&fba::ArrayNodeArgs {
encoding: 1,
metadata: Some(child_metadata),
children: None,
buffers: None,
stats: None,
},
);
let children = fbb.create_vector(&[child]);
let metadata = fbb.create_vector(&[1u8, 2, 3]);
let root = fba::ArrayNode::create(
&mut fbb,
&fba::ArrayNodeArgs {
encoding: 0,
metadata: Some(metadata),
children: Some(children),
buffers: None,
stats: None,
},
);
let array = fba::Array::create(
&mut fbb,
&fba::ArrayArgs {
root: Some(root),
buffers: None,
},
);
fbb.finish_minimal(array);
let (buf, start) = fbb.collapse();
let tree = vortex_buffer::ByteBuffer::from(buf).slice(start..);
let ser = SerializedArray::from_array_tree(tree).unwrap();
let ctx = ReadContext::new([
ArrayId::new_ref("vortex.test.foreign_array"),
ArrayId::new_ref("vortex.test.foreign_child"),
]);
let session = VortexSession::empty()
.with::<ArraySession>()
.allow_unknown();
let decoded = ser
.decode(&DType::Variant(Nullability::Nullable), 5, &ctx, &session)
.unwrap();
assert_eq!(decoded.encoding_id().as_ref(), "vortex.test.foreign_array");
assert_eq!(decoded.nchildren(), 1);
assert_eq!(
decoded.nth_child(0).unwrap().encoding_id().as_ref(),
"vortex.test.foreign_child"
);
assert_eq!(decoded.metadata(&SESSION).unwrap().unwrap(), vec![1, 2, 3]);
assert_eq!(
decoded
.nth_child(0)
.unwrap()
.metadata(&SESSION)
.unwrap()
.unwrap(),
vec![9]
);
let serialized = decoded
.serialize(
&ArrayContext::default(),
&SESSION,
&SerializeOptions::default(),
)
.unwrap();
assert!(!serialized.is_empty());
}
}