use arrow::{
datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch,
};
use futures::{Stream, StreamExt};
use pin_project_lite::pin_project;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use super::common::AbortOnDropSingle;
use super::{RecordBatchStream, SendableRecordBatchStream};
pub struct RecordBatchReceiverStream {
schema: SchemaRef,
inner: ReceiverStream<ArrowResult<RecordBatch>>,
#[allow(dead_code)]
drop_helper: AbortOnDropSingle<()>,
}
impl RecordBatchReceiverStream {
pub fn create(
schema: &SchemaRef,
rx: tokio::sync::mpsc::Receiver<ArrowResult<RecordBatch>>,
join_handle: JoinHandle<()>,
) -> SendableRecordBatchStream {
let schema = schema.clone();
let inner = ReceiverStream::new(rx);
Box::pin(Self {
schema,
inner,
drop_helper: AbortOnDropSingle::new(join_handle),
})
}
}
impl Stream for RecordBatchReceiverStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
impl RecordBatchStream for RecordBatchReceiverStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pin_project! {
pub struct RecordBatchStreamAdapter<S> {
schema: SchemaRef,
#[pin]
stream: S,
}
}
impl<S> RecordBatchStreamAdapter<S> {
pub fn new(schema: SchemaRef, stream: S) -> Self {
Self { schema, stream }
}
}
impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordBatchStreamAdapter")
.field("schema", &self.schema)
.finish()
}
}
impl<S> Stream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = ArrowResult<RecordBatch>>,
{
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = ArrowResult<RecordBatch>>,
{
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}