datafusion_functions_aggregate/min_max/
min_max_bytes.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray,
20    LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder,
21};
22use arrow::datatypes::DataType;
23use datafusion_common::{internal_err, Result};
24use datafusion_expr::{EmitTo, GroupsAccumulator};
25use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
26use std::mem::size_of;
27use std::sync::Arc;
28
29/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`],
30/// [`BinaryArray`], [`StringViewArray`], etc)
31///
32/// This implementation dispatches to the appropriate specialized code in
33/// [`MinMaxBytesState`] based on data type and comparison function
34///
35/// [`StringArray`]: arrow::array::StringArray
36/// [`BinaryArray`]: arrow::array::BinaryArray
37/// [`StringViewArray`]: arrow::array::StringViewArray
38#[derive(Debug)]
39pub(crate) struct MinMaxBytesAccumulator {
40    /// Inner data storage.
41    inner: MinMaxBytesState,
42    /// if true, is `MIN` otherwise is `MAX`
43    is_min: bool,
44}
45
46impl MinMaxBytesAccumulator {
47    /// Create a new accumulator for computing `min(val)`
48    pub fn new_min(data_type: DataType) -> Self {
49        Self {
50            inner: MinMaxBytesState::new(data_type),
51            is_min: true,
52        }
53    }
54
55    /// Create a new accumulator fo computing `max(val)`
56    pub fn new_max(data_type: DataType) -> Self {
57        Self {
58            inner: MinMaxBytesState::new(data_type),
59            is_min: false,
60        }
61    }
62}
63
64impl GroupsAccumulator for MinMaxBytesAccumulator {
65    fn update_batch(
66        &mut self,
67        values: &[ArrayRef],
68        group_indices: &[usize],
69        opt_filter: Option<&BooleanArray>,
70        total_num_groups: usize,
71    ) -> Result<()> {
72        let array = &values[0];
73        assert_eq!(array.len(), group_indices.len());
74        assert_eq!(array.data_type(), &self.inner.data_type);
75
76        // apply filter if needed
77        let array = apply_filter_as_nulls(array, opt_filter)?;
78
79        // dispatch to appropriate kernel / specialized implementation
80        fn string_min(a: &[u8], b: &[u8]) -> bool {
81            // safety: only called from this function, which ensures a and b come
82            // from an array with valid utf8 data
83            unsafe {
84                let a = std::str::from_utf8_unchecked(a);
85                let b = std::str::from_utf8_unchecked(b);
86                a < b
87            }
88        }
89        fn string_max(a: &[u8], b: &[u8]) -> bool {
90            // safety: only called from this function, which ensures a and b come
91            // from an array with valid utf8 data
92            unsafe {
93                let a = std::str::from_utf8_unchecked(a);
94                let b = std::str::from_utf8_unchecked(b);
95                a > b
96            }
97        }
98        fn binary_min(a: &[u8], b: &[u8]) -> bool {
99            a < b
100        }
101
102        fn binary_max(a: &[u8], b: &[u8]) -> bool {
103            a > b
104        }
105
106        fn str_to_bytes<'a>(
107            it: impl Iterator<Item = Option<&'a str>>,
108        ) -> impl Iterator<Item = Option<&'a [u8]>> {
109            it.map(|s| s.map(|s| s.as_bytes()))
110        }
111
112        match (self.is_min, &self.inner.data_type) {
113            // Utf8/LargeUtf8/Utf8View Min
114            (true, &DataType::Utf8) => self.inner.update_batch(
115                str_to_bytes(array.as_string::<i32>().iter()),
116                group_indices,
117                total_num_groups,
118                string_min,
119            ),
120            (true, &DataType::LargeUtf8) => self.inner.update_batch(
121                str_to_bytes(array.as_string::<i64>().iter()),
122                group_indices,
123                total_num_groups,
124                string_min,
125            ),
126            (true, &DataType::Utf8View) => self.inner.update_batch(
127                str_to_bytes(array.as_string_view().iter()),
128                group_indices,
129                total_num_groups,
130                string_min,
131            ),
132
133            // Utf8/LargeUtf8/Utf8View Max
134            (false, &DataType::Utf8) => self.inner.update_batch(
135                str_to_bytes(array.as_string::<i32>().iter()),
136                group_indices,
137                total_num_groups,
138                string_max,
139            ),
140            (false, &DataType::LargeUtf8) => self.inner.update_batch(
141                str_to_bytes(array.as_string::<i64>().iter()),
142                group_indices,
143                total_num_groups,
144                string_max,
145            ),
146            (false, &DataType::Utf8View) => self.inner.update_batch(
147                str_to_bytes(array.as_string_view().iter()),
148                group_indices,
149                total_num_groups,
150                string_max,
151            ),
152
153            // Binary/LargeBinary/BinaryView Min
154            (true, &DataType::Binary) => self.inner.update_batch(
155                array.as_binary::<i32>().iter(),
156                group_indices,
157                total_num_groups,
158                binary_min,
159            ),
160            (true, &DataType::LargeBinary) => self.inner.update_batch(
161                array.as_binary::<i64>().iter(),
162                group_indices,
163                total_num_groups,
164                binary_min,
165            ),
166            (true, &DataType::BinaryView) => self.inner.update_batch(
167                array.as_binary_view().iter(),
168                group_indices,
169                total_num_groups,
170                binary_min,
171            ),
172
173            // Binary/LargeBinary/BinaryView Max
174            (false, &DataType::Binary) => self.inner.update_batch(
175                array.as_binary::<i32>().iter(),
176                group_indices,
177                total_num_groups,
178                binary_max,
179            ),
180            (false, &DataType::LargeBinary) => self.inner.update_batch(
181                array.as_binary::<i64>().iter(),
182                group_indices,
183                total_num_groups,
184                binary_max,
185            ),
186            (false, &DataType::BinaryView) => self.inner.update_batch(
187                array.as_binary_view().iter(),
188                group_indices,
189                total_num_groups,
190                binary_max,
191            ),
192
193            _ => internal_err!(
194                "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})",
195                self.is_min,
196                self.inner.data_type
197            ),
198        }
199    }
200
201    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
202        let (data_capacity, min_maxes) = self.inner.emit_to(emit_to);
203
204        // Convert the Vec of bytes to a vec of Strings (at no cost)
205        fn bytes_to_str(
206            min_maxes: Vec<Option<Vec<u8>>>,
207        ) -> impl Iterator<Item = Option<String>> {
208            min_maxes.into_iter().map(|opt| {
209                opt.map(|bytes| {
210                    // Safety: only called on data added from update_batch which ensures
211                    // the input type matched the output type
212                    unsafe { String::from_utf8_unchecked(bytes) }
213                })
214            })
215        }
216
217        let result: ArrayRef = match self.inner.data_type {
218            DataType::Utf8 => {
219                let mut builder =
220                    StringBuilder::with_capacity(min_maxes.len(), data_capacity);
221                for opt in bytes_to_str(min_maxes) {
222                    match opt {
223                        None => builder.append_null(),
224                        Some(s) => builder.append_value(s.as_str()),
225                    }
226                }
227                Arc::new(builder.finish())
228            }
229            DataType::LargeUtf8 => {
230                let mut builder =
231                    LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity);
232                for opt in bytes_to_str(min_maxes) {
233                    match opt {
234                        None => builder.append_null(),
235                        Some(s) => builder.append_value(s.as_str()),
236                    }
237                }
238                Arc::new(builder.finish())
239            }
240            DataType::Utf8View => {
241                let block_size = capacity_to_view_block_size(data_capacity);
242
243                let mut builder = StringViewBuilder::with_capacity(min_maxes.len())
244                    .with_fixed_block_size(block_size);
245                for opt in bytes_to_str(min_maxes) {
246                    match opt {
247                        None => builder.append_null(),
248                        Some(s) => builder.append_value(s.as_str()),
249                    }
250                }
251                Arc::new(builder.finish())
252            }
253            DataType::Binary => {
254                let mut builder =
255                    BinaryBuilder::with_capacity(min_maxes.len(), data_capacity);
256                for opt in min_maxes {
257                    match opt {
258                        None => builder.append_null(),
259                        Some(s) => builder.append_value(s.as_ref() as &[u8]),
260                    }
261                }
262                Arc::new(builder.finish())
263            }
264            DataType::LargeBinary => {
265                let mut builder =
266                    LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity);
267                for opt in min_maxes {
268                    match opt {
269                        None => builder.append_null(),
270                        Some(s) => builder.append_value(s.as_ref() as &[u8]),
271                    }
272                }
273                Arc::new(builder.finish())
274            }
275            DataType::BinaryView => {
276                let block_size = capacity_to_view_block_size(data_capacity);
277
278                let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len())
279                    .with_fixed_block_size(block_size);
280                for opt in min_maxes {
281                    match opt {
282                        None => builder.append_null(),
283                        Some(s) => builder.append_value(s.as_ref() as &[u8]),
284                    }
285                }
286                Arc::new(builder.finish())
287            }
288            _ => {
289                return internal_err!(
290                    "Unexpected data type for MinMaxBytesAccumulator: {:?}",
291                    self.inner.data_type
292                );
293            }
294        };
295
296        assert_eq!(&self.inner.data_type, result.data_type());
297        Ok(result)
298    }
299
300    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
301        // min/max are their own states (no transition needed)
302        self.evaluate(emit_to).map(|arr| vec![arr])
303    }
304
305    fn merge_batch(
306        &mut self,
307        values: &[ArrayRef],
308        group_indices: &[usize],
309        opt_filter: Option<&BooleanArray>,
310        total_num_groups: usize,
311    ) -> Result<()> {
312        // min/max are their own states (no transition needed)
313        self.update_batch(values, group_indices, opt_filter, total_num_groups)
314    }
315
316    fn convert_to_state(
317        &self,
318        values: &[ArrayRef],
319        opt_filter: Option<&BooleanArray>,
320    ) -> Result<Vec<ArrayRef>> {
321        // Min/max do not change the values as they are their own states
322        // apply the filter by combining with the null mask, if any
323        let output = apply_filter_as_nulls(&values[0], opt_filter)?;
324        Ok(vec![output])
325    }
326
327    fn supports_convert_to_state(&self) -> bool {
328        true
329    }
330
331    fn size(&self) -> usize {
332        self.inner.size()
333    }
334}
335
336/// Returns the block size in (contiguous buffer size) to use
337/// for a given data capacity (total string length)
338///
339/// This is a heuristic to avoid allocating too many small buffers
340fn capacity_to_view_block_size(data_capacity: usize) -> u32 {
341    let max_block_size = 2 * 1024 * 1024;
342    // Avoid block size equal to zero when calling `with_fixed_block_size()`.
343    if data_capacity == 0 {
344        return 1;
345    }
346    if let Ok(block_size) = u32::try_from(data_capacity) {
347        block_size.min(max_block_size)
348    } else {
349        max_block_size
350    }
351}
352
353/// Stores internal Min/Max state for "bytes" types.
354///
355/// This implementation is general and stores the minimum/maximum for each
356/// groups in an individual byte array, which balances allocations and memory
357/// fragmentation (aka garbage).
358///
359/// ```text
360///                    ┌─────────────────────────────────┐
361///   ┌─────┐    ┌────▶│Option<Vec<u8>> (["A"])          │───────────▶   "A"
362///   │  0  │────┘     └─────────────────────────────────┘
363///   ├─────┤          ┌─────────────────────────────────┐
364///   │  1  │─────────▶│Option<Vec<u8>> (["Z"])          │───────────▶   "Z"
365///   └─────┘          └─────────────────────────────────┘               ...
366///     ...               ...
367///   ┌─────┐          ┌────────────────────────────────┐
368///   │ N-2 │─────────▶│Option<Vec<u8>> (["A"])         │────────────▶   "A"
369///   ├─────┤          └────────────────────────────────┘
370///   │ N-1 │────┐     ┌────────────────────────────────┐
371///   └─────┘    └────▶│Option<Vec<u8>> (["Q"])         │────────────▶   "Q"
372///                    └────────────────────────────────┘
373///
374///                      min_max: Vec<Option<Vec<u8>>
375/// ```
376///
377/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially
378/// more efficient implementations (e.g. by managing a string data buffer
379/// directly), but then garbage collection, memory management, and final array
380/// construction becomes more complex.
381///
382/// See discussion on <https://github.com/apache/datafusion/issues/6906>
383#[derive(Debug)]
384struct MinMaxBytesState {
385    /// The minimum/maximum value for each group
386    min_max: Vec<Option<Vec<u8>>>,
387    /// The data type of the array
388    data_type: DataType,
389    /// The total bytes of the string data (for pre-allocating the final array,
390    /// and tracking memory usage)
391    total_data_bytes: usize,
392}
393
394#[derive(Debug, Clone, Copy)]
395enum MinMaxLocation<'a> {
396    /// the min/max value is stored in the existing `min_max` array
397    ExistingMinMax,
398    /// the min/max value is stored in the input array at the given index
399    Input(&'a [u8]),
400}
401
402/// Implement the MinMaxBytesAccumulator with a comparison function
403/// for comparing strings
404impl MinMaxBytesState {
405    /// Create a new MinMaxBytesAccumulator
406    ///
407    /// # Arguments:
408    /// * `data_type`: The data type of the arrays that will be passed to this accumulator
409    fn new(data_type: DataType) -> Self {
410        Self {
411            min_max: vec![],
412            data_type,
413            total_data_bytes: 0,
414        }
415    }
416
417    /// Set the specified group to the given value, updating memory usage appropriately
418    fn set_value(&mut self, group_index: usize, new_val: &[u8]) {
419        match self.min_max[group_index].as_mut() {
420            None => {
421                self.min_max[group_index] = Some(new_val.to_vec());
422                self.total_data_bytes += new_val.len();
423            }
424            Some(existing_val) => {
425                // Copy data over to avoid re-allocating
426                self.total_data_bytes -= existing_val.len();
427                self.total_data_bytes += new_val.len();
428                existing_val.clear();
429                existing_val.extend_from_slice(new_val);
430            }
431        }
432    }
433
434    /// Updates the min/max values for the given string values
435    ///
436    /// `cmp` is the  comparison function to use, called like `cmp(new_val, existing_val)`
437    /// returns true if the `new_val` should replace `existing_val`
438    fn update_batch<'a, F, I>(
439        &mut self,
440        iter: I,
441        group_indices: &[usize],
442        total_num_groups: usize,
443        mut cmp: F,
444    ) -> Result<()>
445    where
446        F: FnMut(&[u8], &[u8]) -> bool + Send + Sync,
447        I: IntoIterator<Item = Option<&'a [u8]>>,
448    {
449        self.min_max.resize(total_num_groups, None);
450        // Minimize value copies by calculating the new min/maxes for each group
451        // in this batch (either the existing min/max or the new input value)
452        // and updating the owned values in `self.min_maxes` at most once
453        let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups];
454
455        // Figure out the new min value for each group
456        for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) {
457            let group_index = *group_index;
458            let Some(new_val) = new_val else {
459                continue; // skip nulls
460            };
461
462            let existing_val = match locations[group_index] {
463                // previous input value was the min/max, so compare it
464                MinMaxLocation::Input(existing_val) => existing_val,
465                MinMaxLocation::ExistingMinMax => {
466                    let Some(existing_val) = self.min_max[group_index].as_ref() else {
467                        // no existing min/max, so this is the new min/max
468                        locations[group_index] = MinMaxLocation::Input(new_val);
469                        continue;
470                    };
471                    existing_val.as_ref()
472                }
473            };
474
475            // Compare the new value to the existing value, replacing if necessary
476            if cmp(new_val, existing_val) {
477                locations[group_index] = MinMaxLocation::Input(new_val);
478            }
479        }
480
481        // Update self.min_max with any new min/max values we found in the input
482        for (group_index, location) in locations.iter().enumerate() {
483            match location {
484                MinMaxLocation::ExistingMinMax => {}
485                MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val),
486            }
487        }
488        Ok(())
489    }
490
491    /// Emits the specified min_max values
492    ///
493    /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes
494    ///
495    /// - `data_capacity`: the total length of all strings and their contents,
496    /// - `min_maxes`: the actual min/max values for each group
497    fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec<Option<Vec<u8>>>) {
498        match emit_to {
499            EmitTo::All => {
500                (
501                    std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max
502                    std::mem::take(&mut self.min_max),
503                )
504            }
505            EmitTo::First(n) => {
506                let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect();
507                let first_data_capacity: usize = first_min_maxes
508                    .iter()
509                    .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0))
510                    .sum();
511                self.total_data_bytes -= first_data_capacity;
512                (first_data_capacity, first_min_maxes)
513            }
514        }
515    }
516
517    fn size(&self) -> usize {
518        self.total_data_bytes + self.min_max.len() * size_of::<Option<Vec<u8>>>()
519    }
520}