1use std::pin::Pin;
5use std::task::Poll;
6use std::{collections::VecDeque, task::Context};
7
8use arrow::compute::kernels;
9use arrow_array::RecordBatch;
10use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
11use datafusion_common::DataFusionError;
12use futures::{ready, Stream, StreamExt, TryStreamExt};
13
14use lance_core::error::DataFusionResult;
15use lance_core::Result;
16
17struct BatchReaderChunker {
20 inner: SendableRecordBatchStream,
22 buffered: VecDeque<RecordBatch>,
24 output_size: usize,
26 i: usize,
28}
29
30impl BatchReaderChunker {
31 fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
32 Self {
33 inner,
34 buffered: VecDeque::new(),
35 output_size,
36 i: 0,
37 }
38 }
39
40 fn buffered_len(&self) -> usize {
41 let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
42 buffer_total - self.i
43 }
44
45 async fn fill_buffer(&mut self) -> Result<()> {
46 while self.buffered_len() < self.output_size {
47 match self.inner.next().await {
48 Some(Ok(batch)) => self.buffered.push_back(batch),
49 Some(Err(e)) => return Err(e.into()),
50 None => break,
51 }
52 }
53 Ok(())
54 }
55
56 async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
57 match self.fill_buffer().await {
58 Ok(_) => {}
59 Err(e) => return Some(Err(e)),
60 };
61
62 let mut batches = Vec::new();
63
64 let mut rows_collected = 0;
65
66 while rows_collected < self.output_size {
67 if let Some(batch) = self.buffered.pop_front() {
68 if batch.num_rows() == 0 {
70 continue;
71 }
72
73 let rows_remaining_in_batch = batch.num_rows() - self.i;
74 let rows_to_take =
75 std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
76
77 if rows_to_take == rows_remaining_in_batch {
78 let batch = if self.i == 0 {
80 batch
81 } else {
82 batch.slice(self.i, rows_to_take)
84 };
85 batches.push(batch);
86 self.i = 0;
87 } else {
88 batches.push(batch.slice(self.i, rows_to_take));
90 self.i += rows_to_take;
92 self.buffered.push_front(batch);
93 }
94
95 rows_collected += rows_to_take;
96 } else {
97 break;
98 }
99 }
100
101 if batches.is_empty() {
102 None
103 } else {
104 Some(Ok(batches))
105 }
106 }
107}
108
109struct BreakStreamState {
110 max_rows: usize,
111 rows_seen: usize,
112 rows_remaining: usize,
113 batch: Option<RecordBatch>,
114}
115
116impl BreakStreamState {
117 fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
118 if self.rows_remaining == 0 {
119 return None;
120 }
121 if self.rows_remaining + self.rows_seen <= self.max_rows {
122 self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
123 self.rows_remaining = 0;
124 let next = self.batch.take().unwrap();
125 Some((Ok(next), self))
126 } else {
127 let rows_to_emit = self.max_rows - self.rows_seen;
128 self.rows_seen = 0;
129 self.rows_remaining -= rows_to_emit;
130 let batch = self.batch.as_mut().unwrap();
131 let next = batch.slice(0, rows_to_emit);
132 *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
133 Some((Ok(next), self))
134 }
135 }
136}
137
138pub fn break_stream(
146 stream: SendableRecordBatchStream,
147 max_chunk_size: usize,
148) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
149 let mut rows_already_seen = 0;
150 stream
151 .map_ok(move |batch| {
152 let state = BreakStreamState {
153 rows_remaining: batch.num_rows(),
154 max_rows: max_chunk_size,
155 rows_seen: rows_already_seen,
156 batch: Some(batch),
157 };
158 rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
159
160 futures::stream::unfold(state, move |state| std::future::ready(state.next())).boxed()
161 })
162 .try_flatten()
163 .boxed()
164}
165
166pub fn chunk_stream(
167 stream: SendableRecordBatchStream,
168 chunk_size: usize,
169) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
170 let chunker = BatchReaderChunker::new(stream, chunk_size);
171 futures::stream::unfold(chunker, |mut chunker| async move {
172 match chunker.next().await {
173 Some(Ok(batches)) => Some((Ok(batches), chunker)),
174 Some(Err(e)) => Some((Err(e), chunker)),
175 None => None,
176 }
177 })
178 .boxed()
179}
180
181pub fn chunk_concat_stream(
182 stream: SendableRecordBatchStream,
183 chunk_size: usize,
184) -> SendableRecordBatchStream {
185 let schema = stream.schema();
186 let schema_copy = schema.clone();
187 let chunked = chunk_stream(stream, chunk_size);
188 let chunk_concat = chunked
189 .and_then(move |batches| {
190 std::future::ready(
191 kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
194 )
195 })
196 .map_err(DataFusionError::from)
197 .boxed();
198 Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
199}
200
201pub struct StrictBatchSizeStream<S> {
207 inner: S,
208 batch_size: usize,
209 residual: Option<RecordBatch>,
210}
211
212impl<S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin> StrictBatchSizeStream<S> {
213 pub fn new(inner: S, batch_size: usize) -> Self {
214 Self {
215 inner,
216 batch_size,
217 residual: None,
218 }
219 }
220}
221
222impl<S> Stream for StrictBatchSizeStream<S>
238where
239 S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin,
240{
241 type Item = DataFusionResult<RecordBatch>;
242
243 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244 loop {
245 if let Some(residual) = self.residual.take() {
247 if residual.num_rows() >= self.batch_size {
248 let split_at = self.batch_size;
249 let chunk = residual.slice(0, split_at);
250 let new_residual = residual.slice(split_at, residual.num_rows() - split_at);
251 self.residual = Some(new_residual);
252 return Poll::Ready(Some(Ok(chunk)));
253 } else {
254 self.residual = Some(residual);
256 }
257 }
258
259 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
261 Some(Ok(batch)) => {
262 let current_batch = if let Some(residual) = self.residual.take() {
264 arrow::compute::concat_batches(&residual.schema(), &[residual, batch])
265 .map_err(|e| DataFusionError::External(Box::new(e)))?
266 } else {
267 batch
268 };
269
270 if current_batch.num_rows() >= self.batch_size {
271 let split_at = self.batch_size;
272 let chunk = current_batch.slice(0, split_at);
273 let new_residual =
274 current_batch.slice(split_at, current_batch.num_rows() - split_at);
275 if new_residual.num_rows() > 0 {
276 self.residual = Some(new_residual);
277 }
278 return Poll::Ready(Some(Ok(chunk)));
279 } else {
280 self.residual = Some(current_batch);
282 continue;
283 }
284 }
285 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
286 None => {
287 return Poll::Ready(
288 self.residual
289 .take()
290 .filter(|r| r.num_rows() > 0)
291 .map(Ok::<_, DataFusionError>),
292 );
293 }
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use std::sync::Arc;
302
303 use arrow::datatypes::{Int32Type, Int64Type};
304 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
305 use futures::{StreamExt, TryStreamExt};
306 use lance_datagen::{array, BatchCount, RowCount};
307
308 use crate::datagen::DatafusionDatagenExt;
309
310 #[tokio::test]
311 async fn test_chunkers() {
312 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
313 arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
314 ]));
315
316 let make_batch = |num_rows: u32| {
317 lance_datagen::gen()
318 .anon_col(lance_datagen::array::step::<Int32Type>())
319 .into_batch_rows(RowCount::from(num_rows as u64))
320 .unwrap()
321 };
322
323 let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
324
325 let make_stream = || {
326 let stream = futures::stream::iter(
327 batches
328 .clone()
329 .into_iter()
330 .map(datafusion_common::Result::Ok),
331 )
332 .boxed();
333 Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
334 };
335
336 let chunked = super::chunk_stream(make_stream(), 10)
337 .try_collect::<Vec<_>>()
338 .await
339 .unwrap();
340
341 assert_eq!(chunked.len(), 3);
342 assert_eq!(chunked[0].len(), 1);
343 assert_eq!(chunked[0][0].num_rows(), 10);
344 assert_eq!(chunked[1].len(), 2);
345 assert_eq!(chunked[1][0].num_rows(), 5);
346 assert_eq!(chunked[1][1].num_rows(), 5);
347 assert_eq!(chunked[2].len(), 1);
348 assert_eq!(chunked[2][0].num_rows(), 8);
349
350 let chunked = super::chunk_concat_stream(make_stream(), 10)
351 .try_collect::<Vec<_>>()
352 .await
353 .unwrap();
354
355 assert_eq!(chunked.len(), 3);
356 assert_eq!(chunked[0].num_rows(), 10);
357 assert_eq!(chunked[1].num_rows(), 10);
358 assert_eq!(chunked[2].num_rows(), 8);
359
360 let chunked = super::break_stream(make_stream(), 10)
361 .try_collect::<Vec<_>>()
362 .await
363 .unwrap();
364
365 assert_eq!(chunked.len(), 4);
366 assert_eq!(chunked[0].num_rows(), 10);
367 assert_eq!(chunked[1].num_rows(), 5);
368 assert_eq!(chunked[2].num_rows(), 5);
369 assert_eq!(chunked[3].num_rows(), 8);
370 }
371
372 #[tokio::test]
373 async fn test_strict_batch_size_stream() {
374 let batches = lance_datagen::gen()
375 .anon_col(array::step::<Int32Type>())
376 .anon_col(array::step::<Int64Type>())
377 .into_df_stream(RowCount::from(7), BatchCount::from(10));
378
379 let stream = super::StrictBatchSizeStream::new(batches, 10);
380
381 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
382 assert_eq!(batches.len(), 7);
383
384 for batch in batches {
385 assert_eq!(batch.num_rows(), 10);
386 }
387 }
388}