lance_datafusion/
chunker.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::pin::Pin;
5use std::task::Poll;
6use std::{collections::VecDeque, task::Context};
7
8use arrow::compute::kernels;
9use arrow_array::RecordBatch;
10use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
11use datafusion_common::DataFusionError;
12use futures::{ready, Stream, StreamExt, TryStreamExt};
13
14use lance_core::error::DataFusionResult;
15use lance_core::Result;
16
17/// Wraps a [`SendableRecordBatchStream`] into a stream of RecordBatch chunks of
18/// a given size.  This slices but does not copy any buffers.
19struct BatchReaderChunker {
20    /// The inner stream
21    inner: SendableRecordBatchStream,
22    /// The batches that have been read from the inner stream but not yet fully yielded
23    buffered: VecDeque<RecordBatch>,
24    /// The number of rows to yield in each chunk
25    output_size: usize,
26    /// The position within the first batch in the buffer to start yielding from
27    i: usize,
28}
29
30impl BatchReaderChunker {
31    fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
32        Self {
33            inner,
34            buffered: VecDeque::new(),
35            output_size,
36            i: 0,
37        }
38    }
39
40    fn buffered_len(&self) -> usize {
41        let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
42        buffer_total - self.i
43    }
44
45    async fn fill_buffer(&mut self) -> Result<()> {
46        while self.buffered_len() < self.output_size {
47            match self.inner.next().await {
48                Some(Ok(batch)) => self.buffered.push_back(batch),
49                Some(Err(e)) => return Err(e.into()),
50                None => break,
51            }
52        }
53        Ok(())
54    }
55
56    async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
57        match self.fill_buffer().await {
58            Ok(_) => {}
59            Err(e) => return Some(Err(e)),
60        };
61
62        let mut batches = Vec::new();
63
64        let mut rows_collected = 0;
65
66        while rows_collected < self.output_size {
67            if let Some(batch) = self.buffered.pop_front() {
68                // Skip empty batch
69                if batch.num_rows() == 0 {
70                    continue;
71                }
72
73                let rows_remaining_in_batch = batch.num_rows() - self.i;
74                let rows_to_take =
75                    std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
76
77                if rows_to_take == rows_remaining_in_batch {
78                    // We're taking the whole batch, so we can just move it
79                    let batch = if self.i == 0 {
80                        batch
81                    } else {
82                        // We are taking the remainder of the batch, so we need to slice it
83                        batch.slice(self.i, rows_to_take)
84                    };
85                    batches.push(batch);
86                    self.i = 0;
87                } else {
88                    // We're taking a slice of the batch, so we need to copy it
89                    batches.push(batch.slice(self.i, rows_to_take));
90                    // And then we need to push the remainder back onto the front of the queue
91                    self.i += rows_to_take;
92                    self.buffered.push_front(batch);
93                }
94
95                rows_collected += rows_to_take;
96            } else {
97                break;
98            }
99        }
100
101        if batches.is_empty() {
102            None
103        } else {
104            Some(Ok(batches))
105        }
106    }
107}
108
109struct BreakStreamState {
110    max_rows: usize,
111    rows_seen: usize,
112    rows_remaining: usize,
113    batch: Option<RecordBatch>,
114}
115
116impl BreakStreamState {
117    fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
118        if self.rows_remaining == 0 {
119            return None;
120        }
121        if self.rows_remaining + self.rows_seen <= self.max_rows {
122            self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
123            self.rows_remaining = 0;
124            let next = self.batch.take().unwrap();
125            Some((Ok(next), self))
126        } else {
127            let rows_to_emit = self.max_rows - self.rows_seen;
128            self.rows_seen = 0;
129            self.rows_remaining -= rows_to_emit;
130            let batch = self.batch.as_mut().unwrap();
131            let next = batch.slice(0, rows_to_emit);
132            *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
133            Some((Ok(next), self))
134        }
135    }
136}
137
138// Given a stream of record batches, and a desired break point, this will
139// make sure that a new record batch is emitted every time `break_point` rows
140// have passed.
141//
142// This method will not combine record batches in any way.  For example, if
143// the input lengths are [3, 5, 8, 3, 5], and the break point is 10 then the
144// output batches will be [3, 5, 2 (break inserted) 6, 3, 1 (break inserted) 4]
145pub fn break_stream(
146    stream: SendableRecordBatchStream,
147    max_chunk_size: usize,
148) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
149    let mut rows_already_seen = 0;
150    stream
151        .map_ok(move |batch| {
152            let state = BreakStreamState {
153                rows_remaining: batch.num_rows(),
154                max_rows: max_chunk_size,
155                rows_seen: rows_already_seen,
156                batch: Some(batch),
157            };
158            rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
159
160            futures::stream::unfold(state, move |state| std::future::ready(state.next())).boxed()
161        })
162        .try_flatten()
163        .boxed()
164}
165
166pub fn chunk_stream(
167    stream: SendableRecordBatchStream,
168    chunk_size: usize,
169) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
170    let chunker = BatchReaderChunker::new(stream, chunk_size);
171    futures::stream::unfold(chunker, |mut chunker| async move {
172        match chunker.next().await {
173            Some(Ok(batches)) => Some((Ok(batches), chunker)),
174            Some(Err(e)) => Some((Err(e), chunker)),
175            None => None,
176        }
177    })
178    .boxed()
179}
180
181pub fn chunk_concat_stream(
182    stream: SendableRecordBatchStream,
183    chunk_size: usize,
184) -> SendableRecordBatchStream {
185    let schema = stream.schema();
186    let schema_copy = schema.clone();
187    let chunked = chunk_stream(stream, chunk_size);
188    let chunk_concat = chunked
189        .and_then(move |batches| {
190            std::future::ready(
191                // chunk_stream is zero-copy and so it gives us pieces of batches.  However, the btree
192                // index needs 1 batch-per-page and so we concatenate here.
193                kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
194            )
195        })
196        .map_err(DataFusionError::from)
197        .boxed();
198    Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
199}
200
201/// Given a stream of record batches, this will yield batches of a fixed size.
202///
203/// This stream _will_ combine record batches and so it can be fairly expensive as it will
204/// likely force a copy of all incoming data.  However, it can be useful when users require
205/// precise batch sizing.
206pub struct StrictBatchSizeStream<S> {
207    inner: S,
208    batch_size: usize,
209    residual: Option<RecordBatch>,
210}
211
212impl<S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin> StrictBatchSizeStream<S> {
213    pub fn new(inner: S, batch_size: usize) -> Self {
214        Self {
215            inner,
216            batch_size,
217            residual: None,
218        }
219    }
220}
221
222/// Internal polling method for strict batch size enforcement.
223///
224/// # Use Case
225/// When precise batch sizing is required (e.g., ML batch processing), this method guarantees
226/// output batches exactly match batch_size until final partial batch. Maintains data integrity
227/// across splits using row-aware splitting.
228///
229/// # Example
230/// With batch_size=5 and input sequence:
231/// - Fragment 1: 7 rows → splits into [5,2]
232///   (queues 5, carries 2)
233/// - Fragment 2: 4 rows → combines carried 2 + 4 = 6
234///   splits into [5,1]
235///
236/// - Output batches: [5], [5], [1]
237impl<S> Stream for StrictBatchSizeStream<S>
238where
239    S: Stream<Item = DataFusionResult<RecordBatch>> + Unpin,
240{
241    type Item = DataFusionResult<RecordBatch>;
242
243    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244        loop {
245            // Process residual first if present
246            if let Some(residual) = self.residual.take() {
247                if residual.num_rows() >= self.batch_size {
248                    let split_at = self.batch_size;
249                    let chunk = residual.slice(0, split_at);
250                    let new_residual = residual.slice(split_at, residual.num_rows() - split_at);
251                    self.residual = Some(new_residual);
252                    return Poll::Ready(Some(Ok(chunk)));
253                } else {
254                    // Keep residual and proceed to get more data
255                    self.residual = Some(residual);
256                }
257            }
258
259            // Poll the inner stream for next batch
260            match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
261                Some(Ok(batch)) => {
262                    // Combine with residual if any
263                    let current_batch = if let Some(residual) = self.residual.take() {
264                        arrow::compute::concat_batches(&residual.schema(), &[residual, batch])
265                            .map_err(|e| DataFusionError::External(Box::new(e)))?
266                    } else {
267                        batch
268                    };
269
270                    if current_batch.num_rows() >= self.batch_size {
271                        let split_at = self.batch_size;
272                        let chunk = current_batch.slice(0, split_at);
273                        let new_residual =
274                            current_batch.slice(split_at, current_batch.num_rows() - split_at);
275                        if new_residual.num_rows() > 0 {
276                            self.residual = Some(new_residual);
277                        }
278                        return Poll::Ready(Some(Ok(chunk)));
279                    } else {
280                        // Not enough rows, store as residual
281                        self.residual = Some(current_batch);
282                        continue;
283                    }
284                }
285                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
286                None => {
287                    return Poll::Ready(
288                        self.residual
289                            .take()
290                            .filter(|r| r.num_rows() > 0)
291                            .map(Ok::<_, DataFusionError>),
292                    );
293                }
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use std::sync::Arc;
302
303    use arrow::datatypes::{Int32Type, Int64Type};
304    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
305    use futures::{StreamExt, TryStreamExt};
306    use lance_datagen::{array, BatchCount, RowCount};
307
308    use crate::datagen::DatafusionDatagenExt;
309
310    #[tokio::test]
311    async fn test_chunkers() {
312        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
313            arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
314        ]));
315
316        let make_batch = |num_rows: u32| {
317            lance_datagen::gen_batch()
318                .anon_col(lance_datagen::array::step::<Int32Type>())
319                .into_batch_rows(RowCount::from(num_rows as u64))
320                .unwrap()
321        };
322
323        let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
324
325        let make_stream = || {
326            let stream = futures::stream::iter(
327                batches
328                    .clone()
329                    .into_iter()
330                    .map(datafusion_common::Result::Ok),
331            )
332            .boxed();
333            Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
334        };
335
336        let chunked = super::chunk_stream(make_stream(), 10)
337            .try_collect::<Vec<_>>()
338            .await
339            .unwrap();
340
341        assert_eq!(chunked.len(), 3);
342        assert_eq!(chunked[0].len(), 1);
343        assert_eq!(chunked[0][0].num_rows(), 10);
344        assert_eq!(chunked[1].len(), 2);
345        assert_eq!(chunked[1][0].num_rows(), 5);
346        assert_eq!(chunked[1][1].num_rows(), 5);
347        assert_eq!(chunked[2].len(), 1);
348        assert_eq!(chunked[2][0].num_rows(), 8);
349
350        let chunked = super::chunk_concat_stream(make_stream(), 10)
351            .try_collect::<Vec<_>>()
352            .await
353            .unwrap();
354
355        assert_eq!(chunked.len(), 3);
356        assert_eq!(chunked[0].num_rows(), 10);
357        assert_eq!(chunked[1].num_rows(), 10);
358        assert_eq!(chunked[2].num_rows(), 8);
359
360        let chunked = super::break_stream(make_stream(), 10)
361            .try_collect::<Vec<_>>()
362            .await
363            .unwrap();
364
365        assert_eq!(chunked.len(), 4);
366        assert_eq!(chunked[0].num_rows(), 10);
367        assert_eq!(chunked[1].num_rows(), 5);
368        assert_eq!(chunked[2].num_rows(), 5);
369        assert_eq!(chunked[3].num_rows(), 8);
370    }
371
372    #[tokio::test]
373    async fn test_strict_batch_size_stream() {
374        let batches = lance_datagen::gen_batch()
375            .anon_col(array::step::<Int32Type>())
376            .anon_col(array::step::<Int64Type>())
377            .into_df_stream(RowCount::from(7), BatchCount::from(10));
378
379        let stream = super::StrictBatchSizeStream::new(batches, 10);
380
381        let batches = stream.try_collect::<Vec<_>>().await.unwrap();
382        assert_eq!(batches.len(), 7);
383
384        for batch in batches {
385            assert_eq!(batch.num_rows(), 10);
386        }
387    }
388}