use std::io;
use std::marker::PhantomData;
use crate::constants::{
ARROW_MAGIC_NUMBER, ARROW_MAGIC_NUMBER_PADDED, CONTINUATION_SENTINEL, FILE_CLOSING_MAGIC_LEN,
FILE_OPENING_MAGIC_LEN, METADATA_SIZE_PREFIX,
};
use crate::enums::{DecodeResult, DecodeState, IPCMessageProtocol};
use crate::models::frames::ipc_message::ArrowIPCMessage;
use crate::traits::frame_decoder::FrameDecoder;
use crate::traits::stream_buffer::StreamBuffer;
use crate::utils::{align_8, align_to};
pub struct ArrowIPCFrameDecoder<B: StreamBuffer> {
format: IPCMessageProtocol,
state: DecodeState<B>,
pending_prefix_len: usize,
file_magic_unconsumed: bool,
_phantom: PhantomData<B>,
}
impl<B: StreamBuffer> FrameDecoder for ArrowIPCFrameDecoder<B> {
type Frame = ArrowIPCMessage<B>;
fn decode(&mut self, buf: &[u8]) -> io::Result<DecodeResult<Self::Frame>> {
loop {
if matches!(self.state, DecodeState::Initial)
&& matches!(self.format, IPCMessageProtocol::Stream)
&& Self::has_opening_file_magic(buf)
{
self.format = IPCMessageProtocol::File;
self.file_magic_unconsumed = true;
self.state = DecodeState::ReadingMessageLength;
}
let state = std::mem::replace(&mut self.state, DecodeState::Initial);
let step = match state {
DecodeState::Initial => self.decode_initial(buf)?,
DecodeState::ReadingContinuationSize { .. } => {
self.state = DecodeState::ReadingMessageLength;
None
}
DecodeState::ReadingMessageLength => self.decode_message_length(buf)?,
DecodeState::ReadingMessage { msg_len } => self.decode_message(buf, msg_len)?,
DecodeState::ReadingBody { body_len, message } => {
self.decode_body(buf, body_len, message)?
}
DecodeState::ReadingFooter {
footer_len,
footer_offset,
} => self.decode_footer(buf, footer_len, footer_offset)?,
DecodeState::Done => Some(DecodeResult::NeedMore),
DecodeState::AfterMagic => {
self.state = DecodeState::ReadingMessageLength;
None
}
DecodeState::AfterContMarker => {
self.state = DecodeState::ReadingMessageLength;
None
}
};
if let Some(done) = step {
return Ok(done);
}
}
}
}
impl<B: StreamBuffer> ArrowIPCFrameDecoder<B> {
pub fn new(format: IPCMessageProtocol) -> Self {
Self {
format,
state: DecodeState::Initial,
pending_prefix_len: 0,
file_magic_unconsumed: matches!(format, IPCMessageProtocol::File),
_phantom: PhantomData,
}
}
pub fn new_without_header_check(format: IPCMessageProtocol) -> Self {
Self {
format,
state: if format == IPCMessageProtocol::File {
DecodeState::ReadingMessageLength
} else {
DecodeState::Initial
},
pending_prefix_len: 0,
file_magic_unconsumed: false,
_phantom: PhantomData,
}
}
#[inline]
fn read_u32_le(buf: &[u8]) -> u32 {
u32::from_le_bytes(buf[..4].try_into().unwrap())
}
#[inline]
fn has_opening_file_magic(buf: &[u8]) -> bool {
buf.len() >= FILE_OPENING_MAGIC_LEN
&& &buf[..FILE_OPENING_MAGIC_LEN] == ARROW_MAGIC_NUMBER_PADDED
}
#[inline]
fn has_continuation_sentinel(buf: &[u8]) -> bool {
buf.len() >= METADATA_SIZE_PREFIX && Self::read_u32_le(buf) == CONTINUATION_SENTINEL
}
#[inline]
fn has_eos_marker(buf: &[u8]) -> bool {
buf.len() >= 8
&& Self::read_u32_le(&buf[0..4]) == 0xFFFF_FFFF
&& Self::read_u32_le(&buf[4..8]) == 0x0000_0000
}
#[inline]
fn has_file_footer_markers(buf: &[u8], len_off: usize, msg_len: usize) -> bool {
msg_len > 0
&& len_off + METADATA_SIZE_PREFIX + msg_len + FILE_OPENING_MAGIC_LEN <= buf.len()
}
#[inline]
fn current_base_offset(&self) -> usize {
if self.format == IPCMessageProtocol::File && self.file_magic_unconsumed {
FILE_OPENING_MAGIC_LEN
} else {
0
}
}
#[inline]
fn decode_initial(
&mut self,
buf: &[u8],
) -> io::Result<Option<DecodeResult<ArrowIPCMessage<B>>>> {
match self.format {
IPCMessageProtocol::File => {
if buf.len() < FILE_OPENING_MAGIC_LEN {
return Ok(Some(DecodeResult::NeedMore));
}
if !Self::has_opening_file_magic(buf) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Arrow file magic header",
));
}
self.state = DecodeState::AfterMagic;
}
IPCMessageProtocol::Stream => {
if Self::has_eos_marker(buf) {
return Ok(Some(DecodeResult::Frame {
frame: ArrowIPCMessage {
message: B::default(),
body: B::default(),
},
consumed: 8,
}));
}
if Self::has_continuation_sentinel(buf) {
self.pending_prefix_len = 8; } else {
self.pending_prefix_len = 4; }
self.state = DecodeState::ReadingMessageLength;
}
}
Ok(None)
}
fn decode_message_length(
&mut self,
buf: &[u8],
) -> io::Result<Option<DecodeResult<ArrowIPCMessage<B>>>> {
let base_off = self.current_base_offset();
if self.format == IPCMessageProtocol::Stream {
if buf.len() >= base_off + 8 {
if Self::has_eos_marker(&buf[base_off..]) {
return Ok(Some(DecodeResult::Frame {
frame: ArrowIPCMessage {
message: B::default(),
body: B::default(),
},
consumed: base_off + 8, }));
}
}
let has_marker =
buf.len() >= base_off + 4 && Self::has_continuation_sentinel(&buf[base_off..]);
let len_off = base_off + if has_marker { 4 } else { 0 };
self.pending_prefix_len = if has_marker { 8 } else { 4 };
if buf.len() < len_off + METADATA_SIZE_PREFIX {
return Ok(Some(DecodeResult::NeedMore));
}
let msg_len = Self::read_u32_le(&buf[len_off..len_off + METADATA_SIZE_PREFIX]) as usize;
if msg_len == 0 {
return Ok(Some(DecodeResult::Frame {
frame: ArrowIPCMessage {
message: B::default(),
body: B::default(),
},
consumed: base_off + self.pending_prefix_len, }));
}
self.state = DecodeState::ReadingMessage { msg_len };
return Ok(None);
}
let len_off = base_off;
self.pending_prefix_len = METADATA_SIZE_PREFIX;
if buf.len() < len_off + METADATA_SIZE_PREFIX {
return Ok(Some(DecodeResult::NeedMore));
}
let msg_len = Self::read_u32_le(&buf[len_off..len_off + METADATA_SIZE_PREFIX]) as usize;
if Self::has_file_footer_markers(buf, len_off, msg_len) {
let possible_magic = &buf[len_off + METADATA_SIZE_PREFIX + msg_len
..len_off + METADATA_SIZE_PREFIX + msg_len + FILE_OPENING_MAGIC_LEN];
if possible_magic == ARROW_MAGIC_NUMBER {
self.state = DecodeState::ReadingFooter {
footer_len: msg_len,
footer_offset: len_off + METADATA_SIZE_PREFIX,
};
return Ok(None);
}
}
if msg_len == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Zero-length message",
));
}
self.state = DecodeState::ReadingMessage { msg_len };
Ok(None)
}
fn decode_message(
&mut self,
buf: &[u8],
msg_len: usize,
) -> io::Result<Option<DecodeResult<ArrowIPCMessage<B>>>> {
let base_off = self.current_base_offset();
let prefix = self.pending_prefix_len;
let meta_start = base_off + prefix;
let meta_end = meta_start + msg_len;
if buf.len() < meta_end {
return Ok(Some(DecodeResult::NeedMore));
}
let message = B::from_slice(&buf[meta_start..meta_end]);
use crate::AFMessage;
use flatbuffers::root;
let root = root::<AFMessage>(&message.as_ref()).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse message: {e}"),
)
})?;
let body_len = root.bodyLength() as usize;
let meta_pad = align_8(msg_len);
let body_start = meta_end + meta_pad;
if body_len > 0 {
let body_end = body_start + body_len;
if buf.len() < body_end {
self.state = DecodeState::ReadingBody { body_len, message };
return Ok(Some(DecodeResult::NeedMore));
}
let body = B::from_slice(&buf[body_start..body_end]);
let consumed_before_body_pad = base_off + prefix + msg_len + meta_pad + body_len;
let body_pad = align_to::<B>(consumed_before_body_pad);
let consumed = consumed_before_body_pad + body_pad;
self.state = DecodeState::ReadingMessageLength;
self.pending_prefix_len = 0;
if self.file_magic_unconsumed && self.format == IPCMessageProtocol::File {
self.file_magic_unconsumed = false;
}
return Ok(Some(DecodeResult::Frame {
frame: ArrowIPCMessage { message, body },
consumed,
}));
} else {
let consumed = base_off + prefix + msg_len + meta_pad;
self.state = DecodeState::ReadingMessageLength;
self.pending_prefix_len = 0;
if self.file_magic_unconsumed && self.format == IPCMessageProtocol::File {
self.file_magic_unconsumed = false;
}
let frame = ArrowIPCMessage {
message,
body: B::default(),
};
return Ok(Some(DecodeResult::Frame { frame, consumed }));
}
}
fn decode_body(
&mut self,
buf: &[u8],
body_len: usize,
message: B,
) -> io::Result<Option<DecodeResult<ArrowIPCMessage<B>>>> {
let base_off = self.current_base_offset();
let prefix = self.pending_prefix_len;
let meta_pad = align_8(message.len());
let needed = base_off + prefix + message.len() + meta_pad + body_len;
if buf.len() < needed {
return Ok(Some(DecodeResult::NeedMore));
}
let bstart = base_off + prefix + message.len() + meta_pad;
let bend = bstart + body_len;
let body = B::from_slice(&buf[bstart..bend]);
self.state = DecodeState::ReadingMessageLength;
self.pending_prefix_len = 0;
if self.file_magic_unconsumed && self.format == IPCMessageProtocol::File {
self.file_magic_unconsumed = false;
}
let body_pad = align_to::<B>(needed);
let consumed = needed + body_pad;
Ok(Some(DecodeResult::Frame {
frame: ArrowIPCMessage::<B> { message, body },
consumed,
}))
}
#[inline]
fn decode_footer(
&mut self,
buf: &[u8],
footer_len: usize,
footer_offset: usize,
) -> io::Result<Option<DecodeResult<ArrowIPCMessage<B>>>> {
if buf.len() < footer_offset + footer_len {
return Ok(Some(DecodeResult::NeedMore));
}
if buf.len() < footer_offset + footer_len + 4 + FILE_CLOSING_MAGIC_LEN {
return Ok(Some(DecodeResult::NeedMore));
}
let size_offset = footer_offset + footer_len;
let footer_size =
u32::from_le_bytes(buf[size_offset..size_offset + 4].try_into().unwrap()) as usize;
if footer_size != footer_len {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Arrow file footer size mismatch: expected {footer_len}, found {footer_size}"
),
));
}
let magic = &buf[size_offset + 4..size_offset + 4 + FILE_CLOSING_MAGIC_LEN];
if magic != ARROW_MAGIC_NUMBER {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Arrow file trailing magic",
));
}
self.state = DecodeState::Done;
Ok(None)
}
}
impl<B: StreamBuffer> Default for ArrowIPCFrameDecoder<B> {
fn default() -> Self {
ArrowIPCFrameDecoder::new(IPCMessageProtocol::Stream)
}
}