use std::collections::VecDeque;
use std::pin::Pin;
use arrow::compute::kernels;
use arrow_array::RecordBatch;
use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
use datafusion_common::DataFusionError;
use futures::{Stream, StreamExt, TryStreamExt};
use lance_core::Result;
struct BatchReaderChunker {
inner: SendableRecordBatchStream,
buffered: VecDeque<RecordBatch>,
output_size: usize,
i: usize,
}
impl BatchReaderChunker {
fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
Self {
inner,
buffered: VecDeque::new(),
output_size,
i: 0,
}
}
fn buffered_len(&self) -> usize {
let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
buffer_total - self.i
}
async fn fill_buffer(&mut self) -> Result<()> {
while self.buffered_len() < self.output_size {
match self.inner.next().await {
Some(Ok(batch)) => self.buffered.push_back(batch),
Some(Err(e)) => return Err(e.into()),
None => break,
}
}
Ok(())
}
async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
match self.fill_buffer().await {
Ok(_) => {}
Err(e) => return Some(Err(e)),
};
let mut batches = Vec::new();
let mut rows_collected = 0;
while rows_collected < self.output_size {
if let Some(batch) = self.buffered.pop_front() {
if batch.num_rows() == 0 {
continue;
}
let rows_remaining_in_batch = batch.num_rows() - self.i;
let rows_to_take =
std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
if rows_to_take == rows_remaining_in_batch {
let batch = if self.i == 0 {
batch
} else {
batch.slice(self.i, rows_to_take)
};
batches.push(batch);
self.i = 0;
} else {
batches.push(batch.slice(self.i, rows_to_take));
self.i += rows_to_take;
self.buffered.push_front(batch);
}
rows_collected += rows_to_take;
} else {
break;
}
}
if batches.is_empty() {
None
} else {
Some(Ok(batches))
}
}
}
pub fn chunk_stream(
stream: SendableRecordBatchStream,
chunk_size: usize,
) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
let chunker = BatchReaderChunker::new(stream, chunk_size);
futures::stream::unfold(chunker, |mut chunker| async move {
match chunker.next().await {
Some(Ok(batches)) => Some((Ok(batches), chunker)),
Some(Err(e)) => Some((Err(e), chunker)),
None => None,
}
})
.boxed()
}
pub fn chunk_concat_stream(
stream: SendableRecordBatchStream,
chunk_size: usize,
) -> SendableRecordBatchStream {
let schema = stream.schema().clone();
let schema_copy = schema.clone();
let chunked = chunk_stream(stream, chunk_size);
let chunk_concat = chunked
.and_then(move |batches| {
std::future::ready(
kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
)
})
.map_err(DataFusionError::from)
.boxed();
Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
}