1use arrow_array::RecordBatch;
7use arrow_schema::{ArrowError, SchemaRef};
8use futures::stream::{self, Stream, StreamExt};
9use std::pin::Pin;
10
11pub fn rechunk_stream_by_size<S, E>(
19 input: S,
20 input_schema: SchemaRef,
21 min_bytes: usize,
22 max_bytes: usize,
23) -> impl Stream<Item = Result<RecordBatch, E>>
24where
25 S: Stream<Item = Result<RecordBatch, E>>,
26 E: From<ArrowError>,
27{
28 stream::try_unfold(
29 RechunkState {
30 input: Box::pin(input),
31 accumulated: Vec::new(),
32 acc_bytes: 0,
33 done: false,
34 input_schema,
35 min_bytes,
36 max_bytes,
37 },
38 |mut state| async move {
39 if state.done && state.accumulated.is_empty() {
40 return Ok(None);
41 }
42
43 while !state.done && state.acc_bytes < state.min_bytes {
45 match state.input.next().await {
46 Some(Ok(batch)) => {
47 state.acc_bytes += batch.get_array_memory_size();
48 state.accumulated.push(batch);
49 }
50 Some(Err(e)) => return Err(e),
51 None => {
52 state.done = true;
53 }
54 }
55 }
56
57 if state.accumulated.is_empty() {
58 return Ok(None);
59 }
60
61 if state.accumulated.len() > 1
65 && state.accumulated[0].get_array_memory_size() >= state.min_bytes
66 {
67 let b = state.accumulated.remove(0);
68 state.acc_bytes -= b.get_array_memory_size();
69 return Ok(Some((b, state)));
70 }
71
72 let batch = if state.accumulated.len() == 1 {
73 state.accumulated.pop().unwrap()
74 } else {
75 let b =
76 arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated)
77 .map_err(E::from)?;
78 state.accumulated.clear();
79 b
80 };
81 state.acc_bytes = 0;
82
83 let batch_bytes = batch.get_array_memory_size();
85 let num_rows = batch.num_rows();
86 if batch_bytes <= state.max_bytes || num_rows <= 1 {
87 Ok(Some((batch, state)))
88 } else {
89 let rows_per_chunk =
90 (state.max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize;
91 let mut slices = Vec::new();
92 let mut offset = 0;
93 while offset < num_rows {
94 let len = rows_per_chunk.min(num_rows - offset);
95 slices.push(batch.slice(offset, len));
96 offset += len;
97 }
98
99 let first = slices.remove(0);
100
101 for a in &slices {
103 state.acc_bytes += a.get_array_memory_size();
104 }
105 state.accumulated = slices;
106
107 Ok(Some((first, state)))
108 }
109 },
110 )
111}
112
113struct RechunkState<S> {
117 input: Pin<Box<S>>,
118 accumulated: Vec<RecordBatch>,
119 acc_bytes: usize,
120 done: bool,
121 input_schema: SchemaRef,
122 min_bytes: usize,
123 max_bytes: usize,
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 use std::sync::Arc;
131
132 use arrow_array::Int32Array;
133 use arrow_schema::{DataType, Field, Schema};
134 use futures::executor::block_on;
135
136 fn make_batch(num_rows: usize) -> RecordBatch {
137 let schema = test_schema();
138 let values: Vec<i32> = (0..num_rows as i32).collect();
139 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
140 }
141
142 fn test_schema() -> SchemaRef {
143 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
144 }
145
146 fn collect_rechunked(
147 batches: Vec<RecordBatch>,
148 min_bytes: usize,
149 max_bytes: usize,
150 ) -> Vec<RecordBatch> {
151 let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
152 let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes);
153 block_on(rechunked.collect::<Vec<_>>())
154 .into_iter()
155 .map(|r| r.unwrap())
156 .collect()
157 }
158
159 fn total_rows(batches: &[RecordBatch]) -> usize {
160 batches.iter().map(|b| b.num_rows()).sum()
161 }
162
163 #[test]
164 fn test_empty_stream() {
165 let result = collect_rechunked(vec![], 100, 200);
166 assert!(result.is_empty());
167 }
168
169 #[test]
170 fn test_single_batch_passthrough() {
171 let batch = make_batch(100);
172 let bytes = batch.get_array_memory_size();
173 let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2);
175 assert_eq!(result.len(), 1);
176 assert_eq!(result[0].num_rows(), 100);
177 }
178
179 #[test]
180 fn test_small_batches_concatenated() {
181 let one_batch_bytes = make_batch(10).get_array_memory_size();
182 let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect();
183 let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10);
185 assert_eq!(total_rows(&result), 80);
186 assert!(
188 result.len() < 8,
189 "expected fewer output batches, got {}",
190 result.len()
191 );
192 }
193
194 #[test]
195 fn test_large_batch_sliced() {
196 let batch = make_batch(1000);
197 let bytes = batch.get_array_memory_size();
198 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
199 assert_eq!(total_rows(&result), 1000);
200 assert!(
201 result.len() >= 4,
202 "expected at least 4 slices, got {}",
203 result.len()
204 );
205 }
206
207 #[test]
208 fn test_sliced_leftovers_are_not_recombined() {
209 let batch = make_batch(1000);
215 let bytes = batch.get_array_memory_size();
216 let orig_data = batch.column(0).to_data();
217 let orig_buf = &orig_data.buffers()[0];
218 let orig_start = orig_buf.as_ptr() as usize;
219 let orig_end = orig_start + orig_buf.len();
220
221 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
222
223 assert_eq!(total_rows(&result), 1000);
224 assert!(result.len() >= 4);
225
226 for (i, b) in result.iter().enumerate() {
227 let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize;
228 assert!(
229 ptr >= orig_start && ptr < orig_end,
230 "slice {i} buffer at {ptr:#x} is outside the original allocation \
231 [{orig_start:#x}, {orig_end:#x}) — it was re-concatenated"
232 );
233 }
234 }
235
236 #[test]
237 fn test_flush_remainder_on_stream_end() {
238 let batch = make_batch(10);
240 let bytes = batch.get_array_memory_size();
241 let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200);
242 assert_eq!(result.len(), 1);
243 assert_eq!(result[0].num_rows(), 10);
244 }
245
246 #[test]
247 fn test_large_then_small_batches() {
248 let large = make_batch(1000);
251 let small_bytes = make_batch(10).get_array_memory_size();
252 let batches = vec![
253 large,
254 make_batch(10),
255 make_batch(10),
256 make_batch(10),
257 make_batch(10),
258 make_batch(10),
259 ];
260 let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100);
261 assert_eq!(total_rows(&result), 1050);
262 assert!(result.len() < 6);
266 }
267
268 #[test]
269 fn test_row_preservation_across_slicing() {
270 let batch = make_batch(237); let bytes = batch.get_array_memory_size();
274 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5);
275
276 assert_eq!(total_rows(&result), 237);
277
278 let values: Vec<i32> = result
279 .iter()
280 .flat_map(|b| {
281 b.column(0)
282 .as_any()
283 .downcast_ref::<Int32Array>()
284 .unwrap()
285 .values()
286 .iter()
287 .copied()
288 })
289 .collect();
290 let expected: Vec<i32> = (0..237).collect();
291 assert_eq!(values, expected);
292 }
293
294 #[test]
295 fn test_error_propagation() {
296 let input = stream::iter(vec![
297 Ok(make_batch(10)),
298 Err(ArrowError::ComputeError("boom".into())),
299 Ok(make_batch(10)),
300 ]);
301 let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX);
302 let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect());
303 assert!(results.iter().any(|r| r.is_err()));
304 }
305}