use crate::enums::IPCMessageProtocol;
use crate::models::encoders::ipc::table_stream::GTableStreamEncoder;
use crate::traits::stream_buffer::StreamBuffer;
use crate::utils::extract_dictionary_values_from_col;
use futures_core::Stream;
use minarrow::{Field, Table};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
pub struct TableStreamWriter<B>
where
B: StreamBuffer + Unpin + 'static,
{
encoder: GTableStreamEncoder<B>,
}
impl<B> TableStreamWriter<B>
where
B: StreamBuffer + Unpin + 'static,
{
pub fn new(schema: Vec<Field>, protocol: IPCMessageProtocol) -> Self {
Self {
encoder: GTableStreamEncoder::<B>::new(schema, protocol),
}
}
pub fn register_dictionary(&mut self, dict_id: i64, values: Vec<String>) {
self.encoder.register_dictionary(dict_id, values);
}
pub fn write(&mut self, table: &Table) -> io::Result<()> {
self.encoder.write_record_batch_frame(table)
}
pub fn finish(&mut self) -> io::Result<()> {
self.encoder.finish()
}
pub fn next_frame(&mut self) -> Option<io::Result<B>> {
self.encoder.out_frames.pop_front().map(Ok)
}
pub fn drain_all_frames(&mut self) -> Vec<B> {
let mut out = Vec::new();
while let Some(frame) = self.encoder.out_frames.pop_front() {
out.push(frame);
}
out
}
pub fn is_finished(&self) -> bool {
self.encoder.finished && self.encoder.out_frames.is_empty()
}
pub fn schema(&self) -> &[Field] {
&self.encoder.schema
}
}
impl<B> Stream for TableStreamWriter<B>
where
B: StreamBuffer + Unpin + 'static,
{
type Item = io::Result<B>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(frame) = this.encoder.out_frames.pop_front() {
Poll::Ready(Some(Ok(frame)))
} else if this.encoder.finished {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
pub async fn write_tables_to_stream<W, B>(
mut stream: W,
tables: &[Table],
schema: Vec<Field>,
protocol: IPCMessageProtocol,
) -> io::Result<()>
where
W: AsyncWrite + Unpin + Send + Sync,
B: StreamBuffer + Unpin,
{
let mut writer = TableStreamWriter::<B>::new(schema, protocol);
for table in tables {
for (col_idx, col) in table.cols.iter().enumerate() {
if let Some(values) = extract_dictionary_values_from_col(col) {
writer.register_dictionary(col_idx as i64, values);
}
}
writer.write(table)?;
}
writer.finish()?;
while let Some(frame) = writer.next_frame() {
let buf = frame?;
stream.write_all(buf.as_ref()).await?;
}
stream.flush().await?;
Ok(())
}
pub async fn write_table_to_stream<W, B>(
mut stream: W,
table: &Table,
schema: Vec<Field>,
protocol: IPCMessageProtocol,
) -> io::Result<()>
where
W: AsyncWrite + Unpin + Send + Sync,
B: StreamBuffer + Unpin,
{
let mut writer = TableStreamWriter::<B>::new(schema, protocol);
for (col_idx, col) in table.cols.iter().enumerate() {
if let Some(values) = extract_dictionary_values_from_col(col) {
writer.register_dictionary(col_idx as i64, values);
}
}
writer.write(table)?;
writer.finish()?;
while let Some(frame) = writer.next_frame() {
let buf = frame?;
stream.write_all(buf.as_ref()).await?;
}
stream.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::enums::IPCMessageProtocol;
use crate::test_helpers::*;
use minarrow::{Field, Table, Vec64};
use std::io;
fn all_types_schema() -> Vec<Field> {
make_schema_all_types()
}
fn test_table() -> Table {
make_all_types_table()
}
#[test]
fn test_table_stream_writer_schema_and_finish() {
let schema = all_types_schema();
let mut writer =
TableStreamWriter::<Vec64<u8>>::new(schema.clone(), IPCMessageProtocol::Stream);
assert_eq!(writer.schema(), &schema[..]);
assert!(!writer.is_finished());
writer.finish().unwrap();
assert!(writer.is_finished());
}
#[test]
fn test_write_and_drain_one_table() {
let schema = all_types_schema();
let table = test_table();
let mut writer =
TableStreamWriter::<Vec64<u8>>::new(schema.clone(), IPCMessageProtocol::Stream);
for (col_idx, col) in table.cols.iter().enumerate() {
if let Some(values) = extract_dictionary_values_from_col(col) {
writer.register_dictionary(col_idx as i64, values);
}
}
writer.write(&table).unwrap();
writer.finish().unwrap();
let frames = writer.drain_all_frames();
assert!(
!frames.is_empty(),
"No frames emitted after writing table and finish"
);
assert!(frames.len() >= 2);
let total_len: usize = frames.iter().map(|f| f.len()).sum();
assert!(total_len > 0);
}
#[test]
fn test_multiple_batches_emit_multiple_frames() {
let schema = all_types_schema();
let table1 = test_table();
let mut table2 = test_table();
table2.name = "another".into();
let mut writer =
TableStreamWriter::<Vec64<u8>>::new(schema.clone(), IPCMessageProtocol::Stream);
for (col_idx, col) in table1.cols.iter().enumerate() {
if let Some(values) = extract_dictionary_values_from_col(col) {
writer.register_dictionary(col_idx as i64, values);
}
}
writer.write(&table1).unwrap();
writer.write(&table2).unwrap();
writer.finish().unwrap();
let frames = writer.drain_all_frames();
assert!(
frames.len() >= 4,
"Expected at least 4 frames: schema, 2 batches, EOS"
);
}
#[test]
fn test_next_frame_returns_none_when_empty() {
let schema = all_types_schema();
let mut writer = TableStreamWriter::<Vec64<u8>>::new(schema, IPCMessageProtocol::Stream);
assert!(writer.next_frame().is_none());
writer.finish().unwrap();
assert!(writer.next_frame().is_none());
}
#[test]
fn test_error_on_schema_mismatch() {
let schema = all_types_schema();
let mut bad_table = test_table();
bad_table.cols.pop(); let mut writer = TableStreamWriter::<Vec64<u8>>::new(schema, IPCMessageProtocol::Stream);
let err = writer.write(&bad_table).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn test_stream_trait_polling() {
let schema = all_types_schema();
let table = test_table();
let mut writer = TableStreamWriter::<Vec64<u8>>::new(schema, IPCMessageProtocol::Stream);
for (col_idx, col) in table.cols.iter().enumerate() {
if let Some(values) = extract_dictionary_values_from_col(col) {
writer.register_dictionary(col_idx as i64, values);
}
}
writer.write(&table).unwrap();
writer.finish().unwrap();
let mut pin_writer = Box::pin(writer);
let mut frames = Vec::new();
let cx = futures_util::task::noop_waker_ref();
loop {
match Pin::new(&mut pin_writer)
.as_mut()
.poll_next(&mut Context::from_waker(cx))
{
Poll::Ready(Some(Ok(frame))) => frames.push(frame),
Poll::Ready(None) => break,
Poll::Ready(Some(Err(e))) => panic!("Unexpected error from poll_next: {e}"),
Poll::Pending => continue,
}
}
assert!(
!frames.is_empty(),
"Should emit at least some frames through poll_next"
);
}
}