use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use std::pin::Pin;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::task::JoinSet;
use tracing::Span;
use tracing::dispatcher;
use crate::DeltaTableError;
use crate::errors::DeltaResult;
pub trait RecordBatchStream: Stream<Item = DeltaResult<RecordBatch>> {
fn schema(&self) -> SchemaRef;
}
pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
pub type SendableRBStream = Pin<Box<dyn Stream<Item = DeltaResult<RecordBatch>> + Send>>;
pub(crate) struct ReceiverStreamBuilder<O> {
tx: Sender<DeltaResult<O>>,
rx: Receiver<DeltaResult<O>>,
join_set: JoinSet<DeltaResult<()>>,
}
impl<O: Send + 'static> ReceiverStreamBuilder<O> {
pub fn new(capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
Self {
tx,
rx,
join_set: JoinSet::new(),
}
}
pub fn tx(&self) -> Sender<DeltaResult<O>> {
self.tx.clone()
}
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce() -> DeltaResult<()>,
F: Send + 'static,
{
let dispatch = dispatcher::get_default(|d| d.clone());
let span = Span::current();
self.join_set.spawn_blocking(move || {
dispatcher::with_default(&dispatch, || {
let _enter = span.enter();
f()
})
});
}
pub fn build(self) -> BoxStream<'static, DeltaResult<O>> {
let Self {
tx,
rx,
mut join_set,
} = self;
drop(tx);
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
Ok(_) => continue,
Err(error) => return Some(Err(error)),
}
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
return Some(Err(DeltaTableError::Generic(format!(
"Non Panic Task error: {e}"
))));
}
}
}
}
None
};
let check_stream = futures::stream::once(check)
.filter_map(|item| async move { item });
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});
futures::stream::select(rx_stream, check_stream).boxed()
}
}
pub(crate) type RecordBatchReceiverStreamBuilder = ReceiverStreamBuilder<RecordBatch>;