lance_datafusion/
chunker.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::pin::Pin;
5use std::task::Poll;
6use std::{collections::VecDeque, task::Context};
7
8use arrow::compute::kernels;
9use arrow_array::RecordBatch;
10use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
11use datafusion_common::DataFusionError;
12use futures::{ready, Stream, StreamExt, TryStreamExt};
13
14use lance_core::error::DataFusionResult;
15use lance_core::Result;
16
17/// Wraps a [`SendableRecordBatchStream`] into a stream of RecordBatch chunks of
18/// a given size.  This slices but does not copy any buffers.
19struct BatchReaderChunker {
20    /// The inner stream
21    inner: SendableRecordBatchStream,
22    /// The batches that have been read from the inner stream but not yet fully yielded
23    buffered: VecDeque<RecordBatch>,
24    /// The number of rows to yield in each chunk
25    output_size: usize,
26    /// The position within the first batch in the buffer to start yielding from
27    i: usize,
28}
29
30impl BatchReaderChunker {
31    fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
32        Self {
33            inner,
34            buffered: VecDeque::new(),
35            output_size,
36            i: 0,
37        }
38    }
39
40    fn buffered_len(&self) -> usize {
41        let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
42        buffer_total - self.i
43    }
44
45    async fn fill_buffer(&mut self) -> Result<()> {
46        while self.buffered_len() < self.output_size {
47            match self.inner.next().await {
48                Some(Ok(batch)) => self.buffered.push_back(batch),
49                Some(Err(e)) => return Err(e.into()),
50                None => break,
51            }
52        }
53        Ok(())
54    }
55
56    async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
57        match self.fill_buffer().await {
58            Ok(_) => {}
59            Err(e) => return Some(Err(e)),
60        };
61
62        let mut batches = Vec::new();
63
64        let mut rows_collected = 0;
65
66        while rows_collected < self.output_size {
67            if let Some(batch) = self.buffered.pop_front() {
68                // Skip empty batch
69                if batch.num_rows() == 0 {
70                    continue;
71                }
72
73                let rows_remaining_in_batch = batch.num_rows() - self.i;
74                let rows_to_take =
75                    std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
76
77                if rows_to_take == rows_remaining_in_batch {
78                    // We're taking the whole batch, so we can just move it
79                    let batch = if self.i == 0 {
80                        batch
81                    } else {
82                        // We are taking the remainder of the batch, so we need to slice it
83                        batch.slice(self.i, rows_to_take)
84                    };
85                    batches.push(batch);
86                    self.i = 0;
87                } else {
88                    // We're taking a slice of the batch, so we need to copy it
89                    batches.push(batch.slice(self.i, rows_to_take));
90                    // And then we need to push the remainder back onto the front of the queue
91                    self.i += rows_to_take;
92                    self.buffered.push_front(batch);
93                }
94
95                rows_collected += rows_to_take;
96            } else {
97                break;
98            }
99        }
100
101        if batches.is_empty() {
102            None
103        } else {
104            Some(Ok(batches))
105        }
106    }
107}
108
109struct BreakStreamState {
110    max_rows: usize,
111    rows_seen: usize,
112    rows_remaining: usize,
113    batch: Option<RecordBatch>,
114}
115
116impl BreakStreamState {
117    fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
118        if self.rows_remaining == 0 {
119            return None;
120        }
121        if self.rows_remaining + self.rows_seen <= self.max_rows {
122            self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
123            self.rows_remaining = 0;
124            let next = self.batch.take().unwrap();
125            Some((Ok(next), self))
126        } else {
127            let rows_to_emit = self.max_rows - self.rows_seen;
128            self.rows_seen = 0;
129            self.rows_remaining -= rows_to_emit;
130            let batch = self.batch.as_mut().unwrap();
131            let next = batch.slice(0, rows_to_emit);
132            *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
133            Some((Ok(next), self))
134        }
135    }
136}
137
138// Given a stream of record batches, and a desired break point, this will
139// make sure that a new record batch is emitted every time `break_point` rows
140// have passed.
141//
142// This method will not combine record batches in any way.  For example, if
143// the input lengths are [3, 5, 8, 3, 5], and the break point is 10 then the
144// output batches will be [3, 5, 2 (break inserted) 6, 3, 1 (break inserted) 4]
145pub fn break_stream(
146    stream: SendableRecordBatchStream,
147    max_chunk_size: usize,
148) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
149    let mut rows_already_seen = 0;
150    stream
151        .map_ok(move |batch| {
152            let state = BreakStreamState {
153                rows_remaining: batch.num_rows(),
154                max_rows: max_chunk_size,
155                rows_seen: rows_already_seen,
156                batch: Some(batch),
157            };
158            rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
159
160            futures::stream::unfold(state, move |state| std::future::ready(state.next()))
161                .fuse()
162                .boxed()
163        })
164        .try_flatten()
165        .boxed()
166}
167
168/// Given a stream of record batches, this will yield batches of a fixed size.
169///
170/// In order to avoid copying data the batches will be converted into a stream of
171/// `Vec<RecordBatch>` where each item is a `Vec` of batches whose total size is
172/// `chunk_size`.
173pub fn chunk_stream(
174    stream: SendableRecordBatchStream,
175    chunk_size: usize,
176) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
177    let chunker = BatchReaderChunker::new(stream, chunk_size);
178    futures::stream::unfold(chunker, |mut chunker| async move {
179        match chunker.next().await {
180            Some(Ok(batches)) => Some((Ok(batches), chunker)),
181            Some(Err(e)) => Some((Err(e), chunker)),
182            None => None,
183        }
184    })
185    .fuse()
186    .boxed()
187}
188
189/// Given a stream of record batches, this will yield batches of a fixed size.
190///
191/// This stream _will_ combine record batches and so it can be fairly expensive as it will
192/// likely force a copy of incoming data.  However, it can be useful when users require
193/// precise batch sizing.
194pub fn chunk_concat_stream(
195    stream: SendableRecordBatchStream,
196    chunk_size: usize,
197) -> SendableRecordBatchStream {
198    let schema = stream.schema();
199    let schema_copy = schema.clone();
200    let chunked = chunk_stream(stream, chunk_size);
201    let chunk_concat = chunked
202        .and_then(move |batches| {
203            std::future::ready(
204                // chunk_stream is zero-copy and so it gives us pieces of batches.  However, the btree
205                // index needs 1 batch-per-page and so we concatenate here.
206                kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
207            )
208        })
209        .map_err(DataFusionError::from)
210        .boxed();
211    Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
212}
213
214/// Given a stream of record batches, this will yield batches of a fixed size.
215///
216/// This stream _will_ combine record batches and so it can be fairly expensive as it will
217/// likely force a copy of all incoming data.  However, it can be useful when users require
218/// precise batch sizing.
219pub struct StrictBatchSizeStream<S> {
220    inner: S,
221    batch_size: usize,
222    residual: Option<RecordBatch>,
223}
224
225impl<S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin> StrictBatchSizeStream<S> {
226    pub fn new(inner: S, batch_size: usize) -> Self {
227        Self {
228            inner,
229            batch_size,
230            residual: None,
231        }
232    }
233}
234
235/// Internal polling method for strict batch size enforcement.
236///
237/// # Use Case
238/// When precise batch sizing is required (e.g., ML batch processing), this method guarantees
239/// output batches exactly match batch_size until final partial batch. Maintains data integrity
240/// across splits using row-aware splitting.
241///
242/// # Example
243/// With batch_size=5 and input sequence:
244/// - Fragment 1: 7 rows → splits into [5,2]
245///   (queues 5, carries 2)
246/// - Fragment 2: 4 rows → combines carried 2 + 4 = 6
247///   splits into [5,1]
248///
249/// - Output batches: [5], [5], [1]
250impl<S> Stream for StrictBatchSizeStream<S>
251where
252    S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin,
253{
254    type Item = DataFusionResult<RecordBatch>;
255
256    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
257        loop {
258            // Process residual first if present
259            if let Some(residual) = self.residual.take() {
260                if residual.num_rows() >= self.batch_size {
261                    let split_at = self.batch_size;
262                    let chunk = residual.slice(0, split_at);
263                    let new_residual = residual.slice(split_at, residual.num_rows() - split_at);
264                    self.residual = Some(new_residual);
265                    return Poll::Ready(Some(Ok(chunk)));
266                } else {
267                    // Keep residual and proceed to get more data
268                    self.residual = Some(residual);
269                }
270            }
271
272            // Poll the inner stream for next batch
273            match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
274                Some(Ok(batch)) => {
275                    // Combine with residual if any
276                    let current_batch = if let Some(residual) = self.residual.take() {
277                        arrow::compute::concat_batches(&residual.schema(), &[residual, batch])
278                            .map_err(|e| DataFusionError::External(Box::new(e)))?
279                    } else {
280                        batch
281                    };
282
283                    if current_batch.num_rows() >= self.batch_size {
284                        let split_at = self.batch_size;
285                        let chunk = current_batch.slice(0, split_at);
286                        let new_residual =
287                            current_batch.slice(split_at, current_batch.num_rows() - split_at);
288                        if new_residual.num_rows() > 0 {
289                            self.residual = Some(new_residual);
290                        }
291                        return Poll::Ready(Some(Ok(chunk)));
292                    } else {
293                        // Not enough rows, store as residual
294                        self.residual = Some(current_batch);
295                        continue;
296                    }
297                }
298                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
299                None => {
300                    return Poll::Ready(
301                        self.residual
302                            .take()
303                            .filter(|r| r.num_rows() > 0)
304                            .map(Ok::<_, DataFusionError>),
305                    );
306                }
307            }
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use std::sync::Arc;
315
316    use arrow::datatypes::{Int32Type, Int64Type};
317    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
318    use futures::{StreamExt, TryStreamExt};
319    use lance_datagen::{array, BatchCount, RowCount};
320
321    use crate::datagen::DatafusionDatagenExt;
322
323    #[tokio::test]
324    async fn test_chunkers() {
325        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
326            arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
327        ]));
328
329        let make_batch = |num_rows: u32| {
330            lance_datagen::gen_batch()
331                .anon_col(lance_datagen::array::step::<Int32Type>())
332                .into_batch_rows(RowCount::from(num_rows as u64))
333                .unwrap()
334        };
335
336        let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
337
338        let make_stream = || {
339            let stream = futures::stream::iter(
340                batches
341                    .clone()
342                    .into_iter()
343                    .map(datafusion_common::Result::Ok),
344            )
345            .boxed();
346            Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
347        };
348
349        let chunked = super::chunk_stream(make_stream(), 10)
350            .try_collect::<Vec<_>>()
351            .await
352            .unwrap();
353
354        assert_eq!(chunked.len(), 3);
355        assert_eq!(chunked[0].len(), 1);
356        assert_eq!(chunked[0][0].num_rows(), 10);
357        assert_eq!(chunked[1].len(), 2);
358        assert_eq!(chunked[1][0].num_rows(), 5);
359        assert_eq!(chunked[1][1].num_rows(), 5);
360        assert_eq!(chunked[2].len(), 1);
361        assert_eq!(chunked[2][0].num_rows(), 8);
362
363        let chunked = super::chunk_concat_stream(make_stream(), 10)
364            .try_collect::<Vec<_>>()
365            .await
366            .unwrap();
367
368        assert_eq!(chunked.len(), 3);
369        assert_eq!(chunked[0].num_rows(), 10);
370        assert_eq!(chunked[1].num_rows(), 10);
371        assert_eq!(chunked[2].num_rows(), 8);
372
373        let chunked = super::break_stream(make_stream(), 10)
374            .try_collect::<Vec<_>>()
375            .await
376            .unwrap();
377
378        assert_eq!(chunked.len(), 4);
379        assert_eq!(chunked[0].num_rows(), 10);
380        assert_eq!(chunked[1].num_rows(), 5);
381        assert_eq!(chunked[2].num_rows(), 5);
382        assert_eq!(chunked[3].num_rows(), 8);
383    }
384
385    #[tokio::test]
386    async fn test_strict_batch_size_stream() {
387        let batches = lance_datagen::gen_batch()
388            .anon_col(array::step::<Int32Type>())
389            .anon_col(array::step::<Int64Type>())
390            .into_df_stream(RowCount::from(7), BatchCount::from(10));
391
392        let stream = super::StrictBatchSizeStream::new(batches, 10);
393
394        let batches = stream.try_collect::<Vec<_>>().await.unwrap();
395        assert_eq!(batches.len(), 7);
396
397        for batch in batches {
398            assert_eq!(batch.num_rows(), 10);
399        }
400    }
401}