use crate::compression::Compression;
use crate::enums::IPCMessageProtocol;
use crate::models::encoders::ipc::table_stream::GTableStreamEncoder;
use crate::traits::stream_buffer::StreamBuffer;
use futures_core::Stream;
use minarrow::{Field, Table, Vec64};
use std::io;
use tokio::io::AsyncWrite;
use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
pub type TableSink<W> = GTableSink<W, Vec<u8>>;
pub type TableSink64<W> = GTableSink<W, Vec64<u8>>;
pub struct GTableSink<W, B>
where
W: AsyncWrite + Unpin + Send + Sync + 'static,
B: StreamBuffer,
{
pub(crate) schema: Vec<Field>,
pub(crate) inner: GTableStreamEncoder<B>,
pub(crate) destination: W,
pub(crate) protocol: IPCMessageProtocol,
pub(crate) schema_written: bool,
pub(crate) finished: bool,
pub(crate) frame_buf: Option<B>, pub(crate) frame_pos: usize, }
impl<W, B> GTableSink<W, B>
where
W: AsyncWrite + Unpin + Send + Sync + 'static,
B: StreamBuffer + std::fmt::Debug + Unpin + 'static,
{
pub fn new(sink: W, schema: Vec<Field>, protocol: IPCMessageProtocol) -> io::Result<Self> {
Ok(Self {
inner: GTableStreamEncoder::new(schema.clone(), protocol),
schema,
destination: sink,
protocol,
schema_written: false,
finished: false,
frame_buf: None,
frame_pos: 0,
})
}
pub fn with_compression(
sink: W,
schema: Vec<Field>,
protocol: IPCMessageProtocol,
compression: Compression,
) -> io::Result<Self> {
Ok(Self {
inner: GTableStreamEncoder::with_compression(schema.clone(), protocol, compression),
schema,
destination: sink,
protocol,
schema_written: false,
finished: false,
frame_buf: None,
frame_pos: 0,
})
}
pub fn sink_mut(&mut self) -> &mut W {
&mut self.destination
}
}
impl<W, B> Sink<Table> for GTableSink<W, B>
where
W: AsyncWrite + Unpin + Send + Sync + 'static,
B: StreamBuffer + std::fmt::Debug + Unpin + 'static,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, table: Table) -> Result<(), Self::Error> {
if !self.schema_written {
self.inner.write_schema_frame()?;
self.schema_written = true;
}
self.inner.write_record_batch_frame(&table)?;
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
if self.frame_buf.is_none() {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(frame))) => {
self.frame_pos = 0;
self.frame_buf = Some(frame);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Ready(None) | Poll::Pending => break,
}
}
if let Some(buf) = self.frame_buf.take() {
let remaining = &buf.as_ref()[self.frame_pos..];
const MAX_WRITE_CHUNK: usize = 1024 * 1024; let chunk = if remaining.len() > MAX_WRITE_CHUNK {
&remaining[..MAX_WRITE_CHUNK]
} else {
remaining
};
match Pin::new(&mut self.destination).poll_write(cx, chunk) {
Poll::Pending => {
self.frame_buf = Some(buf);
return Poll::Pending;
}
Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
Poll::Ready(Ok(n)) => {
self.frame_pos += n;
if self.frame_pos < buf.as_ref().len() {
self.frame_buf = Some(buf);
cx.waker().wake_by_ref();
return Poll::Pending;
} else {
self.frame_pos = 0;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
} else {
break; }
}
Pin::new(&mut self.destination).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.finished {
self.inner.finish()?; self.finished = true;
}
match self.as_mut().poll_flush(cx)? {
Poll::Pending => return Poll::Pending,
Poll::Ready(()) => { }
}
Pin::new(&mut self.destination)
.poll_shutdown(cx)
.map_err(Into::into)
}
}