datafusion_physical_plan/coalesce/
mod.rs1use arrow::array::RecordBatch;
19use arrow::compute::BatchCoalescer;
20use arrow::datatypes::SchemaRef;
21use datafusion_common::{Result, assert_or_internal_err};
22
23#[derive(Debug)]
27pub struct LimitedBatchCoalescer {
28 inner: BatchCoalescer,
30 total_rows: usize,
32 fetch: Option<usize>,
34 finished: bool,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PushBatchStatus {
41 Continue,
43 LimitReached,
47}
48
49impl LimitedBatchCoalescer {
50 pub fn new(
58 schema: SchemaRef,
59 target_batch_size: usize,
60 fetch: Option<usize>,
61 ) -> Self {
62 Self {
63 inner: BatchCoalescer::new(schema, target_batch_size)
64 .with_biggest_coalesce_batch_size(Some(target_batch_size / 2)),
65 total_rows: 0,
66 fetch,
67 finished: false,
68 }
69 }
70
71 pub fn schema(&self) -> SchemaRef {
73 self.inner.schema()
74 }
75
76 pub fn push_batch(&mut self, batch: RecordBatch) -> Result<PushBatchStatus> {
91 assert_or_internal_err!(
92 !self.finished,
93 "LimitedBatchCoalescer: cannot push batch after finish"
94 );
95
96 if let Some(fetch) = self.fetch {
98 if self.total_rows >= fetch {
100 return Ok(PushBatchStatus::LimitReached);
101 }
102
103 if self.total_rows + batch.num_rows() >= fetch {
105 let remaining_rows = fetch - self.total_rows;
107 debug_assert!(remaining_rows > 0);
108
109 let batch_head = batch.slice(0, remaining_rows);
110 self.total_rows += batch_head.num_rows();
111 self.inner.push_batch(batch_head)?;
112 return Ok(PushBatchStatus::LimitReached);
113 }
114 }
115
116 self.total_rows += batch.num_rows();
118 self.inner.push_batch(batch)?;
119
120 Ok(PushBatchStatus::Continue)
121 }
122
123 pub fn is_empty(&self) -> bool {
125 self.inner.is_empty()
126 }
127
128 pub fn finish(&mut self) -> Result<()> {
132 self.inner.finish_buffered_batch()?;
133 self.finished = true;
134 Ok(())
135 }
136
137 pub(crate) fn is_finished(&self) -> bool {
138 self.finished
139 }
140
141 pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
143 self.inner.next_completed_batch()
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::ops::Range;
151 use std::sync::Arc;
152
153 use arrow::array::UInt32Array;
154 use arrow::compute::concat_batches;
155 use arrow::datatypes::{DataType, Field, Schema};
156
157 #[test]
158 fn test_coalesce() {
159 let batch = uint32_batch(0..8);
160 Test::new()
161 .with_batches(std::iter::repeat_n(batch, 10))
162 .with_target_batch_size(21)
164 .with_expected_output_sizes(vec![21, 21, 21, 17])
165 .run()
166 }
167
168 #[test]
169 fn test_coalesce_with_fetch_larger_than_input_size() {
170 let batch = uint32_batch(0..8);
171 Test::new()
172 .with_batches(std::iter::repeat_n(batch, 10))
173 .with_target_batch_size(21)
176 .with_fetch(Some(100))
177 .with_expected_output_sizes(vec![21, 21, 21, 17])
178 .run();
179 }
180
181 #[test]
182 fn test_coalesce_with_fetch_less_than_input_size() {
183 let batch = uint32_batch(0..8);
184 Test::new()
185 .with_batches(std::iter::repeat_n(batch, 10))
186 .with_target_batch_size(21)
188 .with_fetch(Some(50))
189 .with_expected_output_sizes(vec![21, 21, 8])
190 .run();
191 }
192
193 #[test]
194 fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
195 let batch = uint32_batch(0..8);
196 Test::new()
197 .with_batches(std::iter::repeat_n(batch, 10))
198 .with_target_batch_size(24)
200 .with_fetch(Some(48))
201 .with_expected_output_sizes(vec![24, 24])
202 .run();
203 }
204
205 #[test]
206 fn test_coalesce_with_fetch_less_target_batch_size() {
207 let batch = uint32_batch(0..8);
208 Test::new()
209 .with_batches(std::iter::repeat_n(batch, 10))
210 .with_target_batch_size(21)
212 .with_fetch(Some(10))
213 .with_expected_output_sizes(vec![10])
214 .run();
215 }
216
217 #[test]
218 fn test_coalesce_single_large_batch_over_fetch() {
219 let large_batch = uint32_batch(0..100);
220 Test::new()
221 .with_batch(large_batch)
222 .with_target_batch_size(20)
223 .with_fetch(Some(7))
224 .with_expected_output_sizes(vec![7])
225 .run()
226 }
227
228 #[derive(Debug, Clone, Default)]
233 struct Test {
234 input_batches: Vec<RecordBatch>,
237 expected_output_sizes: Vec<usize>,
239 target_batch_size: usize,
241 fetch: Option<usize>,
243 }
244
245 impl Test {
246 fn new() -> Self {
247 Self::default()
248 }
249
250 fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
252 self.target_batch_size = target_batch_size;
253 self
254 }
255
256 fn with_fetch(mut self, fetch: Option<usize>) -> Self {
258 self.fetch = fetch;
259 self
260 }
261
262 fn with_batch(mut self, batch: RecordBatch) -> Self {
264 self.input_batches.push(batch);
265 self
266 }
267
268 fn with_batches(
270 mut self,
271 batches: impl IntoIterator<Item = RecordBatch>,
272 ) -> Self {
273 self.input_batches.extend(batches);
274 self
275 }
276
277 fn with_expected_output_sizes(
279 mut self,
280 sizes: impl IntoIterator<Item = usize>,
281 ) -> Self {
282 self.expected_output_sizes.extend(sizes);
283 self
284 }
285
286 fn run(self) {
288 let Self {
289 input_batches,
290 target_batch_size,
291 fetch,
292 expected_output_sizes,
293 } = self;
294
295 let schema = input_batches[0].schema();
296
297 let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
299
300 let mut coalescer =
301 LimitedBatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch);
302
303 let mut output_batches = vec![];
304 for batch in input_batches {
305 match coalescer.push_batch(batch).unwrap() {
306 PushBatchStatus::Continue => {
307 }
309 PushBatchStatus::LimitReached => {
310 break;
311 }
312 }
313 }
314 coalescer.finish().unwrap();
315 while let Some(batch) = coalescer.next_completed_batch() {
316 output_batches.push(batch);
317 }
318
319 let actual_output_sizes: Vec<usize> =
320 output_batches.iter().map(|b| b.num_rows()).collect();
321 assert_eq!(
322 expected_output_sizes, actual_output_sizes,
323 "Unexpected number of rows in output batches\n\
324 Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
325 );
326
327 let mut starting_idx = 0;
329 assert_eq!(expected_output_sizes.len(), output_batches.len());
330 for (i, (expected_size, batch)) in
331 expected_output_sizes.iter().zip(output_batches).enumerate()
332 {
333 assert_eq!(
334 *expected_size,
335 batch.num_rows(),
336 "Unexpected number of rows in Batch {i}"
337 );
338
339 let expected_batch =
342 single_input_batch.slice(starting_idx, *expected_size);
343 let batch_strings = batch_to_pretty_strings(&batch);
344 let expected_batch_strings = batch_to_pretty_strings(&expected_batch);
345 let batch_strings = batch_strings.lines().collect::<Vec<_>>();
346 let expected_batch_strings =
347 expected_batch_strings.lines().collect::<Vec<_>>();
348 assert_eq!(
349 expected_batch_strings, batch_strings,
350 "Unexpected content in Batch {i}:\
351 \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
352 );
353 starting_idx += *expected_size;
354 }
355 }
356 }
357
358 fn uint32_batch(range: Range<u32>) -> RecordBatch {
360 let schema =
361 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
362
363 RecordBatch::try_new(
364 Arc::clone(&schema),
365 vec![Arc::new(UInt32Array::from_iter_values(range))],
366 )
367 .unwrap()
368 }
369
370 fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
371 arrow::util::pretty::pretty_format_batches(std::slice::from_ref(batch))
372 .unwrap()
373 .to_string()
374 }
375}