use std::sync::Arc;
use arrow_format::ipc::planus::ReadAsRoot;
use futures::future::BoxFuture;
use futures::AsyncRead;
use futures::AsyncReadExt;
use futures::Stream;
use crate::array::*;
use crate::chunk::Chunk;
use crate::error::{ArrowError, Result};
use super::super::CONTINUATION_MARKER;
use super::common::{read_dictionary, read_record_batch};
use super::schema::deserialize_stream_metadata;
use super::Dictionaries;
use super::StreamMetadata;
struct ReadState<R> {
pub reader: R,
pub metadata: StreamMetadata,
pub dictionaries: Dictionaries,
pub data_buffer: Vec<u8>,
pub message_buffer: Vec<u8>,
}
enum StreamState<R> {
Waiting(ReadState<R>),
Some((ReadState<R>, Chunk<Arc<dyn Array>>)),
}
pub async fn read_stream_metadata_async<R: AsyncRead + Unpin + Send>(
reader: &mut R,
) -> Result<StreamMetadata> {
let mut meta_size: [u8; 4] = [0; 4];
reader.read_exact(&mut meta_size).await?;
let meta_len = {
if meta_size == CONTINUATION_MARKER {
reader.read_exact(&mut meta_size).await?;
}
i32::from_le_bytes(meta_size)
};
let mut meta_buffer = vec![0; meta_len as usize];
reader.read_exact(&mut meta_buffer).await?;
deserialize_stream_metadata(&meta_buffer)
}
async fn maybe_next<R: AsyncRead + Unpin + Send>(
mut state: ReadState<R>,
) -> Result<Option<StreamState<R>>> {
let mut meta_length: [u8; 4] = [0; 4];
match state.reader.read_exact(&mut meta_length).await {
Ok(()) => (),
Err(e) => {
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
Ok(Some(StreamState::Waiting(state)))
} else {
Err(ArrowError::from(e))
};
}
}
let meta_length = {
if meta_length == CONTINUATION_MARKER {
state.reader.read_exact(&mut meta_length).await?;
}
i32::from_le_bytes(meta_length) as usize
};
if meta_length == 0 {
return Ok(None);
}
state.message_buffer.clear();
state.message_buffer.resize(meta_length, 0);
state.reader.read_exact(&mut state.message_buffer).await?;
let message =
arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer).map_err(|err| {
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err))
})?;
let header = message.header()?.ok_or_else(|| {
ArrowError::oos("IPC: unable to fetch the message header. The file or stream is corrupted.")
})?;
match header {
arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(ArrowError::oos("A stream ")),
arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => {
state.data_buffer.clear();
state.data_buffer.resize(message.body_length()? as usize, 0);
state.reader.read_exact(&mut state.data_buffer).await?;
read_record_batch(
batch,
&state.metadata.schema.fields,
&state.metadata.ipc_schema,
None,
&state.dictionaries,
state.metadata.version,
&mut std::io::Cursor::new(&state.data_buffer),
0,
)
.map(|chunk| Some(StreamState::Some((state, chunk))))
}
arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => {
let mut buf = vec![0; message.body_length()? as usize];
state.reader.read_exact(&mut buf).await?;
let mut dict_reader = std::io::Cursor::new(buf);
read_dictionary(
batch,
&state.metadata.schema.fields,
&state.metadata.ipc_schema,
&mut state.dictionaries,
&mut dict_reader,
0,
)?;
Ok(Some(StreamState::Waiting(state)))
}
t => Err(ArrowError::OutOfSpec(format!(
"Reading types other than record batches not yet supported, unable to read {:?} ",
t
))),
}
}
pub struct AsyncStreamReader<'a, R: AsyncRead + Unpin + Send + 'a> {
metadata: StreamMetadata,
future: Option<BoxFuture<'a, Result<Option<StreamState<R>>>>>,
}
impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncStreamReader<'a, R> {
pub fn new(reader: R, metadata: StreamMetadata) -> Self {
let state = ReadState {
reader,
metadata: metadata.clone(),
dictionaries: Default::default(),
data_buffer: Default::default(),
message_buffer: Default::default(),
};
let future = Some(Box::pin(maybe_next(state)) as _);
Self { metadata, future }
}
pub fn metadata(&self) -> &StreamMetadata {
&self.metadata
}
}
impl<'a, R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<'a, R> {
type Item = Result<Chunk<Arc<dyn Array>>>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::pin::Pin;
use std::task::Poll;
let me = Pin::into_inner(self);
match &mut me.future {
Some(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(None)) => {
me.future = None;
Poll::Ready(None)
}
Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => {
me.future = Some(Box::pin(maybe_next(state)));
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending,
Poll::Ready(Err(err)) => {
me.future = None;
Poll::Ready(Some(Err(err)))
}
Poll::Pending => Poll::Pending,
},
None => Poll::Ready(None),
}
}
}