1use std::pin::Pin;
5use std::task::Poll;
6use std::{collections::VecDeque, task::Context};
7
8use arrow::compute::kernels;
9use arrow_array::RecordBatch;
10use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
11use datafusion_common::DataFusionError;
12use futures::{ready, Stream, StreamExt, TryStreamExt};
13
14use lance_core::error::DataFusionResult;
15use lance_core::Result;
16
17struct BatchReaderChunker {
20 inner: SendableRecordBatchStream,
22 buffered: VecDeque<RecordBatch>,
24 output_size: usize,
26 i: usize,
28}
29
30impl BatchReaderChunker {
31 fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
32 Self {
33 inner,
34 buffered: VecDeque::new(),
35 output_size,
36 i: 0,
37 }
38 }
39
40 fn buffered_len(&self) -> usize {
41 let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
42 buffer_total - self.i
43 }
44
45 async fn fill_buffer(&mut self) -> Result<()> {
46 while self.buffered_len() < self.output_size {
47 match self.inner.next().await {
48 Some(Ok(batch)) => self.buffered.push_back(batch),
49 Some(Err(e)) => return Err(e.into()),
50 None => break,
51 }
52 }
53 Ok(())
54 }
55
56 async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
57 match self.fill_buffer().await {
58 Ok(_) => {}
59 Err(e) => return Some(Err(e)),
60 };
61
62 let mut batches = Vec::new();
63
64 let mut rows_collected = 0;
65
66 while rows_collected < self.output_size {
67 if let Some(batch) = self.buffered.pop_front() {
68 if batch.num_rows() == 0 {
70 continue;
71 }
72
73 let rows_remaining_in_batch = batch.num_rows() - self.i;
74 let rows_to_take =
75 std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
76
77 if rows_to_take == rows_remaining_in_batch {
78 let batch = if self.i == 0 {
80 batch
81 } else {
82 batch.slice(self.i, rows_to_take)
84 };
85 batches.push(batch);
86 self.i = 0;
87 } else {
88 batches.push(batch.slice(self.i, rows_to_take));
90 self.i += rows_to_take;
92 self.buffered.push_front(batch);
93 }
94
95 rows_collected += rows_to_take;
96 } else {
97 break;
98 }
99 }
100
101 if batches.is_empty() {
102 None
103 } else {
104 Some(Ok(batches))
105 }
106 }
107}
108
109struct BreakStreamState {
110 max_rows: usize,
111 rows_seen: usize,
112 rows_remaining: usize,
113 batch: Option<RecordBatch>,
114}
115
116impl BreakStreamState {
117 fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
118 if self.rows_remaining == 0 {
119 return None;
120 }
121 if self.rows_remaining + self.rows_seen <= self.max_rows {
122 self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
123 self.rows_remaining = 0;
124 let next = self.batch.take().unwrap();
125 Some((Ok(next), self))
126 } else {
127 let rows_to_emit = self.max_rows - self.rows_seen;
128 self.rows_seen = 0;
129 self.rows_remaining -= rows_to_emit;
130 let batch = self.batch.as_mut().unwrap();
131 let next = batch.slice(0, rows_to_emit);
132 *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
133 Some((Ok(next), self))
134 }
135 }
136}
137
138pub fn break_stream(
146 stream: SendableRecordBatchStream,
147 max_chunk_size: usize,
148) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
149 let mut rows_already_seen = 0;
150 stream
151 .map_ok(move |batch| {
152 let state = BreakStreamState {
153 rows_remaining: batch.num_rows(),
154 max_rows: max_chunk_size,
155 rows_seen: rows_already_seen,
156 batch: Some(batch),
157 };
158 rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
159
160 futures::stream::unfold(state, move |state| std::future::ready(state.next()))
161 .fuse()
162 .boxed()
163 })
164 .try_flatten()
165 .boxed()
166}
167
168pub fn chunk_stream(
174 stream: SendableRecordBatchStream,
175 chunk_size: usize,
176) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
177 let chunker = BatchReaderChunker::new(stream, chunk_size);
178 futures::stream::unfold(chunker, |mut chunker| async move {
179 match chunker.next().await {
180 Some(Ok(batches)) => Some((Ok(batches), chunker)),
181 Some(Err(e)) => Some((Err(e), chunker)),
182 None => None,
183 }
184 })
185 .fuse()
186 .boxed()
187}
188
189pub fn chunk_concat_stream(
195 stream: SendableRecordBatchStream,
196 chunk_size: usize,
197) -> SendableRecordBatchStream {
198 let schema = stream.schema();
199 let schema_copy = schema.clone();
200 let chunked = chunk_stream(stream, chunk_size);
201 let chunk_concat = chunked
202 .and_then(move |batches| {
203 std::future::ready(
204 kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
207 )
208 })
209 .map_err(DataFusionError::from)
210 .boxed();
211 Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
212}
213
214pub struct StrictBatchSizeStream<S> {
220 inner: S,
221 batch_size: usize,
222 residual: Option<RecordBatch>,
223}
224
225impl<S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin> StrictBatchSizeStream<S> {
226 pub fn new(inner: S, batch_size: usize) -> Self {
227 Self {
228 inner,
229 batch_size,
230 residual: None,
231 }
232 }
233}
234
235impl<S> Stream for StrictBatchSizeStream<S>
251where
252 S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin,
253{
254 type Item = DataFusionResult<RecordBatch>;
255
256 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
257 loop {
258 if let Some(residual) = self.residual.take() {
260 if residual.num_rows() >= self.batch_size {
261 let split_at = self.batch_size;
262 let chunk = residual.slice(0, split_at);
263 let new_residual = residual.slice(split_at, residual.num_rows() - split_at);
264 self.residual = Some(new_residual);
265 return Poll::Ready(Some(Ok(chunk)));
266 } else {
267 self.residual = Some(residual);
269 }
270 }
271
272 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
274 Some(Ok(batch)) => {
275 let current_batch = if let Some(residual) = self.residual.take() {
277 arrow::compute::concat_batches(&residual.schema(), &[residual, batch])
278 .map_err(|e| DataFusionError::External(Box::new(e)))?
279 } else {
280 batch
281 };
282
283 if current_batch.num_rows() >= self.batch_size {
284 let split_at = self.batch_size;
285 let chunk = current_batch.slice(0, split_at);
286 let new_residual =
287 current_batch.slice(split_at, current_batch.num_rows() - split_at);
288 if new_residual.num_rows() > 0 {
289 self.residual = Some(new_residual);
290 }
291 return Poll::Ready(Some(Ok(chunk)));
292 } else {
293 self.residual = Some(current_batch);
295 continue;
296 }
297 }
298 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
299 None => {
300 return Poll::Ready(
301 self.residual
302 .take()
303 .filter(|r| r.num_rows() > 0)
304 .map(Ok::<_, DataFusionError>),
305 );
306 }
307 }
308 }
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use std::sync::Arc;
315
316 use arrow::datatypes::{Int32Type, Int64Type};
317 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
318 use futures::{StreamExt, TryStreamExt};
319 use lance_datagen::{array, BatchCount, RowCount};
320
321 use crate::datagen::DatafusionDatagenExt;
322
323 #[tokio::test]
324 async fn test_chunkers() {
325 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
326 arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
327 ]));
328
329 let make_batch = |num_rows: u32| {
330 lance_datagen::gen_batch()
331 .anon_col(lance_datagen::array::step::<Int32Type>())
332 .into_batch_rows(RowCount::from(num_rows as u64))
333 .unwrap()
334 };
335
336 let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
337
338 let make_stream = || {
339 let stream = futures::stream::iter(
340 batches
341 .clone()
342 .into_iter()
343 .map(datafusion_common::Result::Ok),
344 )
345 .boxed();
346 Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
347 };
348
349 let chunked = super::chunk_stream(make_stream(), 10)
350 .try_collect::<Vec<_>>()
351 .await
352 .unwrap();
353
354 assert_eq!(chunked.len(), 3);
355 assert_eq!(chunked[0].len(), 1);
356 assert_eq!(chunked[0][0].num_rows(), 10);
357 assert_eq!(chunked[1].len(), 2);
358 assert_eq!(chunked[1][0].num_rows(), 5);
359 assert_eq!(chunked[1][1].num_rows(), 5);
360 assert_eq!(chunked[2].len(), 1);
361 assert_eq!(chunked[2][0].num_rows(), 8);
362
363 let chunked = super::chunk_concat_stream(make_stream(), 10)
364 .try_collect::<Vec<_>>()
365 .await
366 .unwrap();
367
368 assert_eq!(chunked.len(), 3);
369 assert_eq!(chunked[0].num_rows(), 10);
370 assert_eq!(chunked[1].num_rows(), 10);
371 assert_eq!(chunked[2].num_rows(), 8);
372
373 let chunked = super::break_stream(make_stream(), 10)
374 .try_collect::<Vec<_>>()
375 .await
376 .unwrap();
377
378 assert_eq!(chunked.len(), 4);
379 assert_eq!(chunked[0].num_rows(), 10);
380 assert_eq!(chunked[1].num_rows(), 5);
381 assert_eq!(chunked[2].num_rows(), 5);
382 assert_eq!(chunked[3].num_rows(), 8);
383 }
384
385 #[tokio::test]
386 async fn test_strict_batch_size_stream() {
387 let batches = lance_datagen::gen_batch()
388 .anon_col(array::step::<Int32Type>())
389 .anon_col(array::step::<Int64Type>())
390 .into_df_stream(RowCount::from(7), BatchCount::from(10));
391
392 let stream = super::StrictBatchSizeStream::new(batches, 10);
393
394 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
395 assert_eq!(batches.len(), 7);
396
397 for batch in batches {
398 assert_eq!(batch.num_rows(), 10);
399 }
400 }
401}