1use arrow_array::RecordBatch;
7use arrow_schema::{ArrowError, SchemaRef};
8use futures::stream::{self, Stream, StreamExt};
9use std::pin::Pin;
10
11use crate::deepcopy::deep_copy_batch_sliced;
12
13pub fn rechunk_stream_by_size<S, E>(
21 input: S,
22 input_schema: SchemaRef,
23 min_bytes: usize,
24 max_bytes: usize,
25) -> impl Stream<Item = Result<RecordBatch, E>>
26where
27 S: Stream<Item = Result<RecordBatch, E>>,
28 E: From<ArrowError>,
29{
30 rechunk_stream_by_size_inner(input, input_schema, min_bytes, max_bytes, false)
31}
32
33pub fn rechunk_stream_by_size_deep_copy<S, E>(
49 input: S,
50 input_schema: SchemaRef,
51 min_bytes: usize,
52 max_bytes: usize,
53) -> impl Stream<Item = Result<RecordBatch, E>>
54where
55 S: Stream<Item = Result<RecordBatch, E>>,
56 E: From<ArrowError>,
57{
58 rechunk_stream_by_size_inner(input, input_schema, min_bytes, max_bytes, true)
59}
60
61fn rechunk_stream_by_size_inner<S, E>(
62 input: S,
63 input_schema: SchemaRef,
64 min_bytes: usize,
65 max_bytes: usize,
66 deep_copy: bool,
67) -> impl Stream<Item = Result<RecordBatch, E>>
68where
69 S: Stream<Item = Result<RecordBatch, E>>,
70 E: From<ArrowError>,
71{
72 stream::try_unfold(
73 RechunkState {
74 input: Box::pin(input),
75 accumulated: Vec::new(),
76 acc_bytes: 0,
77 done: false,
78 input_schema,
79 min_bytes,
80 max_bytes,
81 deep_copy,
82 },
83 |mut state| async move {
84 if state.done && state.accumulated.is_empty() {
85 return Ok(None);
86 }
87
88 while !state.done && (state.accumulated.is_empty() || state.acc_bytes < state.min_bytes)
91 {
92 match state.input.next().await {
93 Some(Ok(batch)) => {
94 state.acc_bytes += batch.get_array_memory_size();
95 state.accumulated.push(batch);
96 }
97 Some(Err(e)) => return Err(e),
98 None => {
99 state.done = true;
100 }
101 }
102 }
103
104 if state.accumulated.is_empty() {
105 return Ok(None);
106 }
107
108 if state.accumulated.len() > 1
112 && state.accumulated[0].get_array_memory_size() >= state.min_bytes
113 {
114 let b = state.accumulated.remove(0);
115 state.acc_bytes -= b.get_array_memory_size();
116 return Ok(Some((b, state)));
117 }
118
119 let batch = if state.accumulated.len() == 1 {
120 state.accumulated.pop().unwrap()
121 } else {
122 let b =
123 arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated)
124 .map_err(E::from)?;
125 state.accumulated.clear();
126 b
127 };
128 state.acc_bytes = 0;
129
130 let mut slices =
132 slice_batch(batch, state.max_bytes, state.deep_copy).map_err(E::from)?;
133
134 if slices.len() == 1 {
135 Ok(Some((slices.pop().unwrap(), state)))
136 } else {
137 let first = slices.remove(0);
138
139 for a in &slices {
141 state.acc_bytes += a.get_array_memory_size();
142 }
143 state.accumulated = slices;
144
145 Ok(Some((first, state)))
146 }
147 },
148 )
149}
150
151fn slice_batch(
163 batch: RecordBatch,
164 max_bytes: usize,
165 deep_copy: bool,
166) -> Result<Vec<RecordBatch>, ArrowError> {
167 let batch_bytes = batch.get_array_memory_size();
168 let num_rows = batch.num_rows();
169
170 if batch_bytes <= max_bytes || num_rows <= 1 {
171 return Ok(vec![batch]);
172 }
173
174 let rows_per_chunk = (max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize;
175
176 let mut result = Vec::new();
177 let mut offset = 0;
178 while offset < num_rows {
179 let len = rows_per_chunk.min(num_rows - offset);
180 let slice = batch.slice(offset, len);
181 if deep_copy {
182 let copied = deep_copy_batch_sliced(&slice)?;
183 result.extend(slice_batch(copied, max_bytes, true)?);
186 } else {
187 result.push(slice);
188 }
189 offset += len;
190 }
191
192 Ok(result)
193}
194
195struct RechunkState<S> {
199 input: Pin<Box<S>>,
200 accumulated: Vec<RecordBatch>,
201 acc_bytes: usize,
202 done: bool,
203 input_schema: SchemaRef,
204 min_bytes: usize,
205 max_bytes: usize,
206 deep_copy: bool,
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 use std::sync::Arc;
214
215 use arrow_array::Int32Array;
216 use arrow_schema::{DataType, Field, Schema};
217 use futures::executor::block_on;
218
219 fn make_batch(num_rows: usize) -> RecordBatch {
220 let schema = test_schema();
221 let values: Vec<i32> = (0..num_rows as i32).collect();
222 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
223 }
224
225 fn test_schema() -> SchemaRef {
226 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
227 }
228
229 fn collect_rechunked(
230 batches: Vec<RecordBatch>,
231 min_bytes: usize,
232 max_bytes: usize,
233 ) -> Vec<RecordBatch> {
234 let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
235 let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes);
236 block_on(rechunked.collect::<Vec<_>>())
237 .into_iter()
238 .map(|r| r.unwrap())
239 .collect()
240 }
241
242 fn total_rows(batches: &[RecordBatch]) -> usize {
243 batches.iter().map(|b| b.num_rows()).sum()
244 }
245
246 #[test]
247 fn test_empty_stream() {
248 let result = collect_rechunked(vec![], 100, 200);
249 assert!(result.is_empty());
250 }
251
252 #[test]
253 fn test_single_batch_passthrough() {
254 let batch = make_batch(100);
255 let bytes = batch.get_array_memory_size();
256 let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2);
258 assert_eq!(result.len(), 1);
259 assert_eq!(result[0].num_rows(), 100);
260 }
261
262 #[test]
263 fn test_small_batches_concatenated() {
264 let one_batch_bytes = make_batch(10).get_array_memory_size();
265 let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect();
266 let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10);
268 assert_eq!(total_rows(&result), 80);
269 assert!(
271 result.len() < 8,
272 "expected fewer output batches, got {}",
273 result.len()
274 );
275 }
276
277 #[test]
278 fn test_large_batch_sliced() {
279 let batch = make_batch(1000);
280 let bytes = batch.get_array_memory_size();
281 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
282 assert_eq!(total_rows(&result), 1000);
283 assert!(
284 result.len() >= 4,
285 "expected at least 4 slices, got {}",
286 result.len()
287 );
288 }
289
290 #[test]
291 fn test_sliced_leftovers_are_not_recombined() {
292 let batch = make_batch(1000);
298 let bytes = batch.get_array_memory_size();
299 let orig_data = batch.column(0).to_data();
300 let orig_buf = &orig_data.buffers()[0];
301 let orig_start = orig_buf.as_ptr() as usize;
302 let orig_end = orig_start + orig_buf.len();
303
304 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
305
306 assert_eq!(total_rows(&result), 1000);
307 assert!(result.len() >= 4);
308
309 for (i, b) in result.iter().enumerate() {
310 let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize;
311 assert!(
312 ptr >= orig_start && ptr < orig_end,
313 "slice {i} buffer at {ptr:#x} is outside the original allocation \
314 [{orig_start:#x}, {orig_end:#x}) — it was re-concatenated"
315 );
316 }
317 }
318
319 #[test]
320 fn test_flush_remainder_on_stream_end() {
321 let batch = make_batch(10);
323 let bytes = batch.get_array_memory_size();
324 let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200);
325 assert_eq!(result.len(), 1);
326 assert_eq!(result[0].num_rows(), 10);
327 }
328
329 #[test]
330 fn test_large_then_small_batches() {
331 let large = make_batch(1000);
334 let small_bytes = make_batch(10).get_array_memory_size();
335 let batches = vec![
336 large,
337 make_batch(10),
338 make_batch(10),
339 make_batch(10),
340 make_batch(10),
341 make_batch(10),
342 ];
343 let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100);
344 assert_eq!(total_rows(&result), 1050);
345 assert!(result.len() < 6);
349 }
350
351 #[test]
352 fn test_row_preservation_across_slicing() {
353 let batch = make_batch(237); let bytes = batch.get_array_memory_size();
357 let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5);
358
359 assert_eq!(total_rows(&result), 237);
360
361 let values: Vec<i32> = result
362 .iter()
363 .flat_map(|b| {
364 b.column(0)
365 .as_any()
366 .downcast_ref::<Int32Array>()
367 .unwrap()
368 .values()
369 .iter()
370 .copied()
371 })
372 .collect();
373 let expected: Vec<i32> = (0..237).collect();
374 assert_eq!(values, expected);
375 }
376
377 #[test]
378 fn test_min_bytes_zero_still_yields_all_rows() {
379 let batches: Vec<_> = (0..5).map(|_| make_batch(100)).collect();
382 let batch_bytes = batches[0].get_array_memory_size();
383 let result = collect_rechunked(batches, 0, batch_bytes * 2);
384 assert_eq!(total_rows(&result), 500);
385 }
386
387 #[test]
388 fn test_min_bytes_zero_slices_oversized() {
389 let batch = make_batch(1000);
391 let bytes = batch.get_array_memory_size();
392 let result = collect_rechunked(vec![batch], 0, bytes / 4);
393 assert_eq!(total_rows(&result), 1000);
394 assert!(
395 result.len() >= 4,
396 "expected at least 4 slices, got {}",
397 result.len()
398 );
399 }
400
401 fn make_variable_batch(
405 num_rows: usize,
406 small_size: usize,
407 big_row_idx: usize,
408 big_size: usize,
409 ) -> RecordBatch {
410 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)]));
411 let values: Vec<String> = (0..num_rows)
412 .map(|i| {
413 if i == big_row_idx {
414 "X".repeat(big_size)
415 } else {
416 "x".repeat(small_size)
417 }
418 })
419 .collect();
420 let array = arrow_array::StringArray::from(values);
421 RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
422 }
423
424 fn variable_schema() -> SchemaRef {
425 Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)]))
426 }
427
428 fn collect_rechunked_variable(
429 batches: Vec<RecordBatch>,
430 min_bytes: usize,
431 max_bytes: usize,
432 ) -> Vec<RecordBatch> {
433 let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
434 let rechunked =
435 rechunk_stream_by_size_deep_copy(input, variable_schema(), min_bytes, max_bytes);
436 block_on(rechunked.collect::<Vec<_>>())
437 .into_iter()
438 .map(|r| r.unwrap())
439 .collect()
440 }
441
442 #[test]
443 fn test_oversized_row_at_end() {
444 let batch = make_variable_batch(100, 64, 99, 100 * 1024);
446 let max_bytes = 64 * 1024;
447 let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
448 assert_eq!(total_rows(&result), 100);
449 for (i, b) in result.iter().enumerate() {
450 let size = b.get_array_memory_size();
451 assert!(
452 size <= max_bytes || b.num_rows() == 1,
453 "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
454 b.num_rows()
455 );
456 }
457 }
458
459 #[test]
460 fn test_oversized_row_at_start() {
461 let batch = make_variable_batch(100, 64, 0, 100 * 1024);
463 let max_bytes = 64 * 1024;
464 let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
465 assert_eq!(total_rows(&result), 100);
466 for (i, b) in result.iter().enumerate() {
467 let size = b.get_array_memory_size();
468 assert!(
469 size <= max_bytes || b.num_rows() == 1,
470 "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
471 b.num_rows()
472 );
473 }
474 }
475
476 #[test]
477 fn test_oversized_row_in_middle() {
478 let batch = make_variable_batch(100, 64, 50, 100 * 1024);
480 let max_bytes = 64 * 1024;
481 let result = collect_rechunked_variable(vec![batch], 0, max_bytes);
482 assert_eq!(total_rows(&result), 100);
483 for (i, b) in result.iter().enumerate() {
484 let size = b.get_array_memory_size();
485 assert!(
486 size <= max_bytes || b.num_rows() == 1,
487 "batch {i} has {size} bytes (max {max_bytes}) and {} rows",
488 b.num_rows()
489 );
490 }
491 }
492
493 #[test]
494 fn test_error_propagation() {
495 let input = stream::iter(vec![
496 Ok(make_batch(10)),
497 Err(ArrowError::ComputeError("boom".into())),
498 Ok(make_batch(10)),
499 ]);
500 let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX);
501 let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect());
502 assert!(results.iter().any(|r| r.is_err()));
503 }
504}