Skip to main content

datafusion_physical_plan/coalesce/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::RecordBatch;
19use arrow::compute::BatchCoalescer;
20use arrow::datatypes::SchemaRef;
21use datafusion_common::{Result, assert_or_internal_err};
22
23/// Concatenate multiple [`RecordBatch`]es and apply a limit
24///
25/// See [`BatchCoalescer`] for more details on how this works.
26#[derive(Debug)]
27pub struct LimitedBatchCoalescer {
28    /// The arrow structure that builds the output batches
29    inner: BatchCoalescer,
30    /// Total number of rows returned so far
31    total_rows: usize,
32    /// Limit: maximum number of rows to fetch, `None` means fetch all rows
33    fetch: Option<usize>,
34    /// Indicates if the coalescer is finished
35    finished: bool,
36}
37
38/// Status returned by [`LimitedBatchCoalescer::push_batch`]
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PushBatchStatus {
41    /// The limit has **not** been reached, and more batches can be pushed
42    Continue,
43    /// The limit **has** been reached after processing this batch
44    /// The caller should call [`LimitedBatchCoalescer::finish`]
45    /// to flush any buffered rows and stop pushing more batches.
46    LimitReached,
47}
48
49impl LimitedBatchCoalescer {
50    /// Create a new `BatchCoalescer`
51    ///
52    /// # Arguments
53    /// - `schema` - the schema of the output batches
54    /// - `target_batch_size` - the minimum number of rows for each
55    ///   output batch (until limit reached)
56    /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows
57    pub fn new(
58        schema: SchemaRef,
59        target_batch_size: usize,
60        fetch: Option<usize>,
61    ) -> Self {
62        Self {
63            inner: BatchCoalescer::new(schema, target_batch_size)
64                .with_biggest_coalesce_batch_size(Some(target_batch_size / 2)),
65            total_rows: 0,
66            fetch,
67            finished: false,
68        }
69    }
70
71    /// Return the schema of the output batches
72    pub fn schema(&self) -> SchemaRef {
73        self.inner.schema()
74    }
75
76    /// Pushes the next [`RecordBatch`] into the coalescer and returns its status.
77    ///
78    /// # Arguments
79    /// * `batch` - The [`RecordBatch`] to append.
80    ///
81    /// # Returns
82    /// * [`PushBatchStatus::Continue`] - More batches can still be pushed.
83    /// * [`PushBatchStatus::LimitReached`] - The row limit was reached after processing
84    ///   this batch. The caller should call [`Self::finish`] before retrieving the
85    ///   remaining buffered batches.
86    ///
87    /// # Errors
88    /// Returns an error if called after [`Self::finish`] or if the internal push
89    /// operation fails.
90    pub fn push_batch(&mut self, batch: RecordBatch) -> Result<PushBatchStatus> {
91        assert_or_internal_err!(
92            !self.finished,
93            "LimitedBatchCoalescer: cannot push batch after finish"
94        );
95
96        // if we are at the limit, return LimitReached
97        if let Some(fetch) = self.fetch {
98            // limit previously reached
99            if self.total_rows >= fetch {
100                return Ok(PushBatchStatus::LimitReached);
101            }
102
103            // limit now reached
104            if self.total_rows + batch.num_rows() >= fetch {
105                // Limit is reached
106                let remaining_rows = fetch - self.total_rows;
107                debug_assert!(remaining_rows > 0);
108
109                let batch_head = batch.slice(0, remaining_rows);
110                self.total_rows += batch_head.num_rows();
111                self.inner.push_batch(batch_head)?;
112                return Ok(PushBatchStatus::LimitReached);
113            }
114        }
115
116        // Limit not reached, push the entire batch
117        self.total_rows += batch.num_rows();
118        self.inner.push_batch(batch)?;
119
120        Ok(PushBatchStatus::Continue)
121    }
122
123    /// Return true if there is no data buffered
124    pub fn is_empty(&self) -> bool {
125        self.inner.is_empty()
126    }
127
128    /// Complete the current buffered batch and finish the coalescer
129    ///
130    /// Any subsequent calls to `push_batch()` will return an Err
131    pub fn finish(&mut self) -> Result<()> {
132        self.inner.finish_buffered_batch()?;
133        self.finished = true;
134        Ok(())
135    }
136
137    pub(crate) fn is_finished(&self) -> bool {
138        self.finished
139    }
140
141    /// Return the next completed batch, if any
142    pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
143        self.inner.next_completed_batch()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::ops::Range;
151    use std::sync::Arc;
152
153    use arrow::array::UInt32Array;
154    use arrow::compute::concat_batches;
155    use arrow::datatypes::{DataType, Field, Schema};
156
157    #[test]
158    fn test_coalesce() {
159        let batch = uint32_batch(0..8);
160        Test::new()
161            .with_batches(std::iter::repeat_n(batch, 10))
162            // expected output is batches of exactly 21 rows (except for the final batch)
163            .with_target_batch_size(21)
164            .with_expected_output_sizes(vec![21, 21, 21, 17])
165            .run()
166    }
167
168    #[test]
169    fn test_coalesce_with_fetch_larger_than_input_size() {
170        let batch = uint32_batch(0..8);
171        Test::new()
172            .with_batches(std::iter::repeat_n(batch, 10))
173            // input is 10 batches x 8 rows (80 rows) with fetch limit of 100
174            // expected to behave the same as `test_concat_batches`
175            .with_target_batch_size(21)
176            .with_fetch(Some(100))
177            .with_expected_output_sizes(vec![21, 21, 21, 17])
178            .run();
179    }
180
181    #[test]
182    fn test_coalesce_with_fetch_less_than_input_size() {
183        let batch = uint32_batch(0..8);
184        Test::new()
185            .with_batches(std::iter::repeat_n(batch, 10))
186            // input is 10 batches x 8 rows (80 rows) with fetch limit of 50
187            .with_target_batch_size(21)
188            .with_fetch(Some(50))
189            .with_expected_output_sizes(vec![21, 21, 8])
190            .run();
191    }
192
193    #[test]
194    fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
195        let batch = uint32_batch(0..8);
196        Test::new()
197            .with_batches(std::iter::repeat_n(batch, 10))
198            // input is 10 batches x 8 rows (80 rows) with fetch limit of 48
199            .with_target_batch_size(24)
200            .with_fetch(Some(48))
201            .with_expected_output_sizes(vec![24, 24])
202            .run();
203    }
204
205    #[test]
206    fn test_coalesce_with_fetch_less_target_batch_size() {
207        let batch = uint32_batch(0..8);
208        Test::new()
209            .with_batches(std::iter::repeat_n(batch, 10))
210            // input is 10 batches x 8 rows (80 rows) with fetch limit of 10
211            .with_target_batch_size(21)
212            .with_fetch(Some(10))
213            .with_expected_output_sizes(vec![10])
214            .run();
215    }
216
217    #[test]
218    fn test_coalesce_single_large_batch_over_fetch() {
219        let large_batch = uint32_batch(0..100);
220        Test::new()
221            .with_batch(large_batch)
222            .with_target_batch_size(20)
223            .with_fetch(Some(7))
224            .with_expected_output_sizes(vec![7])
225            .run()
226    }
227
228    /// Test for [`LimitedBatchCoalescer`]
229    ///
230    /// Pushes the input batches to the coalescer and verifies that the resulting
231    /// batches have the expected number of rows and contents.
232    #[derive(Debug, Clone, Default)]
233    struct Test {
234        /// Batches to feed to the coalescer. Tests must have at least one
235        /// schema
236        input_batches: Vec<RecordBatch>,
237        /// Expected output sizes of the resulting batches
238        expected_output_sizes: Vec<usize>,
239        /// target batch size
240        target_batch_size: usize,
241        /// Fetch (limit)
242        fetch: Option<usize>,
243    }
244
245    impl Test {
246        fn new() -> Self {
247            Self::default()
248        }
249
250        /// Set the target batch size
251        fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
252            self.target_batch_size = target_batch_size;
253            self
254        }
255
256        /// Set the fetch (limit)
257        fn with_fetch(mut self, fetch: Option<usize>) -> Self {
258            self.fetch = fetch;
259            self
260        }
261
262        /// Extend the input batches with `batch`
263        fn with_batch(mut self, batch: RecordBatch) -> Self {
264            self.input_batches.push(batch);
265            self
266        }
267
268        /// Extends the input batches with `batches`
269        fn with_batches(
270            mut self,
271            batches: impl IntoIterator<Item = RecordBatch>,
272        ) -> Self {
273            self.input_batches.extend(batches);
274            self
275        }
276
277        /// Extends `sizes` to expected output sizes
278        fn with_expected_output_sizes(
279            mut self,
280            sizes: impl IntoIterator<Item = usize>,
281        ) -> Self {
282            self.expected_output_sizes.extend(sizes);
283            self
284        }
285
286        /// Runs the test -- see documentation on [`Test`] for details
287        fn run(self) {
288            let Self {
289                input_batches,
290                target_batch_size,
291                fetch,
292                expected_output_sizes,
293            } = self;
294
295            let schema = input_batches[0].schema();
296
297            // create a single large input batch for output comparison
298            let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
299
300            let mut coalescer =
301                LimitedBatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch);
302
303            let mut output_batches = vec![];
304            for batch in input_batches {
305                match coalescer.push_batch(batch).unwrap() {
306                    PushBatchStatus::Continue => {
307                        // continue pushing batches
308                    }
309                    PushBatchStatus::LimitReached => {
310                        break;
311                    }
312                }
313            }
314            coalescer.finish().unwrap();
315            while let Some(batch) = coalescer.next_completed_batch() {
316                output_batches.push(batch);
317            }
318
319            let actual_output_sizes: Vec<usize> =
320                output_batches.iter().map(|b| b.num_rows()).collect();
321            assert_eq!(
322                expected_output_sizes, actual_output_sizes,
323                "Unexpected number of rows in output batches\n\
324                Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
325            );
326
327            // make sure we got the expected number of output batches and content
328            let mut starting_idx = 0;
329            assert_eq!(expected_output_sizes.len(), output_batches.len());
330            for (i, (expected_size, batch)) in
331                expected_output_sizes.iter().zip(output_batches).enumerate()
332            {
333                assert_eq!(
334                    *expected_size,
335                    batch.num_rows(),
336                    "Unexpected number of rows in Batch {i}"
337                );
338
339                // compare the contents of the batch (using `==` compares the
340                // underlying memory layout too)
341                let expected_batch =
342                    single_input_batch.slice(starting_idx, *expected_size);
343                let batch_strings = batch_to_pretty_strings(&batch);
344                let expected_batch_strings = batch_to_pretty_strings(&expected_batch);
345                let batch_strings = batch_strings.lines().collect::<Vec<_>>();
346                let expected_batch_strings =
347                    expected_batch_strings.lines().collect::<Vec<_>>();
348                assert_eq!(
349                    expected_batch_strings, batch_strings,
350                    "Unexpected content in Batch {i}:\
351                    \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
352                );
353                starting_idx += *expected_size;
354            }
355        }
356    }
357
358    /// Return a batch of  UInt32 with the specified range
359    fn uint32_batch(range: Range<u32>) -> RecordBatch {
360        let schema =
361            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
362
363        RecordBatch::try_new(
364            Arc::clone(&schema),
365            vec![Arc::new(UInt32Array::from_iter_values(range))],
366        )
367        .unwrap()
368    }
369
370    fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
371        arrow::util::pretty::pretty_format_batches(std::slice::from_ref(batch))
372            .unwrap()
373            .to_string()
374    }
375}