use crate::error::StreamingError;
pub const ARROW_MAGIC: &[u8] = b"ARROW1";
pub const ARROW_MAGIC_LEN: usize = 6;
pub const ARROW_ALIGNMENT: usize = 8;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IpcMessageType {
Schema,
DictionaryBatch,
RecordBatch,
Tensor,
SparseTensor,
}
#[derive(Debug, Clone)]
pub struct IpcMessageHeader {
pub message_type: IpcMessageType,
pub metadata_length: i32,
pub body_length: i64,
pub body_offset: u64,
}
#[derive(Debug, Clone)]
pub struct IpcBuffer {
pub offset: i64,
pub length: i64,
}
#[derive(Debug, Clone)]
pub struct IpcRecordBatch {
pub length: i64,
pub nodes: Vec<IpcFieldNode>,
pub buffers: Vec<IpcBuffer>,
}
#[derive(Debug, Clone)]
pub struct IpcFieldNode {
pub length: i64,
pub null_count: i64,
}
pub struct ArrowIpcReader {
data: Vec<u8>,
offset: usize,
}
impl ArrowIpcReader {
#[must_use]
pub fn new(data: Vec<u8>) -> Self {
Self { data, offset: 0 }
}
#[must_use]
pub fn is_arrow_file(&self) -> bool {
self.data.len() >= ARROW_MAGIC_LEN && self.data.starts_with(ARROW_MAGIC)
}
fn read_i32(&self, offset: usize) -> Option<i32> {
let bytes = self.data.get(offset..offset + 4)?;
Some(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
fn read_i64(&self, offset: usize) -> Option<i64> {
let bytes = self.data.get(offset..offset + 8)?;
Some(i64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]))
}
fn read_u32(&self, offset: usize) -> Option<u32> {
let bytes = self.data.get(offset..offset + 4)?;
Some(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
pub fn parse_file_header(&mut self) -> Result<(), StreamingError> {
if !self.is_arrow_file() {
return Err(StreamingError::Other("Not an Arrow IPC file".into()));
}
self.offset = ARROW_MAGIC_LEN + 2;
Ok(())
}
pub fn next_message(&mut self) -> Result<Option<IpcMessageHeader>, StreamingError> {
if self.offset + 4 > self.data.len() {
return Ok(None);
}
if let Some(cont) = self.read_u32(self.offset) {
if cont == 0xFFFF_FFFF {
self.offset += 4;
}
}
let metadata_length = self
.read_i32(self.offset)
.ok_or_else(|| StreamingError::Other("Truncated metadata length".into()))?;
if metadata_length <= 0 {
return Ok(None);
}
self.offset += 4;
let meta_end = self.offset + metadata_length as usize;
let msg_type = if meta_end <= self.data.len() && metadata_length >= 8 {
match self.data.get(self.offset + 4).copied().unwrap_or(0) {
1 => IpcMessageType::Schema,
2 => IpcMessageType::DictionaryBatch,
3 => IpcMessageType::RecordBatch,
4 => IpcMessageType::Tensor,
5 => IpcMessageType::SparseTensor,
_ => IpcMessageType::RecordBatch,
}
} else {
IpcMessageType::RecordBatch
};
let aligned_meta = align_to(metadata_length as usize, ARROW_ALIGNMENT);
self.offset += aligned_meta;
let body_length = self.read_i64(self.offset).unwrap_or(0);
self.offset += 8;
let body_offset = self.offset as u64;
let aligned_body = align_to(body_length as usize, ARROW_ALIGNMENT);
self.offset += aligned_body;
Ok(Some(IpcMessageHeader {
message_type: msg_type,
metadata_length,
body_length,
body_offset,
}))
}
#[must_use]
pub fn read_buffer<'a>(&'a self, body_offset: u64, buf: &IpcBuffer) -> Option<&'a [u8]> {
let start = (body_offset as usize).checked_add(buf.offset as usize)?;
let end = start.checked_add(buf.length as usize)?;
self.data.get(start..end)
}
#[must_use]
pub fn data_len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn current_offset(&self) -> usize {
self.offset
}
}
pub struct ArrowIpcWriter {
buf: Vec<u8>,
}
impl ArrowIpcWriter {
#[must_use]
pub fn new() -> Self {
let mut w = Self { buf: Vec::new() };
w.buf.extend_from_slice(ARROW_MAGIC);
w.buf.extend_from_slice(&[0u8; 2]); w
}
pub fn write_message(&mut self, metadata: &[u8], body: &[u8]) {
self.buf.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes());
self.buf
.extend_from_slice(&(metadata.len() as i32).to_le_bytes());
self.buf.extend_from_slice(metadata);
let meta_pad = align_to(metadata.len(), ARROW_ALIGNMENT) - metadata.len();
self.buf.resize(self.buf.len() + meta_pad, 0u8);
self.buf
.extend_from_slice(&(body.len() as i64).to_le_bytes());
self.buf.extend_from_slice(body);
let body_pad = align_to(body.len(), ARROW_ALIGNMENT) - body.len();
self.buf.resize(self.buf.len() + body_pad, 0u8);
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
self.buf.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes());
self.buf.extend_from_slice(&0i32.to_le_bytes());
self.buf.extend_from_slice(ARROW_MAGIC);
self.buf
}
}
impl Default for ArrowIpcWriter {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn align_to(size: usize, alignment: usize) -> usize {
if alignment == 0 {
return size;
}
(size + alignment - 1) & !(alignment - 1)
}