Skip to main content

lance_arrow/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with streams of [`RecordBatch`].
5
6use arrow_array::RecordBatch;
7use arrow_schema::{ArrowError, SchemaRef};
8use futures::stream::{self, Stream, StreamExt};
9use std::pin::Pin;
10
11/// Rechunks a stream of [`RecordBatch`] so that each output batch has
12/// approximately `target_bytes` of array data.
13///
14/// Small input batches are accumulated (by concatenation) until at least
15/// `min_bytes` of data has been collected. If the resulting batch exceeds
16/// `max_bytes`, it is sliced into roughly equal pieces of ~`max_bytes`
17/// (assuming uniform row sizes).
18pub fn rechunk_stream_by_size<S, E>(
19    input: S,
20    input_schema: SchemaRef,
21    min_bytes: usize,
22    max_bytes: usize,
23) -> impl Stream<Item = Result<RecordBatch, E>>
24where
25    S: Stream<Item = Result<RecordBatch, E>>,
26    E: From<ArrowError>,
27{
28    stream::try_unfold(
29        RechunkState {
30            input: Box::pin(input),
31            accumulated: Vec::new(),
32            acc_bytes: 0,
33            done: false,
34            input_schema,
35            min_bytes,
36            max_bytes,
37        },
38        |mut state| async move {
39            if state.done && state.accumulated.is_empty() {
40                return Ok(None);
41            }
42
43            // Pull batches until we reach the byte target or exhaust input.
44            while !state.done && state.acc_bytes < state.min_bytes {
45                match state.input.next().await {
46                    Some(Ok(batch)) => {
47                        state.acc_bytes += batch.get_array_memory_size();
48                        state.accumulated.push(batch);
49                    }
50                    Some(Err(e)) => return Err(e),
51                    None => {
52                        state.done = true;
53                    }
54                }
55            }
56
57            if state.accumulated.is_empty() {
58                return Ok(None);
59            }
60
61            // Fast path: if the first accumulated batch already meets the
62            // byte threshold, deliver it directly instead of concatenating
63            // everything together (which would just get sliced back apart).
64            if state.accumulated.len() > 1
65                && state.accumulated[0].get_array_memory_size() >= state.min_bytes
66            {
67                let b = state.accumulated.remove(0);
68                state.acc_bytes -= b.get_array_memory_size();
69                return Ok(Some((b, state)));
70            }
71
72            let batch = if state.accumulated.len() == 1 {
73                state.accumulated.pop().unwrap()
74            } else {
75                let b =
76                    arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated)
77                        .map_err(E::from)?;
78                state.accumulated.clear();
79                b
80            };
81            state.acc_bytes = 0;
82
83            // Slice the batch into ~max_bytes pieces assuming uniform row sizes.
84            let batch_bytes = batch.get_array_memory_size();
85            let num_rows = batch.num_rows();
86            if batch_bytes <= state.max_bytes || num_rows <= 1 {
87                Ok(Some((batch, state)))
88            } else {
89                let rows_per_chunk =
90                    (state.max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize;
91                let mut slices = Vec::new();
92                let mut offset = 0;
93                while offset < num_rows {
94                    let len = rows_per_chunk.min(num_rows - offset);
95                    slices.push(batch.slice(offset, len));
96                    offset += len;
97                }
98
99                let first = slices.remove(0);
100
101                // Stash leftover slices for subsequent iterations.
102                for a in &slices {
103                    state.acc_bytes += a.get_array_memory_size();
104                }
105                state.accumulated = slices;
106
107                Ok(Some((first, state)))
108            }
109        },
110    )
111}
112
113/// Internal state for [`rechunk_stream`].
114///
115/// Kept as a named struct so the `try_unfold` closure stays readable.
116struct RechunkState<S> {
117    input: Pin<Box<S>>,
118    accumulated: Vec<RecordBatch>,
119    acc_bytes: usize,
120    done: bool,
121    input_schema: SchemaRef,
122    min_bytes: usize,
123    max_bytes: usize,
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    use std::sync::Arc;
131
132    use arrow_array::Int32Array;
133    use arrow_schema::{DataType, Field, Schema};
134    use futures::executor::block_on;
135
136    fn make_batch(num_rows: usize) -> RecordBatch {
137        let schema = test_schema();
138        let values: Vec<i32> = (0..num_rows as i32).collect();
139        RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
140    }
141
142    fn test_schema() -> SchemaRef {
143        Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
144    }
145
146    fn collect_rechunked(
147        batches: Vec<RecordBatch>,
148        min_bytes: usize,
149        max_bytes: usize,
150    ) -> Vec<RecordBatch> {
151        let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
152        let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes);
153        block_on(rechunked.collect::<Vec<_>>())
154            .into_iter()
155            .map(|r| r.unwrap())
156            .collect()
157    }
158
159    fn total_rows(batches: &[RecordBatch]) -> usize {
160        batches.iter().map(|b| b.num_rows()).sum()
161    }
162
163    #[test]
164    fn test_empty_stream() {
165        let result = collect_rechunked(vec![], 100, 200);
166        assert!(result.is_empty());
167    }
168
169    #[test]
170    fn test_single_batch_passthrough() {
171        let batch = make_batch(100);
172        let bytes = batch.get_array_memory_size();
173        // Batch is between min and max — should pass through as-is.
174        let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2);
175        assert_eq!(result.len(), 1);
176        assert_eq!(result[0].num_rows(), 100);
177    }
178
179    #[test]
180    fn test_small_batches_concatenated() {
181        let one_batch_bytes = make_batch(10).get_array_memory_size();
182        let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect();
183        // min = 5 batches worth, max = 10 batches worth.
184        let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10);
185        assert_eq!(total_rows(&result), 80);
186        // Should have been concatenated into fewer batches than the 8 inputs.
187        assert!(
188            result.len() < 8,
189            "expected fewer output batches, got {}",
190            result.len()
191        );
192    }
193
194    #[test]
195    fn test_large_batch_sliced() {
196        let batch = make_batch(1000);
197        let bytes = batch.get_array_memory_size();
198        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
199        assert_eq!(total_rows(&result), 1000);
200        assert!(
201            result.len() >= 4,
202            "expected at least 4 slices, got {}",
203            result.len()
204        );
205    }
206
207    #[test]
208    fn test_sliced_leftovers_are_not_recombined() {
209        // Key test for the fast-path optimisation. When a large batch is
210        // sliced, leftover slices should be delivered one-at-a-time without
211        // being concatenated back together.  We verify this by checking that
212        // every output buffer pointer falls inside the original batch's
213        // allocation (i.e. they are all zero-copy slices, not fresh copies).
214        let batch = make_batch(1000);
215        let bytes = batch.get_array_memory_size();
216        let orig_data = batch.column(0).to_data();
217        let orig_buf = &orig_data.buffers()[0];
218        let orig_start = orig_buf.as_ptr() as usize;
219        let orig_end = orig_start + orig_buf.len();
220
221        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
222
223        assert_eq!(total_rows(&result), 1000);
224        assert!(result.len() >= 4);
225
226        for (i, b) in result.iter().enumerate() {
227            let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize;
228            assert!(
229                ptr >= orig_start && ptr < orig_end,
230                "slice {i} buffer at {ptr:#x} is outside the original allocation \
231                 [{orig_start:#x}, {orig_end:#x}) — it was re-concatenated"
232            );
233        }
234    }
235
236    #[test]
237    fn test_flush_remainder_on_stream_end() {
238        // Data below min_bytes should still be flushed when the stream ends.
239        let batch = make_batch(10);
240        let bytes = batch.get_array_memory_size();
241        let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200);
242        assert_eq!(result.len(), 1);
243        assert_eq!(result[0].num_rows(), 10);
244    }
245
246    #[test]
247    fn test_large_then_small_batches() {
248        // After a large batch is fully drained, subsequent small batches
249        // should be accumulated normally.
250        let large = make_batch(1000);
251        let small_bytes = make_batch(10).get_array_memory_size();
252        let batches = vec![
253            large,
254            make_batch(10),
255            make_batch(10),
256            make_batch(10),
257            make_batch(10),
258            make_batch(10),
259        ];
260        let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100);
261        assert_eq!(total_rows(&result), 1050);
262        // The large batch should appear (possibly sliced) followed by
263        // concatenated small batches, so we should have fewer output batches
264        // than the 6 inputs.
265        assert!(result.len() < 6);
266    }
267
268    #[test]
269    fn test_row_preservation_across_slicing() {
270        // Verify that every input row appears exactly once in the output
271        // and in the correct order after slicing.
272        let batch = make_batch(237); // odd count to exercise remainder slice
273        let bytes = batch.get_array_memory_size();
274        let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5);
275
276        assert_eq!(total_rows(&result), 237);
277
278        let values: Vec<i32> = result
279            .iter()
280            .flat_map(|b| {
281                b.column(0)
282                    .as_any()
283                    .downcast_ref::<Int32Array>()
284                    .unwrap()
285                    .values()
286                    .iter()
287                    .copied()
288            })
289            .collect();
290        let expected: Vec<i32> = (0..237).collect();
291        assert_eq!(values, expected);
292    }
293
294    #[test]
295    fn test_error_propagation() {
296        let input = stream::iter(vec![
297            Ok(make_batch(10)),
298            Err(ArrowError::ComputeError("boom".into())),
299            Ok(make_batch(10)),
300        ]);
301        let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX);
302        let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect());
303        assert!(results.iter().any(|r| r.is_err()));
304    }
305}