use ::arrow::array::RecordBatch;
use arrow::error::ArrowError;
use bytes::Buf;
use bytes::Bytes;
use datafusion_common::Result;
use futures::stream::BoxStream;
use futures::StreamExt as _;
use futures::{ready, Stream};
use std::collections::VecDeque;
use std::fmt;
use std::task::Poll;
#[derive(Debug, PartialEq)]
pub enum DeserializerOutput {
RecordBatch(RecordBatch),
RequiresMoreData,
InputExhausted,
}
pub trait BatchDeserializer<T>: Send + fmt::Debug {
fn digest(&mut self, message: T) -> usize;
fn next(&mut self) -> Result<DeserializerOutput, ArrowError>;
fn finish(&mut self);
}
pub trait Decoder: Send + fmt::Debug {
fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError>;
fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError>;
fn can_flush_early(&self) -> bool;
}
impl<T: Decoder> fmt::Debug for DecoderDeserializer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Deserializer")
.field("buffered_queue", &self.buffered_queue)
.field("finalized", &self.finalized)
.finish()
}
}
impl<T: Decoder> BatchDeserializer<Bytes> for DecoderDeserializer<T> {
fn digest(&mut self, message: Bytes) -> usize {
if message.is_empty() {
return 0;
}
let consumed = message.len();
self.buffered_queue.push_back(message);
consumed
}
fn next(&mut self) -> Result<DeserializerOutput, ArrowError> {
while let Some(buffered) = self.buffered_queue.front_mut() {
let decoded = self.decoder.decode(buffered)?;
buffered.advance(decoded);
if buffered.is_empty() {
self.buffered_queue.pop_front();
}
if decoded == 0 || self.decoder.can_flush_early() {
return match self.decoder.flush() {
Ok(Some(batch)) => Ok(DeserializerOutput::RecordBatch(batch)),
Ok(None) => continue,
Err(e) => Err(e),
};
}
}
if self.finalized {
Ok(DeserializerOutput::InputExhausted)
} else {
Ok(DeserializerOutput::RequiresMoreData)
}
}
fn finish(&mut self) {
self.finalized = true;
self.buffered_queue.push_back(Bytes::new());
}
}
pub struct DecoderDeserializer<T: Decoder> {
pub(crate) decoder: T,
pub(crate) buffered_queue: VecDeque<Bytes>,
pub(crate) finalized: bool,
}
impl<T: Decoder> DecoderDeserializer<T> {
pub fn new(decoder: T) -> Self {
DecoderDeserializer {
decoder,
buffered_queue: VecDeque::new(),
finalized: false,
}
}
}
pub fn deserialize_stream<'a>(
mut input: impl Stream<Item = Result<Bytes>> + Unpin + Send + 'a,
mut deserializer: impl BatchDeserializer<Bytes> + 'a,
) -> BoxStream<'a, Result<RecordBatch, ArrowError>> {
futures::stream::poll_fn(move |cx| loop {
match ready!(input.poll_next_unpin(cx)).transpose()? {
Some(b) => _ = deserializer.digest(b),
None => deserializer.finish(),
};
return match deserializer.next()? {
DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))),
DeserializerOutput::InputExhausted => Poll::Ready(None),
DeserializerOutput::RequiresMoreData => continue,
};
})
.boxed()
}