Skip to main content

lance_arrow/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with streams of [`RecordBatch`].
5
6use arrow_array::RecordBatch;
7use arrow_schema::{ArrowError, SchemaRef};
8use futures::stream::{self, Stream, StreamExt};
9use std::pin::Pin;
10
11use crate::deepcopy::deep_copy_batch_sliced;
12
13/// Rechunks a stream of [`RecordBatch`] so that each output batch has
14/// approximately `target_bytes` of array data.
15///
16/// Small input batches are accumulated (by concatenation) until at least
17/// `min_bytes` of data has been collected. If the resulting batch exceeds
18/// `max_bytes`, it is sliced into roughly equal pieces of ~`max_bytes`
19/// (assuming uniform row sizes).
20pub fn rechunk_stream_by_size<S, E>(
21    input: S,
22    input_schema: SchemaRef,
23    min_bytes: usize,
24    max_bytes: usize,
25) -> impl Stream<Item = Result<RecordBatch, E>>
26where
27    S: Stream<Item = Result<RecordBatch, E>>,
28    E: From<ArrowError>,
29{
30    rechunk_stream_by_size_inner(input, input_schema, min_bytes, max_bytes, false)
31}
32
33/// Like [`rechunk_stream_by_size`] but deep-copies slices so that
34/// `get_array_memory_size` reflects the true size of each output batch.
35///
36/// After a normal `RecordBatch::slice`, the backing buffers are shared with
37/// the original batch, so `get_array_memory_size` still reports the full
38/// parent size.  This variant deep-copies every slice produced during the
39/// splitting phase, which allows the stream to detect and re-split slices
40/// that still exceed `max_bytes` (e.g. because a single row is much larger
41/// than average).
42///
43/// The deep copy is a last resort and potentially expensive for large
44/// batches.  However, it is only performed when a batch actually needs to be
45/// sliced — batches that are already within the target range pass through at
46/// zero cost.  Use this only when the hard cap on `max_bytes` is a
47/// correctness requirement, not merely a performance hint.
48pub fn rechunk_stream_by_size_deep_copy<S, E>(
49    input: S,
50    input_schema: SchemaRef,
51    min_bytes: usize,
52    max_bytes: usize,
53) -> impl Stream<Item = Result<RecordBatch, E>>
54where
55    S: Stream<Item = Result<RecordBatch, E>>,
56    E: From<ArrowError>,
57{
58    rechunk_stream_by_size_inner(input, input_schema, min_bytes, max_bytes, true)
59}
60
61fn rechunk_stream_by_size_inner<S, E>(
62    input: S,
63    input_schema: SchemaRef,
64    min_bytes: usize,
65    max_bytes: usize,
66    deep_copy: bool,
67) -> impl Stream<Item = Result<RecordBatch, E>>
68where
69    S: Stream<Item = Result<RecordBatch, E>>,
70    E: From<ArrowError>,
71{
72    stream::try_unfold(
73        RechunkState {
74            input: Box::pin(input),
75            accumulated: Vec::new(),
76            acc_bytes: 0,
77            done: false,
78            input_schema,
79            min_bytes,
80            max_bytes,
81            deep_copy,
82        },
83        |mut state| async move {
84            if state.done && state.accumulated.is_empty() {
85                return Ok(None);
86            }
87
88            // Pull batches until we reach the byte target or exhaust input.
89            // Always pull at least one batch so that min_bytes=0 works.
90            while !state.done && (state.accumulated.is_empty() || state.acc_bytes < state.min_bytes)
91            {
92                match state.input.next().await {
93                    Some(Ok(batch)) => {
94                        state.acc_bytes += batch.get_array_memory_size();
95                        state.accumulated.push(batch);
96                    }
97                    Some(Err(e)) => return Err(e),
98                    None => {
99                        state.done = true;
100                    }
101                }
102            }
103
104            if state.accumulated.is_empty() {
105                return Ok(None);
106            }
107
108            // Fast path: if the first accumulated batch already meets the
109            // byte threshold, deliver it directly instead of concatenating
110            // everything together (which would just get sliced back apart).
111            if state.accumulated.len() > 1
112                && state.accumulated[0].get_array_memory_size() >= state.min_bytes
113            {
114                let b = state.accumulated.remove(0);
115                state.acc_bytes -= b.get_array_memory_size();
116                return Ok(Some((b, state)));
117            }
118
119            let batch = if state.accumulated.len() == 1 {
120                state.accumulated.pop().unwrap()
121            } else {
122                let b =
123                    arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated)
124                        .map_err(E::from)?;
125                state.accumulated.clear();
126                b
127            };
128            state.acc_bytes = 0;
129
130            // Slice the batch into ~max_bytes pieces assuming uniform row sizes.
131            let mut slices =
132                slice_batch(batch, state.max_bytes, state.deep_copy).map_err(E::from)?;
133
134            if slices.len() == 1 {
135                Ok(Some((slices.pop().unwrap(), state)))
136            } else {
137                let first = slices.remove(0);
138
139                // Stash leftover slices for subsequent iterations.
140                for a in &slices {
141                    state.acc_bytes += a.get_array_memory_size();
142                }
143                state.accumulated = slices;
144
145                Ok(Some((first, state)))
146            }
147        },
148    )
149}
150
151/// Slice a batch into pieces of at most `max_bytes`.
152///
153/// When `deep_copy` is false, slices share buffers with the original batch
154/// and `get_array_memory_size` will still report the parent buffer size.
155/// This is fine when the caller only needs approximate sizing.
156///
157/// When `deep_copy` is true, each slice is deep-copied so that
158/// `get_array_memory_size` reflects the true size.  If a deep-copied slice
159/// still exceeds `max_bytes` (due to non-uniform row sizes), it is
160/// recursively split until every piece is within budget or contains only a
161/// single row.
162fn slice_batch(
163    batch: RecordBatch,
164    max_bytes: usize,
165    deep_copy: bool,
166) -> Result<Vec<RecordBatch>, ArrowError> {
167    let batch_bytes = batch.get_array_memory_size();
168    let num_rows = batch.num_rows();
169
170    if batch_bytes <= max_bytes || num_rows <= 1 {
171        return Ok(vec![batch]);
172    }
173
174    let rows_per_chunk = (max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize;
175
176    let mut result = Vec::new();
177    let mut offset = 0;
178    while offset < num_rows {
179        let len = rows_per_chunk.min(num_rows - offset);
180        let slice = batch.slice(offset, len);
181        if deep_copy {
182            let copied = deep_copy_batch_sliced(&slice)?;
183            // Recurse: the deep-copied slice has accurate sizes, so if it
184            // still exceeds max_bytes we can split further.
185            result.extend(slice_batch(copied, max_bytes, true)?);
186        } else {
187            result.push(slice);
188        }
189        offset += len;
190    }
191
192    Ok(result)
193}
194
195/// Internal state for [`rechunk_stream`].
196///
197/// Kept as a named struct so the `try_unfold` closure stays readable.
198struct RechunkState<S> {
199    input: Pin<Box<S>>,
200    accumulated: Vec<RecordBatch>,
201    acc_bytes: usize,
202    done: bool,
203    input_schema: SchemaRef,
204    min_bytes: usize,
205    max_bytes: usize,
206    deep_copy: bool,
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    use std::sync::Arc;
214
215    use arrow_array::Int32Array;
216    use arrow_schema::{DataType, Field, Schema};
217    use futures::executor::block_on;
218
219    fn make_batch(num_rows: usize) -> RecordBatch {
220        let schema = test_schema();
221        let values: Vec<i32> = (0..num_rows as i32).collect();
222        RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
223    }
224
225    fn test_schema() -> SchemaRef {
226        Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
227    }
228
229    fn collect_rechunked(
230        batches: Vec<RecordBatch>,
231        min_bytes: usize,
232        max_bytes: usize,
233    ) -> Vec<RecordBatch> {
234        let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
235        let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes);
236        block_on(rechunked.collect::<Vec<_>>())
237            .into_iter()
238            .map(|r| r.unwrap())
239            .collect()
240    }
241
242    fn total_rows(batches: &[RecordBatch]) -> usize {
243        batches.iter().map(|b| b.num_rows()).sum()
244    }
245
246    #[test]
247    fn test_empty_stream() {
248        let result = collect_rechunked(vec![], 100, 200);
249        assert!(result.is_empty());
250    }
251
252    #[test]
253    fn test_single_batch_passthrough() {
254        let batch = make_batch(100);
255        let bytes = batch.get_array_memory_size();
256        // Batch is between min and max — should pass through as-is.
257        let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2);
258        assert_eq!(result.len(), 1);
259        assert_eq!(result[0].num_rows(), 100);
260    }
261
262    #[test]
263    fn test_small_batches_concatenated() {
264        let one_batch_bytes = make_batch(10).get_array_memory_size();
265        let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect();
266        // min = 5 batches worth, max = 10 batches worth.
267        let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10);
268        assert_eq!(total_rows(&result), 80);
269        // Should have been concatenated into fewer batches than the 8 inputs.
270        assert!(
271            result.len() < 8,
272            "expected fewer output batches, got {}",
273            result.len()
274        );
275    }
276
277    #[test]
278    fn test_large_batch_sliced() {
279        let batch = make_batch(1000);
280        let bytes = batch.get_array_memory_size();
281        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
282        assert_eq!(total_rows(&result), 1000);
283        assert!(
284            result.len() >= 4,
285            "expected at least 4 slices, got {}",
286            result.len()
287        );
288    }
289
290    #[test]
291    fn test_sliced_leftovers_are_not_recombined() {
292        // Key test for the fast-path optimisation. When a large batch is
293        // sliced, leftover slices should be delivered one-at-a-time without
294        // being concatenated back together.  We verify this by checking that
295        // every output buffer pointer falls inside the original batch's
296        // allocation (i.e. they are all zero-copy slices, not fresh copies).
297        let batch = make_batch(1000);
298        let bytes = batch.get_array_memory_size();
299        let orig_data = batch.column(0).to_data();
300        let orig_buf = &orig_data.buffers()[0];
301        let orig_start = orig_buf.as_ptr() as usize;
302        let orig_end = orig_start + orig_buf.len();
303
304        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
305
306        assert_eq!(total_rows(&result), 1000);
307        assert!(result.len() >= 4);
308
309        for (i, b) in result.iter().enumerate() {
310            let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize;
311            assert!(
312                ptr >= orig_start && ptr < orig_end,
313                "slice {i} buffer at {ptr:#x} is outside the original allocation \
314                 [{orig_start:#x}, {orig_end:#x}) — it was re-concatenated"
315            );
316        }
317    }
318
319    #[test]
320    fn test_flush_remainder_on_stream_end() {
321        // Data below min_bytes should still be flushed when the stream ends.
322        let batch = make_batch(10);
323        let bytes = batch.get_array_memory_size();
324        let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200);
325        assert_eq!(result.len(), 1);
326        assert_eq!(result[0].num_rows(), 10);
327    }
328
329    #[test]
330    fn test_large_then_small_batches() {
331        // After a large batch is fully drained, subsequent small batches
332        // should be accumulated normally.
333        let large = make_batch(1000);
334        let small_bytes = make_batch(10).get_array_memory_size();
335        let batches = vec![
336            large,
337            make_batch(10),
338            make_batch(10),
339            make_batch(10),
340            make_batch(10),
341            make_batch(10),
342        ];
343        let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100);
344        assert_eq!(total_rows(&result), 1050);
345        // The large batch should appear (possibly sliced) followed by
346        // concatenated small batches, so we should have fewer output batches
347        // than the 6 inputs.
348        assert!(result.len() < 6);
349    }
350
351    #[test]
352    fn test_row_preservation_across_slicing() {
353        // Verify that every input row appears exactly once in the output
354        // and in the correct order after slicing.
355        let batch = make_batch(237); // odd count to exercise remainder slice
356        let bytes = batch.get_array_memory_size();
357        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5);
358
359        assert_eq!(total_rows(&result), 237);
360
361        let values: Vec<i32> = result
362            .iter()
363            .flat_map(|b| {
364                b.column(0)
365                    .as_any()
366                    .downcast_ref::<Int32Array>()
367                    .unwrap()
368                    .values()
369                    .iter()
370                    .copied()
371            })
372            .collect();
373        let expected: Vec<i32> = (0..237).collect();
374        assert_eq!(values, expected);
375    }
376
377    #[test]
378    fn test_min_bytes_zero_still_yields_all_rows() {
379        // When min_bytes=0, the stream should still yield every batch.
380        // This is the "chop only, don't coalesce" use case.
381        let batches: Vec<_> = (0..5).map(|_| make_batch(100)).collect();
382        let batch_bytes = batches[0].get_array_memory_size();
383        let result = collect_rechunked(batches, 0, batch_bytes * 2);
384        assert_eq!(total_rows(&result), 500);
385    }
386
387    #[test]
388    fn test_min_bytes_zero_slices_oversized() {
389        // min_bytes=0 with a small max_bytes should still slice large batches.
390        let batch = make_batch(1000);
391        let bytes = batch.get_array_memory_size();
392        let result = collect_rechunked(vec![batch], 0, bytes / 4);
393        assert_eq!(total_rows(&result), 1000);
394        assert!(
395            result.len() >= 4,
396            "expected at least 4 slices, got {}",
397            result.len()
398        );
399    }
400
401    /// Build a batch with one variable-length string column.
402    /// Every row is `small_size` bytes except the row at index `big_row_idx`
403    /// which is `big_size` bytes.
404    fn make_variable_batch(
405        num_rows: usize,
406        small_size: usize,
407        big_row_idx: usize,
408        big_size: usize,
409    ) -> RecordBatch {
410        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)]));
411        let values: Vec<String> = (0..num_rows)
412            .map(|i| {
413                if i == big_row_idx {
414                    "X".repeat(big_size)
415                } else {
416                    "x".repeat(small_size)
417                }
418            })
419            .collect();
420        let array = arrow_array::StringArray::from(values);
421        RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
422    }
423
424    fn variable_schema() -> SchemaRef {
425        Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)]))
426    }
427
428    fn collect_rechunked_variable(
429        batches: Vec<RecordBatch>,
430        min_bytes: usize,
431        max_bytes: usize,
432    ) -> Vec<RecordBatch> {
433        let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
434        let rechunked =
435            rechunk_stream_by_size_deep_copy(input, variable_schema(), min_bytes, max_bytes);
436        block_on(rechunked.collect::<Vec<_>>())
437            .into_iter()
438            .map(|r| r.unwrap())
439            .collect()
440    }
441
442    #[test]
443    fn test_oversized_row_at_end() {
444        // 100 rows: 99 small (64 bytes each) + 1 large (100KiB) at the end.
445        let batch = make_variable_batch(100, 64, 99, 100 * 1024);
446        let max_bytes = 64 * 1024;
447        let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
448        assert_eq!(total_rows(&result), 100);
449        for (i, b) in result.iter().enumerate() {
450            let size = b.get_array_memory_size();
451            assert!(
452                size <= max_bytes || b.num_rows() == 1,
453                "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
454                b.num_rows()
455            );
456        }
457    }
458
459    #[test]
460    fn test_oversized_row_at_start() {
461        // 100 rows: 1 large (100KiB) at the start + 99 small (64 bytes each).
462        let batch = make_variable_batch(100, 64, 0, 100 * 1024);
463        let max_bytes = 64 * 1024;
464        let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
465        assert_eq!(total_rows(&result), 100);
466        for (i, b) in result.iter().enumerate() {
467            let size = b.get_array_memory_size();
468            assert!(
469                size <= max_bytes || b.num_rows() == 1,
470                "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
471                b.num_rows()
472            );
473        }
474    }
475
476    #[test]
477    fn test_oversized_row_in_middle() {
478        // 100 rows: 1 large (100KiB) in the middle + 99 small (64 bytes each).
479        let batch = make_variable_batch(100, 64, 50, 100 * 1024);
480        let max_bytes = 64 * 1024;
481        let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
482        assert_eq!(total_rows(&result), 100);
483        for (i, b) in result.iter().enumerate() {
484            let size = b.get_array_memory_size();
485            assert!(
486                size <= max_bytes || b.num_rows() == 1,
487                "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
488                b.num_rows()
489            );
490        }
491    }
492
493    #[test]
494    fn test_error_propagation() {
495        let input = stream::iter(vec![
496            Ok(make_batch(10)),
497            Err(ArrowError::ComputeError("boom".into())),
498            Ok(make_batch(10)),
499        ]);
500        let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX);
501        let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect());
502        assert!(results.iter().any(|r| r.is_err()));
503    }
504}