use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::arrow::message::org::apache::arrow::flatbuf as fb;
use crate::enums::{BatchState, IPCMessageProtocol};
use crate::models::decoders::ipc::parser::{
handle_dictionary_batch, handle_record_batch, handle_schema_header,
};
use crate::models::decoders::ipc::protocol::ArrowIPCFrameDecoder;
use crate::models::streams::framed_byte_stream::FramedByteStream;
use crate::traits::stream_buffer::StreamBuffer;
use futures_core::Stream;
use minarrow::*;
pub type TableStreamDecoder<S> = GTableStreamDecoder<S, Vec<u8>>;
pub type TableStreamDecoder64<S> = GTableStreamDecoder<S, Vec64<u8>>;
pub struct GTableStreamDecoder<S, B>
where
S: Stream<Item = Result<B, io::Error>> + Unpin + Send + Sync,
B: StreamBuffer,
{
pub(crate) inner: FramedByteStream<S, ArrowIPCFrameDecoder<B>, B>,
state: BatchState,
pub fields: Vec<Field>,
pub dicts: HashMap<i64, Vec<String>>,
pub protocol: IPCMessageProtocol,
}
impl<S, B> GTableStreamDecoder<S, B>
where
S: Stream<Item = Result<B, io::Error>> + Unpin + Send + Sync,
B: StreamBuffer,
{
pub fn new(stream: S, initial_capacity: usize, protocol: IPCMessageProtocol) -> Self {
Self {
inner: FramedByteStream::new(
stream,
ArrowIPCFrameDecoder::new(protocol),
initial_capacity,
),
state: BatchState::NeedSchema,
fields: Vec::new(),
dicts: HashMap::new(),
protocol,
}
}
}
impl<S, B> Stream for GTableStreamDecoder<S, B>
where
S: Stream<Item = Result<B, io::Error>> + Unpin + Send + Sync,
B: StreamBuffer + Unpin,
{
type Item = io::Result<Table>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
let raw_frame = match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
if !matches!(this.state, BatchState::Done) {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Underlying stream ended before Arrow EOF (EOS marker or footer not seen)",
))));
}
return Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(frame))) => frame,
};
if raw_frame.message.is_empty() && raw_frame.body.is_empty() {
this.state = BatchState::Done;
return Poll::Ready(None);
}
let af_msg = flatbuffers::root::<fb::Message>(&raw_frame.message.as_ref())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
match af_msg.header_type() {
fb::MessageHeader::Schema if matches!(this.state, BatchState::NeedSchema) => {
this.fields = handle_schema_header(&af_msg)?;
this.state = BatchState::Ready;
continue;
}
fb::MessageHeader::DictionaryBatch if matches!(this.state, BatchState::Ready) => {
let db = af_msg.header_as_dictionary_batch().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "missing DictionaryBatch header")
})?;
handle_dictionary_batch(&db, &raw_frame.body.as_ref(), &mut this.dicts)?;
continue;
}
fb::MessageHeader::RecordBatch if matches!(this.state, BatchState::Ready) => {
let rec = af_msg.header_as_record_batch().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "missing RecordBatch header")
})?;
let table = handle_record_batch(
&rec,
&this.fields,
&this.dicts,
&raw_frame.body.as_ref(),
)?;
return Poll::Ready(Some(Ok(table)));
}
fb::MessageHeader::NONE if this.protocol == IPCMessageProtocol::File => {
this.state = BatchState::Done;
return Poll::Ready(None);
}
fb::MessageHeader::NONE => {
this.state = BatchState::Done;
return Poll::Ready(None);
}
_ => {
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected message order",
))));
}
}
}
}
}