Skip to main content

datafusion_functions_nested/
min_max.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_min and array_max functions.
19use crate::utils::make_scalar_function;
20use arrow::array::{
21    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, GenericListArray,
22    OffsetSizeTrait, PrimitiveBuilder, downcast_primitive,
23};
24use arrow::datatypes::DataType;
25use arrow::datatypes::DataType::{LargeList, List};
26use datafusion_common::Result;
27use datafusion_common::cast::{as_large_list_array, as_list_array};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{ScalarValue, exec_err, plan_err};
30use datafusion_doc::Documentation;
31use datafusion_expr::{
32    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch};
35use datafusion_macros::user_doc;
36use itertools::Itertools;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArrayMax,
41    array_max,
42    array,
43    "returns the maximum value in the array.",
44    array_max_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Returns the maximum value in the array.",
50    syntax_example = "array_max(array)",
51    sql_example = r#"```sql
52> select array_max([3,1,4,2]);
53+-----------------------------------------+
54| array_max(List([3,1,4,2]))              |
55+-----------------------------------------+
56| 4                                       |
57+-----------------------------------------+
58```"#,
59    argument(
60        name = "array",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct ArrayMax {
66    signature: Signature,
67    aliases: Vec<String>,
68}
69
70impl Default for ArrayMax {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl ArrayMax {
77    pub fn new() -> Self {
78        Self {
79            signature: Signature::array(Volatility::Immutable),
80            aliases: vec!["list_max".to_string()],
81        }
82    }
83}
84
85impl ScalarUDFImpl for ArrayMax {
86    fn name(&self) -> &str {
87        "array_max"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95        let [array] = take_function_args(self.name(), arg_types)?;
96        match array {
97            List(field) | LargeList(field) => Ok(field.data_type().clone()),
98            arg_type => plan_err!("{} does not support type {arg_type}", self.name()),
99        }
100    }
101
102    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103        make_scalar_function(array_max_inner)(&args.args)
104    }
105
106    fn aliases(&self) -> &[String] {
107        &self.aliases
108    }
109
110    fn documentation(&self) -> Option<&Documentation> {
111        self.doc()
112    }
113}
114
115fn array_max_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
116    let [array] = take_function_args("array_max", args)?;
117    match array.data_type() {
118        List(_) => array_min_max_helper(as_list_array(array)?, false),
119        LargeList(_) => array_min_max_helper(as_large_list_array(array)?, false),
120        arg_type => exec_err!("array_max does not support type: {arg_type}"),
121    }
122}
123
124make_udf_expr_and_func!(
125    ArrayMin,
126    array_min,
127    array,
128    "returns the minimum value in the array",
129    array_min_udf
130);
131#[user_doc(
132    doc_section(label = "Array Functions"),
133    description = "Returns the minimum value in the array.",
134    syntax_example = "array_min(array)",
135    sql_example = r#"```sql
136> select array_min([3,1,4,2]);
137+-----------------------------------------+
138| array_min(List([3,1,4,2]))              |
139+-----------------------------------------+
140| 1                                       |
141+-----------------------------------------+
142```"#,
143    argument(
144        name = "array",
145        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
146    )
147)]
148#[derive(Debug, PartialEq, Eq, Hash)]
149struct ArrayMin {
150    signature: Signature,
151}
152
153impl Default for ArrayMin {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl ArrayMin {
160    fn new() -> Self {
161        Self {
162            signature: Signature::array(Volatility::Immutable),
163        }
164    }
165}
166
167impl ScalarUDFImpl for ArrayMin {
168    fn name(&self) -> &str {
169        "array_min"
170    }
171
172    fn signature(&self) -> &Signature {
173        &self.signature
174    }
175
176    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
177        let [array] = take_function_args(self.name(), arg_types)?;
178        match array {
179            List(field) | LargeList(field) => Ok(field.data_type().clone()),
180            arg_type => plan_err!("{} does not support type {}", self.name(), arg_type),
181        }
182    }
183
184    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
185        make_scalar_function(array_min_inner)(&args.args)
186    }
187
188    fn documentation(&self) -> Option<&Documentation> {
189        self.doc()
190    }
191}
192
193fn array_min_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
194    let [array] = take_function_args("array_min", args)?;
195    match array.data_type() {
196        List(_) => array_min_max_helper(as_list_array(array)?, true),
197        LargeList(_) => array_min_max_helper(as_large_list_array(array)?, true),
198        arg_type => exec_err!("array_min does not support type: {arg_type}"),
199    }
200}
201
202fn array_min_max_helper<O: OffsetSizeTrait>(
203    array: &GenericListArray<O>,
204    is_min: bool,
205) -> Result<ArrayRef> {
206    // Try the primitive fast path first
207    if let Some(result) = try_primitive_array_min_max(array, is_min) {
208        return result;
209    }
210
211    // Fallback: per-row ScalarValue path for non-primitive types
212    let agg_fn = if is_min { min_batch } else { max_batch };
213    let null_value = ScalarValue::try_from(array.value_type())?;
214    let result_vec: Vec<ScalarValue> = array
215        .iter()
216        .map(|arr| arr.as_ref().map_or_else(|| Ok(null_value.clone()), agg_fn))
217        .try_collect()?;
218    ScalarValue::iter_to_array(result_vec)
219}
220
221/// Dispatches to a typed primitive min/max implementation, or returns `None` if
222/// the element type is not a primitive.
223fn try_primitive_array_min_max<O: OffsetSizeTrait>(
224    list_array: &GenericListArray<O>,
225    is_min: bool,
226) -> Option<Result<ArrayRef>> {
227    macro_rules! helper {
228        ($t:ty) => {
229            return Some(primitive_array_min_max::<O, $t>(list_array, is_min))
230        };
231    }
232    downcast_primitive! {
233        list_array.value_type() => (helper),
234        _ => {}
235    }
236    None
237}
238
239/// Threshold to switch from direct iteration to using `min` / `max` kernel from
240/// `arrow::compute`. The latter has enough per-invocation overhead that direct
241/// iteration is faster for small lists.
242const ARROW_COMPUTE_THRESHOLD: usize = 32;
243
244/// Computes min or max for each row of a primitive ListArray.
245fn primitive_array_min_max<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
246    list_array: &GenericListArray<O>,
247    is_min: bool,
248) -> Result<ArrayRef> {
249    let values_array = list_array.values().as_primitive::<T>();
250    let values_slice = values_array.values();
251    let values_nulls = values_array.nulls();
252    let mut result_builder = PrimitiveBuilder::<T>::with_capacity(list_array.len())
253        .with_data_type(values_array.data_type().clone());
254
255    for (row, w) in list_array.offsets().windows(2).enumerate() {
256        let row_result = if list_array.is_null(row) {
257            None
258        } else {
259            let start = w[0].as_usize();
260            let end = w[1].as_usize();
261            let len = end - start;
262
263            match len {
264                0 => None,
265                _ if len < ARROW_COMPUTE_THRESHOLD => {
266                    scalar_min_max::<T>(values_slice, values_nulls, start, end, is_min)
267                }
268                _ => {
269                    let slice = values_array.slice(start, len);
270                    if is_min {
271                        arrow::compute::min::<T>(&slice)
272                    } else {
273                        arrow::compute::max::<T>(&slice)
274                    }
275                }
276            }
277        };
278
279        result_builder.append_option(row_result);
280    }
281
282    Ok(Arc::new(result_builder.finish()) as ArrayRef)
283}
284
285/// Computes min or max for a single list row by directly scanning a slice of
286/// the flat values buffer.
287#[inline]
288fn scalar_min_max<T: ArrowPrimitiveType>(
289    values_slice: &[T::Native],
290    values_nulls: Option<&arrow::buffer::NullBuffer>,
291    start: usize,
292    end: usize,
293    is_min: bool,
294) -> Option<T::Native> {
295    let mut best: Option<T::Native> = None;
296    for (i, &val) in values_slice[start..end].iter().enumerate() {
297        if let Some(nulls) = values_nulls
298            && !nulls.is_valid(start + i)
299        {
300            continue;
301        }
302        let update_best = match best {
303            None => true,
304            Some(current) if is_min => val.is_lt(current),
305            Some(current) => val.is_gt(current),
306        };
307        if update_best {
308            best = Some(val);
309        }
310    }
311    best
312}