datafusion_functions_aggregate/
min_max.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
19//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
20
21mod min_max_bytes;
22
23use arrow::array::{
24    ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
25    Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
26    Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
27    IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
28    LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
29    Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
30    TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
31    TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
32};
33use arrow::compute;
34use arrow::datatypes::{
35    DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
36    Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, UInt16Type, UInt32Type,
37    UInt64Type, UInt8Type,
38};
39use datafusion_common::stats::Precision;
40use datafusion_common::{
41    downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
42};
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
44use datafusion_physical_expr::expressions;
45use std::cmp::Ordering;
46use std::fmt::Debug;
47
48use arrow::datatypes::i256;
49use arrow::datatypes::{
50    Date32Type, Date64Type, Time32MillisecondType, Time32SecondType,
51    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
52    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
53};
54
55use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
56use datafusion_common::ScalarValue;
57use datafusion_expr::{
58    function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation,
59    SetMonotonicity, Signature, Volatility,
60};
61use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
62use datafusion_macros::user_doc;
63use half::f16;
64use std::mem::size_of_val;
65use std::ops::Deref;
66
67fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
68    // make sure that the input types only has one element.
69    if input_types.len() != 1 {
70        return exec_err!(
71            "min/max was called with {} arguments. It requires only 1.",
72            input_types.len()
73        );
74    }
75    // min and max support the dictionary data type
76    // unpack the dictionary to get the value
77    match &input_types[0] {
78        DataType::Dictionary(_, dict_value_type) => {
79            // TODO add checker, if the value type is complex data type
80            Ok(vec![dict_value_type.deref().clone()])
81        }
82        // TODO add checker for datatype which min and max supported
83        // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
84        _ => Ok(input_types.to_vec()),
85    }
86}
87
88#[user_doc(
89    doc_section(label = "General Functions"),
90    description = "Returns the maximum value in the specified column.",
91    syntax_example = "max(expression)",
92    sql_example = r#"```sql
93> SELECT max(column_name) FROM table_name;
94+----------------------+
95| max(column_name)      |
96+----------------------+
97| 150                  |
98+----------------------+
99```"#,
100    standard_argument(name = "expression",)
101)]
102// MAX aggregate UDF
103#[derive(Debug)]
104pub struct Max {
105    signature: Signature,
106}
107
108impl Max {
109    pub fn new() -> Self {
110        Self {
111            signature: Signature::user_defined(Volatility::Immutable),
112        }
113    }
114}
115
116impl Default for Max {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
122/// the specified [`ArrowPrimitiveType`].
123///
124/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
125macro_rules! primitive_max_accumulator {
126    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
127        Ok(Box::new(
128            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
129                match (new).partial_cmp(cur) {
130                    Some(Ordering::Greater) | None => {
131                        // new is Greater or None
132                        *cur = new
133                    }
134                    _ => {}
135                }
136            })
137            // Initialize each accumulator to $NATIVE::MIN
138            .with_starting_value($NATIVE::MIN),
139        ))
140    }};
141}
142
143/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
144/// the specified [`ArrowPrimitiveType`].
145///
146///
147/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
148macro_rules! primitive_min_accumulator {
149    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
150        Ok(Box::new(
151            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
152                match (new).partial_cmp(cur) {
153                    Some(Ordering::Less) | None => {
154                        // new is Less or NaN
155                        *cur = new
156                    }
157                    _ => {}
158                }
159            })
160            // Initialize each accumulator to $NATIVE::MAX
161            .with_starting_value($NATIVE::MAX),
162        ))
163    }};
164}
165
166trait FromColumnStatistics {
167    fn value_from_column_statistics(
168        &self,
169        stats: &ColumnStatistics,
170    ) -> Option<ScalarValue>;
171
172    fn value_from_statistics(
173        &self,
174        statistics_args: &StatisticsArgs,
175    ) -> Option<ScalarValue> {
176        if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
177            match *num_rows {
178                0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
179                value if value > 0 => {
180                    let col_stats = &statistics_args.statistics.column_statistics;
181                    if statistics_args.exprs.len() == 1 {
182                        // TODO optimize with exprs other than Column
183                        if let Some(col_expr) = statistics_args.exprs[0]
184                            .as_any()
185                            .downcast_ref::<expressions::Column>()
186                        {
187                            return self.value_from_column_statistics(
188                                &col_stats[col_expr.index()],
189                            );
190                        }
191                    }
192                }
193                _ => {}
194            }
195        }
196        None
197    }
198}
199
200impl FromColumnStatistics for Max {
201    fn value_from_column_statistics(
202        &self,
203        col_stats: &ColumnStatistics,
204    ) -> Option<ScalarValue> {
205        if let Precision::Exact(ref val) = col_stats.max_value {
206            if !val.is_null() {
207                return Some(val.clone());
208            }
209        }
210        None
211    }
212}
213
214impl AggregateUDFImpl for Max {
215    fn as_any(&self) -> &dyn std::any::Any {
216        self
217    }
218
219    fn name(&self) -> &str {
220        "max"
221    }
222
223    fn signature(&self) -> &Signature {
224        &self.signature
225    }
226
227    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
228        Ok(arg_types[0].to_owned())
229    }
230
231    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
232        Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?))
233    }
234
235    fn aliases(&self) -> &[String] {
236        &[]
237    }
238
239    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
240        use DataType::*;
241        matches!(
242            args.return_type,
243            Int8 | Int16
244                | Int32
245                | Int64
246                | UInt8
247                | UInt16
248                | UInt32
249                | UInt64
250                | Float16
251                | Float32
252                | Float64
253                | Decimal128(_, _)
254                | Decimal256(_, _)
255                | Date32
256                | Date64
257                | Time32(_)
258                | Time64(_)
259                | Timestamp(_, _)
260                | Utf8
261                | LargeUtf8
262                | Utf8View
263                | Binary
264                | LargeBinary
265                | BinaryView
266        )
267    }
268
269    fn create_groups_accumulator(
270        &self,
271        args: AccumulatorArgs,
272    ) -> Result<Box<dyn GroupsAccumulator>> {
273        use DataType::*;
274        use TimeUnit::*;
275        let data_type = args.return_type;
276        match data_type {
277            Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
278            Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
279            Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
280            Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
281            UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
282            UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
283            UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
284            UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
285            Float16 => {
286                primitive_max_accumulator!(data_type, f16, Float16Type)
287            }
288            Float32 => {
289                primitive_max_accumulator!(data_type, f32, Float32Type)
290            }
291            Float64 => {
292                primitive_max_accumulator!(data_type, f64, Float64Type)
293            }
294            Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
295            Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
296            Time32(Second) => {
297                primitive_max_accumulator!(data_type, i32, Time32SecondType)
298            }
299            Time32(Millisecond) => {
300                primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
301            }
302            Time64(Microsecond) => {
303                primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
304            }
305            Time64(Nanosecond) => {
306                primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
307            }
308            Timestamp(Second, _) => {
309                primitive_max_accumulator!(data_type, i64, TimestampSecondType)
310            }
311            Timestamp(Millisecond, _) => {
312                primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
313            }
314            Timestamp(Microsecond, _) => {
315                primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
316            }
317            Timestamp(Nanosecond, _) => {
318                primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
319            }
320            Decimal128(_, _) => {
321                primitive_max_accumulator!(data_type, i128, Decimal128Type)
322            }
323            Decimal256(_, _) => {
324                primitive_max_accumulator!(data_type, i256, Decimal256Type)
325            }
326            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
327                Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
328            }
329
330            // This is only reached if groups_accumulator_supported is out of sync
331            _ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
332        }
333    }
334
335    fn create_sliding_accumulator(
336        &self,
337        args: AccumulatorArgs,
338    ) -> Result<Box<dyn Accumulator>> {
339        Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?))
340    }
341
342    fn is_descending(&self) -> Option<bool> {
343        Some(true)
344    }
345
346    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
347        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
348    }
349
350    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
351        get_min_max_result_type(arg_types)
352    }
353    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
354        datafusion_expr::ReversedUDAF::Identical
355    }
356    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
357        self.value_from_statistics(statistics_args)
358    }
359
360    fn documentation(&self) -> Option<&Documentation> {
361        self.doc()
362    }
363
364    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
365        // `MAX` is monotonically increasing as it always increases or stays
366        // the same as new values are seen.
367        SetMonotonicity::Increasing
368    }
369}
370
371// Statically-typed version of min/max(array) -> ScalarValue for string types
372macro_rules! typed_min_max_batch_string {
373    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
374        let array = downcast_value!($VALUES, $ARRAYTYPE);
375        let value = compute::$OP(array);
376        let value = value.and_then(|e| Some(e.to_string()));
377        ScalarValue::$SCALAR(value)
378    }};
379}
380// Statically-typed version of min/max(array) -> ScalarValue for binary types.
381macro_rules! typed_min_max_batch_binary {
382    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
383        let array = downcast_value!($VALUES, $ARRAYTYPE);
384        let value = compute::$OP(array);
385        let value = value.and_then(|e| Some(e.to_vec()));
386        ScalarValue::$SCALAR(value)
387    }};
388}
389
390// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
391macro_rules! typed_min_max_batch {
392    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
393        let array = downcast_value!($VALUES, $ARRAYTYPE);
394        let value = compute::$OP(array);
395        ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*)
396    }};
397}
398
399// Statically-typed version of min/max(array) -> ScalarValue  for non-string types.
400// this is a macro to support both operations (min and max).
401macro_rules! min_max_batch {
402    ($VALUES:expr, $OP:ident) => {{
403        match $VALUES.data_type() {
404            DataType::Null => ScalarValue::Null,
405            DataType::Decimal128(precision, scale) => {
406                typed_min_max_batch!(
407                    $VALUES,
408                    Decimal128Array,
409                    Decimal128,
410                    $OP,
411                    precision,
412                    scale
413                )
414            }
415            DataType::Decimal256(precision, scale) => {
416                typed_min_max_batch!(
417                    $VALUES,
418                    Decimal256Array,
419                    Decimal256,
420                    $OP,
421                    precision,
422                    scale
423                )
424            }
425            // all types that have a natural order
426            DataType::Float64 => {
427                typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
428            }
429            DataType::Float32 => {
430                typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
431            }
432            DataType::Float16 => {
433                typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
434            }
435            DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
436            DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
437            DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
438            DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP),
439            DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP),
440            DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP),
441            DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP),
442            DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP),
443            DataType::Timestamp(TimeUnit::Second, tz_opt) => {
444                typed_min_max_batch!(
445                    $VALUES,
446                    TimestampSecondArray,
447                    TimestampSecond,
448                    $OP,
449                    tz_opt
450                )
451            }
452            DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!(
453                $VALUES,
454                TimestampMillisecondArray,
455                TimestampMillisecond,
456                $OP,
457                tz_opt
458            ),
459            DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!(
460                $VALUES,
461                TimestampMicrosecondArray,
462                TimestampMicrosecond,
463                $OP,
464                tz_opt
465            ),
466            DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!(
467                $VALUES,
468                TimestampNanosecondArray,
469                TimestampNanosecond,
470                $OP,
471                tz_opt
472            ),
473            DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP),
474            DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP),
475            DataType::Time32(TimeUnit::Second) => {
476                typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP)
477            }
478            DataType::Time32(TimeUnit::Millisecond) => {
479                typed_min_max_batch!(
480                    $VALUES,
481                    Time32MillisecondArray,
482                    Time32Millisecond,
483                    $OP
484                )
485            }
486            DataType::Time64(TimeUnit::Microsecond) => {
487                typed_min_max_batch!(
488                    $VALUES,
489                    Time64MicrosecondArray,
490                    Time64Microsecond,
491                    $OP
492                )
493            }
494            DataType::Time64(TimeUnit::Nanosecond) => {
495                typed_min_max_batch!(
496                    $VALUES,
497                    Time64NanosecondArray,
498                    Time64Nanosecond,
499                    $OP
500                )
501            }
502            DataType::Interval(IntervalUnit::YearMonth) => {
503                typed_min_max_batch!(
504                    $VALUES,
505                    IntervalYearMonthArray,
506                    IntervalYearMonth,
507                    $OP
508                )
509            }
510            DataType::Interval(IntervalUnit::DayTime) => {
511                typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP)
512            }
513            DataType::Interval(IntervalUnit::MonthDayNano) => {
514                typed_min_max_batch!(
515                    $VALUES,
516                    IntervalMonthDayNanoArray,
517                    IntervalMonthDayNano,
518                    $OP
519                )
520            }
521            other => {
522                // This should have been handled before
523                return internal_err!(
524                    "Min/Max accumulator not implemented for type {:?}",
525                    other
526                );
527            }
528        }
529    }};
530}
531
532/// dynamically-typed min(array) -> ScalarValue
533fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
534    Ok(match values.data_type() {
535        DataType::Utf8 => {
536            typed_min_max_batch_string!(values, StringArray, Utf8, min_string)
537        }
538        DataType::LargeUtf8 => {
539            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
540        }
541        DataType::Utf8View => {
542            typed_min_max_batch_string!(
543                values,
544                StringViewArray,
545                Utf8View,
546                min_string_view
547            )
548        }
549        DataType::Boolean => {
550            typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
551        }
552        DataType::Binary => {
553            typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary)
554        }
555        DataType::LargeBinary => {
556            typed_min_max_batch_binary!(
557                &values,
558                LargeBinaryArray,
559                LargeBinary,
560                min_binary
561            )
562        }
563        DataType::BinaryView => {
564            typed_min_max_batch_binary!(
565                &values,
566                BinaryViewArray,
567                BinaryView,
568                min_binary_view
569            )
570        }
571        _ => min_max_batch!(values, min),
572    })
573}
574
575/// dynamically-typed max(array) -> ScalarValue
576fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
577    Ok(match values.data_type() {
578        DataType::Utf8 => {
579            typed_min_max_batch_string!(values, StringArray, Utf8, max_string)
580        }
581        DataType::LargeUtf8 => {
582            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
583        }
584        DataType::Utf8View => {
585            typed_min_max_batch_string!(
586                values,
587                StringViewArray,
588                Utf8View,
589                max_string_view
590            )
591        }
592        DataType::Boolean => {
593            typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
594        }
595        DataType::Binary => {
596            typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
597        }
598        DataType::BinaryView => {
599            typed_min_max_batch_binary!(
600                &values,
601                BinaryViewArray,
602                BinaryView,
603                max_binary_view
604            )
605        }
606        DataType::LargeBinary => {
607            typed_min_max_batch_binary!(
608                &values,
609                LargeBinaryArray,
610                LargeBinary,
611                max_binary
612            )
613        }
614        _ => min_max_batch!(values, max),
615    })
616}
617
618// min/max of two non-string scalar values.
619macro_rules! typed_min_max {
620    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
621        ScalarValue::$SCALAR(
622            match ($VALUE, $DELTA) {
623                (None, None) => None,
624                (Some(a), None) => Some(*a),
625                (None, Some(b)) => Some(*b),
626                (Some(a), Some(b)) => Some((*a).$OP(*b)),
627            },
628            $($EXTRA_ARGS.clone()),*
629        )
630    }};
631}
632macro_rules! typed_min_max_float {
633    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
634        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
635            (None, None) => None,
636            (Some(a), None) => Some(*a),
637            (None, Some(b)) => Some(*b),
638            (Some(a), Some(b)) => match a.total_cmp(b) {
639                choose_min_max!($OP) => Some(*b),
640                _ => Some(*a),
641            },
642        })
643    }};
644}
645
646// min/max of two scalar string values.
647macro_rules! typed_min_max_string {
648    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
649        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
650            (None, None) => None,
651            (Some(a), None) => Some(a.clone()),
652            (None, Some(b)) => Some(b.clone()),
653            (Some(a), Some(b)) => Some((a).$OP(b).clone()),
654        })
655    }};
656}
657
658macro_rules! choose_min_max {
659    (min) => {
660        std::cmp::Ordering::Greater
661    };
662    (max) => {
663        std::cmp::Ordering::Less
664    };
665}
666
667macro_rules! interval_min_max {
668    ($OP:tt, $LHS:expr, $RHS:expr) => {{
669        match $LHS.partial_cmp(&$RHS) {
670            Some(choose_min_max!($OP)) => $RHS.clone(),
671            Some(_) => $LHS.clone(),
672            None => {
673                return internal_err!("Comparison error while computing interval min/max")
674            }
675        }
676    }};
677}
678
679// min/max of two scalar values of the same type
680macro_rules! min_max {
681    ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
682        Ok(match ($VALUE, $DELTA) {
683            (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
684            (
685                lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
686                rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
687            ) => {
688                if lhsp.eq(rhsp) && lhss.eq(rhss) {
689                    typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss)
690                } else {
691                    return internal_err!(
692                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
693                    (lhs, rhs)
694                );
695                }
696            }
697            (
698                lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss),
699                rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss)
700            ) => {
701                if lhsp.eq(rhsp) && lhss.eq(rhss) {
702                    typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss)
703                } else {
704                    return internal_err!(
705                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
706                    (lhs, rhs)
707                );
708                }
709            }
710            (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => {
711                typed_min_max!(lhs, rhs, Boolean, $OP)
712            }
713            (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
714                typed_min_max_float!(lhs, rhs, Float64, $OP)
715            }
716            (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
717                typed_min_max_float!(lhs, rhs, Float32, $OP)
718            }
719            (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
720                typed_min_max_float!(lhs, rhs, Float16, $OP)
721            }
722            (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
723                typed_min_max!(lhs, rhs, UInt64, $OP)
724            }
725            (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
726                typed_min_max!(lhs, rhs, UInt32, $OP)
727            }
728            (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
729                typed_min_max!(lhs, rhs, UInt16, $OP)
730            }
731            (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
732                typed_min_max!(lhs, rhs, UInt8, $OP)
733            }
734            (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
735                typed_min_max!(lhs, rhs, Int64, $OP)
736            }
737            (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
738                typed_min_max!(lhs, rhs, Int32, $OP)
739            }
740            (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
741                typed_min_max!(lhs, rhs, Int16, $OP)
742            }
743            (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
744                typed_min_max!(lhs, rhs, Int8, $OP)
745            }
746            (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => {
747                typed_min_max_string!(lhs, rhs, Utf8, $OP)
748            }
749            (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
750                typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
751            }
752            (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
753                typed_min_max_string!(lhs, rhs, Utf8View, $OP)
754            }
755            (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
756                typed_min_max_string!(lhs, rhs, Binary, $OP)
757            }
758            (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
759                typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
760            }
761            (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
762                typed_min_max_string!(lhs, rhs, BinaryView, $OP)
763            }
764            (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
765                typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
766            }
767            (
768                ScalarValue::TimestampMillisecond(lhs, l_tz),
769                ScalarValue::TimestampMillisecond(rhs, _),
770            ) => {
771                typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz)
772            }
773            (
774                ScalarValue::TimestampMicrosecond(lhs, l_tz),
775                ScalarValue::TimestampMicrosecond(rhs, _),
776            ) => {
777                typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz)
778            }
779            (
780                ScalarValue::TimestampNanosecond(lhs, l_tz),
781                ScalarValue::TimestampNanosecond(rhs, _),
782            ) => {
783                typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz)
784            }
785            (
786                ScalarValue::Date32(lhs),
787                ScalarValue::Date32(rhs),
788            ) => {
789                typed_min_max!(lhs, rhs, Date32, $OP)
790            }
791            (
792                ScalarValue::Date64(lhs),
793                ScalarValue::Date64(rhs),
794            ) => {
795                typed_min_max!(lhs, rhs, Date64, $OP)
796            }
797            (
798                ScalarValue::Time32Second(lhs),
799                ScalarValue::Time32Second(rhs),
800            ) => {
801                typed_min_max!(lhs, rhs, Time32Second, $OP)
802            }
803            (
804                ScalarValue::Time32Millisecond(lhs),
805                ScalarValue::Time32Millisecond(rhs),
806            ) => {
807                typed_min_max!(lhs, rhs, Time32Millisecond, $OP)
808            }
809            (
810                ScalarValue::Time64Microsecond(lhs),
811                ScalarValue::Time64Microsecond(rhs),
812            ) => {
813                typed_min_max!(lhs, rhs, Time64Microsecond, $OP)
814            }
815            (
816                ScalarValue::Time64Nanosecond(lhs),
817                ScalarValue::Time64Nanosecond(rhs),
818            ) => {
819                typed_min_max!(lhs, rhs, Time64Nanosecond, $OP)
820            }
821            (
822                ScalarValue::IntervalYearMonth(lhs),
823                ScalarValue::IntervalYearMonth(rhs),
824            ) => {
825                typed_min_max!(lhs, rhs, IntervalYearMonth, $OP)
826            }
827            (
828                ScalarValue::IntervalMonthDayNano(lhs),
829                ScalarValue::IntervalMonthDayNano(rhs),
830            ) => {
831                typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP)
832            }
833            (
834                ScalarValue::IntervalDayTime(lhs),
835                ScalarValue::IntervalDayTime(rhs),
836            ) => {
837                typed_min_max!(lhs, rhs, IntervalDayTime, $OP)
838            }
839            (
840                ScalarValue::IntervalYearMonth(_),
841                ScalarValue::IntervalMonthDayNano(_),
842            ) | (
843                ScalarValue::IntervalYearMonth(_),
844                ScalarValue::IntervalDayTime(_),
845            ) | (
846                ScalarValue::IntervalMonthDayNano(_),
847                ScalarValue::IntervalDayTime(_),
848            ) | (
849                ScalarValue::IntervalMonthDayNano(_),
850                ScalarValue::IntervalYearMonth(_),
851            ) | (
852                ScalarValue::IntervalDayTime(_),
853                ScalarValue::IntervalYearMonth(_),
854            ) | (
855                ScalarValue::IntervalDayTime(_),
856                ScalarValue::IntervalMonthDayNano(_),
857            ) => {
858                interval_min_max!($OP, $VALUE, $DELTA)
859            }
860                    (
861                ScalarValue::DurationSecond(lhs),
862                ScalarValue::DurationSecond(rhs),
863            ) => {
864                typed_min_max!(lhs, rhs, DurationSecond, $OP)
865            }
866                                (
867                ScalarValue::DurationMillisecond(lhs),
868                ScalarValue::DurationMillisecond(rhs),
869            ) => {
870                typed_min_max!(lhs, rhs, DurationMillisecond, $OP)
871            }
872                                (
873                ScalarValue::DurationMicrosecond(lhs),
874                ScalarValue::DurationMicrosecond(rhs),
875            ) => {
876                typed_min_max!(lhs, rhs, DurationMicrosecond, $OP)
877            }
878                                        (
879                ScalarValue::DurationNanosecond(lhs),
880                ScalarValue::DurationNanosecond(rhs),
881            ) => {
882                typed_min_max!(lhs, rhs, DurationNanosecond, $OP)
883            }
884            e => {
885                return internal_err!(
886                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
887                    e
888                )
889            }
890        })
891    }};
892}
893
894/// An accumulator to compute the maximum value
895#[derive(Debug)]
896pub struct MaxAccumulator {
897    max: ScalarValue,
898}
899
900impl MaxAccumulator {
901    /// new max accumulator
902    pub fn try_new(datatype: &DataType) -> Result<Self> {
903        Ok(Self {
904            max: ScalarValue::try_from(datatype)?,
905        })
906    }
907}
908
909impl Accumulator for MaxAccumulator {
910    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
911        let values = &values[0];
912        let delta = &max_batch(values)?;
913        let new_max: Result<ScalarValue, DataFusionError> =
914            min_max!(&self.max, delta, max);
915        self.max = new_max?;
916        Ok(())
917    }
918
919    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
920        self.update_batch(states)
921    }
922
923    fn state(&mut self) -> Result<Vec<ScalarValue>> {
924        Ok(vec![self.evaluate()?])
925    }
926    fn evaluate(&mut self) -> Result<ScalarValue> {
927        Ok(self.max.clone())
928    }
929
930    fn size(&self) -> usize {
931        size_of_val(self) - size_of_val(&self.max) + self.max.size()
932    }
933}
934
935#[derive(Debug)]
936pub struct SlidingMaxAccumulator {
937    max: ScalarValue,
938    moving_max: MovingMax<ScalarValue>,
939}
940
941impl SlidingMaxAccumulator {
942    /// new max accumulator
943    pub fn try_new(datatype: &DataType) -> Result<Self> {
944        Ok(Self {
945            max: ScalarValue::try_from(datatype)?,
946            moving_max: MovingMax::<ScalarValue>::new(),
947        })
948    }
949}
950
951impl Accumulator for SlidingMaxAccumulator {
952    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
953        for idx in 0..values[0].len() {
954            let val = ScalarValue::try_from_array(&values[0], idx)?;
955            self.moving_max.push(val);
956        }
957        if let Some(res) = self.moving_max.max() {
958            self.max = res.clone();
959        }
960        Ok(())
961    }
962
963    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
964        for _idx in 0..values[0].len() {
965            (self.moving_max).pop();
966        }
967        if let Some(res) = self.moving_max.max() {
968            self.max = res.clone();
969        }
970        Ok(())
971    }
972
973    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
974        self.update_batch(states)
975    }
976
977    fn state(&mut self) -> Result<Vec<ScalarValue>> {
978        Ok(vec![self.max.clone()])
979    }
980
981    fn evaluate(&mut self) -> Result<ScalarValue> {
982        Ok(self.max.clone())
983    }
984
985    fn supports_retract_batch(&self) -> bool {
986        true
987    }
988
989    fn size(&self) -> usize {
990        size_of_val(self) - size_of_val(&self.max) + self.max.size()
991    }
992}
993
994#[user_doc(
995    doc_section(label = "General Functions"),
996    description = "Returns the minimum value in the specified column.",
997    syntax_example = "min(expression)",
998    sql_example = r#"```sql
999> SELECT min(column_name) FROM table_name;
1000+----------------------+
1001| min(column_name)      |
1002+----------------------+
1003| 12                   |
1004+----------------------+
1005```"#,
1006    standard_argument(name = "expression",)
1007)]
1008#[derive(Debug)]
1009pub struct Min {
1010    signature: Signature,
1011}
1012
1013impl Min {
1014    pub fn new() -> Self {
1015        Self {
1016            signature: Signature::user_defined(Volatility::Immutable),
1017        }
1018    }
1019}
1020
1021impl Default for Min {
1022    fn default() -> Self {
1023        Self::new()
1024    }
1025}
1026
1027impl FromColumnStatistics for Min {
1028    fn value_from_column_statistics(
1029        &self,
1030        col_stats: &ColumnStatistics,
1031    ) -> Option<ScalarValue> {
1032        if let Precision::Exact(ref val) = col_stats.min_value {
1033            if !val.is_null() {
1034                return Some(val.clone());
1035            }
1036        }
1037        None
1038    }
1039}
1040
1041impl AggregateUDFImpl for Min {
1042    fn as_any(&self) -> &dyn std::any::Any {
1043        self
1044    }
1045
1046    fn name(&self) -> &str {
1047        "min"
1048    }
1049
1050    fn signature(&self) -> &Signature {
1051        &self.signature
1052    }
1053
1054    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1055        Ok(arg_types[0].to_owned())
1056    }
1057
1058    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1059        Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?))
1060    }
1061
1062    fn aliases(&self) -> &[String] {
1063        &[]
1064    }
1065
1066    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1067        use DataType::*;
1068        matches!(
1069            args.return_type,
1070            Int8 | Int16
1071                | Int32
1072                | Int64
1073                | UInt8
1074                | UInt16
1075                | UInt32
1076                | UInt64
1077                | Float16
1078                | Float32
1079                | Float64
1080                | Decimal128(_, _)
1081                | Decimal256(_, _)
1082                | Date32
1083                | Date64
1084                | Time32(_)
1085                | Time64(_)
1086                | Timestamp(_, _)
1087                | Utf8
1088                | LargeUtf8
1089                | Utf8View
1090                | Binary
1091                | LargeBinary
1092                | BinaryView
1093        )
1094    }
1095
1096    fn create_groups_accumulator(
1097        &self,
1098        args: AccumulatorArgs,
1099    ) -> Result<Box<dyn GroupsAccumulator>> {
1100        use DataType::*;
1101        use TimeUnit::*;
1102        let data_type = args.return_type;
1103        match data_type {
1104            Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
1105            Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
1106            Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
1107            Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
1108            UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
1109            UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
1110            UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
1111            UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
1112            Float16 => {
1113                primitive_min_accumulator!(data_type, f16, Float16Type)
1114            }
1115            Float32 => {
1116                primitive_min_accumulator!(data_type, f32, Float32Type)
1117            }
1118            Float64 => {
1119                primitive_min_accumulator!(data_type, f64, Float64Type)
1120            }
1121            Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
1122            Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
1123            Time32(Second) => {
1124                primitive_min_accumulator!(data_type, i32, Time32SecondType)
1125            }
1126            Time32(Millisecond) => {
1127                primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
1128            }
1129            Time64(Microsecond) => {
1130                primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
1131            }
1132            Time64(Nanosecond) => {
1133                primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
1134            }
1135            Timestamp(Second, _) => {
1136                primitive_min_accumulator!(data_type, i64, TimestampSecondType)
1137            }
1138            Timestamp(Millisecond, _) => {
1139                primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
1140            }
1141            Timestamp(Microsecond, _) => {
1142                primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
1143            }
1144            Timestamp(Nanosecond, _) => {
1145                primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
1146            }
1147            Decimal128(_, _) => {
1148                primitive_min_accumulator!(data_type, i128, Decimal128Type)
1149            }
1150            Decimal256(_, _) => {
1151                primitive_min_accumulator!(data_type, i256, Decimal256Type)
1152            }
1153            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
1154                Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
1155            }
1156
1157            // This is only reached if groups_accumulator_supported is out of sync
1158            _ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
1159        }
1160    }
1161
1162    fn create_sliding_accumulator(
1163        &self,
1164        args: AccumulatorArgs,
1165    ) -> Result<Box<dyn Accumulator>> {
1166        Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?))
1167    }
1168
1169    fn is_descending(&self) -> Option<bool> {
1170        Some(false)
1171    }
1172
1173    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1174        self.value_from_statistics(statistics_args)
1175    }
1176    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
1177        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
1178    }
1179
1180    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1181        get_min_max_result_type(arg_types)
1182    }
1183
1184    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1185        datafusion_expr::ReversedUDAF::Identical
1186    }
1187
1188    fn documentation(&self) -> Option<&Documentation> {
1189        self.doc()
1190    }
1191
1192    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
1193        // `MIN` is monotonically decreasing as it always decreases or stays
1194        // the same as new values are seen.
1195        SetMonotonicity::Decreasing
1196    }
1197}
1198
1199/// An accumulator to compute the minimum value
1200#[derive(Debug)]
1201pub struct MinAccumulator {
1202    min: ScalarValue,
1203}
1204
1205impl MinAccumulator {
1206    /// new min accumulator
1207    pub fn try_new(datatype: &DataType) -> Result<Self> {
1208        Ok(Self {
1209            min: ScalarValue::try_from(datatype)?,
1210        })
1211    }
1212}
1213
1214impl Accumulator for MinAccumulator {
1215    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1216        Ok(vec![self.evaluate()?])
1217    }
1218
1219    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1220        let values = &values[0];
1221        let delta = &min_batch(values)?;
1222        let new_min: Result<ScalarValue, DataFusionError> =
1223            min_max!(&self.min, delta, min);
1224        self.min = new_min?;
1225        Ok(())
1226    }
1227
1228    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1229        self.update_batch(states)
1230    }
1231
1232    fn evaluate(&mut self) -> Result<ScalarValue> {
1233        Ok(self.min.clone())
1234    }
1235
1236    fn size(&self) -> usize {
1237        size_of_val(self) - size_of_val(&self.min) + self.min.size()
1238    }
1239}
1240
1241#[derive(Debug)]
1242pub struct SlidingMinAccumulator {
1243    min: ScalarValue,
1244    moving_min: MovingMin<ScalarValue>,
1245}
1246
1247impl SlidingMinAccumulator {
1248    pub fn try_new(datatype: &DataType) -> Result<Self> {
1249        Ok(Self {
1250            min: ScalarValue::try_from(datatype)?,
1251            moving_min: MovingMin::<ScalarValue>::new(),
1252        })
1253    }
1254}
1255
1256impl Accumulator for SlidingMinAccumulator {
1257    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1258        Ok(vec![self.min.clone()])
1259    }
1260
1261    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1262        for idx in 0..values[0].len() {
1263            let val = ScalarValue::try_from_array(&values[0], idx)?;
1264            if !val.is_null() {
1265                self.moving_min.push(val);
1266            }
1267        }
1268        if let Some(res) = self.moving_min.min() {
1269            self.min = res.clone();
1270        }
1271        Ok(())
1272    }
1273
1274    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1275        for idx in 0..values[0].len() {
1276            let val = ScalarValue::try_from_array(&values[0], idx)?;
1277            if !val.is_null() {
1278                (self.moving_min).pop();
1279            }
1280        }
1281        if let Some(res) = self.moving_min.min() {
1282            self.min = res.clone();
1283        }
1284        Ok(())
1285    }
1286
1287    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1288        self.update_batch(states)
1289    }
1290
1291    fn evaluate(&mut self) -> Result<ScalarValue> {
1292        Ok(self.min.clone())
1293    }
1294
1295    fn supports_retract_batch(&self) -> bool {
1296        true
1297    }
1298
1299    fn size(&self) -> usize {
1300        size_of_val(self) - size_of_val(&self.min) + self.min.size()
1301    }
1302}
1303
1304/// Keep track of the minimum value in a sliding window.
1305///
1306/// The implementation is taken from <https://github.com/spebern/moving_min_max/blob/master/src/lib.rs>
1307///
1308/// `moving min max` provides one data structure for keeping track of the
1309/// minimum value and one for keeping track of the maximum value in a sliding
1310/// window.
1311///
1312/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty,
1313/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from
1314/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue,
1315/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values.
1316///
1317/// The complexity of the operations are
1318/// - O(1) for getting the minimum/maximum
1319/// - O(1) for push
1320/// - amortized O(1) for pop
1321///
1322/// ```
1323/// # use datafusion_functions_aggregate::min_max::MovingMin;
1324/// let mut moving_min = MovingMin::<i32>::new();
1325/// moving_min.push(2);
1326/// moving_min.push(1);
1327/// moving_min.push(3);
1328///
1329/// assert_eq!(moving_min.min(), Some(&1));
1330/// assert_eq!(moving_min.pop(), Some(2));
1331///
1332/// assert_eq!(moving_min.min(), Some(&1));
1333/// assert_eq!(moving_min.pop(), Some(1));
1334///
1335/// assert_eq!(moving_min.min(), Some(&3));
1336/// assert_eq!(moving_min.pop(), Some(3));
1337///
1338/// assert_eq!(moving_min.min(), None);
1339/// assert_eq!(moving_min.pop(), None);
1340/// ```
1341#[derive(Debug)]
1342pub struct MovingMin<T> {
1343    push_stack: Vec<(T, T)>,
1344    pop_stack: Vec<(T, T)>,
1345}
1346
1347impl<T: Clone + PartialOrd> Default for MovingMin<T> {
1348    fn default() -> Self {
1349        Self {
1350            push_stack: Vec::new(),
1351            pop_stack: Vec::new(),
1352        }
1353    }
1354}
1355
1356impl<T: Clone + PartialOrd> MovingMin<T> {
1357    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1358    /// window.
1359    #[inline]
1360    pub fn new() -> Self {
1361        Self::default()
1362    }
1363
1364    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1365    /// window with `capacity` allocated slots.
1366    #[inline]
1367    pub fn with_capacity(capacity: usize) -> Self {
1368        Self {
1369            push_stack: Vec::with_capacity(capacity),
1370            pop_stack: Vec::with_capacity(capacity),
1371        }
1372    }
1373
1374    /// Returns the minimum of the sliding window or `None` if the window is
1375    /// empty.
1376    #[inline]
1377    pub fn min(&self) -> Option<&T> {
1378        match (self.push_stack.last(), self.pop_stack.last()) {
1379            (None, None) => None,
1380            (Some((_, min)), None) => Some(min),
1381            (None, Some((_, min))) => Some(min),
1382            (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }),
1383        }
1384    }
1385
1386    /// Pushes a new element into the sliding window.
1387    #[inline]
1388    pub fn push(&mut self, val: T) {
1389        self.push_stack.push(match self.push_stack.last() {
1390            Some((_, min)) => {
1391                if val > *min {
1392                    (val, min.clone())
1393                } else {
1394                    (val.clone(), val)
1395                }
1396            }
1397            None => (val.clone(), val),
1398        });
1399    }
1400
1401    /// Removes and returns the last value of the sliding window.
1402    #[inline]
1403    pub fn pop(&mut self) -> Option<T> {
1404        if self.pop_stack.is_empty() {
1405            match self.push_stack.pop() {
1406                Some((val, _)) => {
1407                    let mut last = (val.clone(), val);
1408                    self.pop_stack.push(last.clone());
1409                    while let Some((val, _)) = self.push_stack.pop() {
1410                        let min = if last.1 < val {
1411                            last.1.clone()
1412                        } else {
1413                            val.clone()
1414                        };
1415                        last = (val.clone(), min);
1416                        self.pop_stack.push(last.clone());
1417                    }
1418                }
1419                None => return None,
1420            }
1421        }
1422        self.pop_stack.pop().map(|(val, _)| val)
1423    }
1424
1425    /// Returns the number of elements stored in the sliding window.
1426    #[inline]
1427    pub fn len(&self) -> usize {
1428        self.push_stack.len() + self.pop_stack.len()
1429    }
1430
1431    /// Returns `true` if the moving window contains no elements.
1432    #[inline]
1433    pub fn is_empty(&self) -> bool {
1434        self.len() == 0
1435    }
1436}
1437
1438/// Keep track of the maximum value in a sliding window.
1439///
1440/// See [`MovingMin`] for more details.
1441///
1442/// ```
1443/// # use datafusion_functions_aggregate::min_max::MovingMax;
1444/// let mut moving_max = MovingMax::<i32>::new();
1445/// moving_max.push(2);
1446/// moving_max.push(3);
1447/// moving_max.push(1);
1448///
1449/// assert_eq!(moving_max.max(), Some(&3));
1450/// assert_eq!(moving_max.pop(), Some(2));
1451///
1452/// assert_eq!(moving_max.max(), Some(&3));
1453/// assert_eq!(moving_max.pop(), Some(3));
1454///
1455/// assert_eq!(moving_max.max(), Some(&1));
1456/// assert_eq!(moving_max.pop(), Some(1));
1457///
1458/// assert_eq!(moving_max.max(), None);
1459/// assert_eq!(moving_max.pop(), None);
1460/// ```
1461#[derive(Debug)]
1462pub struct MovingMax<T> {
1463    push_stack: Vec<(T, T)>,
1464    pop_stack: Vec<(T, T)>,
1465}
1466
1467impl<T: Clone + PartialOrd> Default for MovingMax<T> {
1468    fn default() -> Self {
1469        Self {
1470            push_stack: Vec::new(),
1471            pop_stack: Vec::new(),
1472        }
1473    }
1474}
1475
1476impl<T: Clone + PartialOrd> MovingMax<T> {
1477    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window.
1478    #[inline]
1479    pub fn new() -> Self {
1480        Self::default()
1481    }
1482
1483    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with
1484    /// `capacity` allocated slots.
1485    #[inline]
1486    pub fn with_capacity(capacity: usize) -> Self {
1487        Self {
1488            push_stack: Vec::with_capacity(capacity),
1489            pop_stack: Vec::with_capacity(capacity),
1490        }
1491    }
1492
1493    /// Returns the maximum of the sliding window or `None` if the window is empty.
1494    #[inline]
1495    pub fn max(&self) -> Option<&T> {
1496        match (self.push_stack.last(), self.pop_stack.last()) {
1497            (None, None) => None,
1498            (Some((_, max)), None) => Some(max),
1499            (None, Some((_, max))) => Some(max),
1500            (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }),
1501        }
1502    }
1503
1504    /// Pushes a new element into the sliding window.
1505    #[inline]
1506    pub fn push(&mut self, val: T) {
1507        self.push_stack.push(match self.push_stack.last() {
1508            Some((_, max)) => {
1509                if val < *max {
1510                    (val, max.clone())
1511                } else {
1512                    (val.clone(), val)
1513                }
1514            }
1515            None => (val.clone(), val),
1516        });
1517    }
1518
1519    /// Removes and returns the last value of the sliding window.
1520    #[inline]
1521    pub fn pop(&mut self) -> Option<T> {
1522        if self.pop_stack.is_empty() {
1523            match self.push_stack.pop() {
1524                Some((val, _)) => {
1525                    let mut last = (val.clone(), val);
1526                    self.pop_stack.push(last.clone());
1527                    while let Some((val, _)) = self.push_stack.pop() {
1528                        let max = if last.1 > val {
1529                            last.1.clone()
1530                        } else {
1531                            val.clone()
1532                        };
1533                        last = (val.clone(), max);
1534                        self.pop_stack.push(last.clone());
1535                    }
1536                }
1537                None => return None,
1538            }
1539        }
1540        self.pop_stack.pop().map(|(val, _)| val)
1541    }
1542
1543    /// Returns the number of elements stored in the sliding window.
1544    #[inline]
1545    pub fn len(&self) -> usize {
1546        self.push_stack.len() + self.pop_stack.len()
1547    }
1548
1549    /// Returns `true` if the moving window contains no elements.
1550    #[inline]
1551    pub fn is_empty(&self) -> bool {
1552        self.len() == 0
1553    }
1554}
1555
1556make_udaf_expr_and_func!(
1557    Max,
1558    max,
1559    expression,
1560    "Returns the maximum of a group of values.",
1561    max_udaf
1562);
1563
1564make_udaf_expr_and_func!(
1565    Min,
1566    min,
1567    expression,
1568    "Returns the minimum of a group of values.",
1569    min_udaf
1570);
1571
1572#[cfg(test)]
1573mod tests {
1574    use super::*;
1575    use arrow::datatypes::{
1576        IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
1577    };
1578    use std::sync::Arc;
1579
1580    #[test]
1581    fn interval_min_max() {
1582        // IntervalYearMonth
1583        let b = IntervalYearMonthArray::from(vec![
1584            IntervalYearMonthType::make_value(0, 1),
1585            IntervalYearMonthType::make_value(5, 34),
1586            IntervalYearMonthType::make_value(-2, 4),
1587            IntervalYearMonthType::make_value(7, -4),
1588            IntervalYearMonthType::make_value(0, 1),
1589        ]);
1590        let b: ArrayRef = Arc::new(b);
1591
1592        let mut min =
1593            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1594                .unwrap();
1595        min.update_batch(&[Arc::clone(&b)]).unwrap();
1596        let min_res = min.evaluate().unwrap();
1597        assert_eq!(
1598            min_res,
1599            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1600                -2, 4
1601            )))
1602        );
1603
1604        let mut max =
1605            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1606                .unwrap();
1607        max.update_batch(&[Arc::clone(&b)]).unwrap();
1608        let max_res = max.evaluate().unwrap();
1609        assert_eq!(
1610            max_res,
1611            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1612                5, 34
1613            )))
1614        );
1615
1616        // IntervalDayTime
1617        let b = IntervalDayTimeArray::from(vec![
1618            IntervalDayTimeType::make_value(0, 0),
1619            IntervalDayTimeType::make_value(5, 454000),
1620            IntervalDayTimeType::make_value(-34, 0),
1621            IntervalDayTimeType::make_value(7, -4000),
1622            IntervalDayTimeType::make_value(1, 0),
1623        ]);
1624        let b: ArrayRef = Arc::new(b);
1625
1626        let mut min =
1627            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1628        min.update_batch(&[Arc::clone(&b)]).unwrap();
1629        let min_res = min.evaluate().unwrap();
1630        assert_eq!(
1631            min_res,
1632            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0)))
1633        );
1634
1635        let mut max =
1636            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1637        max.update_batch(&[Arc::clone(&b)]).unwrap();
1638        let max_res = max.evaluate().unwrap();
1639        assert_eq!(
1640            max_res,
1641            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000)))
1642        );
1643
1644        // IntervalMonthDayNano
1645        let b = IntervalMonthDayNanoArray::from(vec![
1646            IntervalMonthDayNanoType::make_value(1, 0, 0),
1647            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1648            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1649            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1650            IntervalMonthDayNanoType::make_value(1, 0, 0),
1651        ]);
1652        let b: ArrayRef = Arc::new(b);
1653
1654        let mut min =
1655            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1656                .unwrap();
1657        min.update_batch(&[Arc::clone(&b)]).unwrap();
1658        let min_res = min.evaluate().unwrap();
1659        assert_eq!(
1660            min_res,
1661            ScalarValue::IntervalMonthDayNano(Some(
1662                IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000)
1663            ))
1664        );
1665
1666        let mut max =
1667            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1668                .unwrap();
1669        max.update_batch(&[Arc::clone(&b)]).unwrap();
1670        let max_res = max.evaluate().unwrap();
1671        assert_eq!(
1672            max_res,
1673            ScalarValue::IntervalMonthDayNano(Some(
1674                IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000)
1675            ))
1676        );
1677    }
1678
1679    #[test]
1680    fn float_min_max_with_nans() {
1681        let pos_nan = f32::NAN;
1682        let zero = 0_f32;
1683        let neg_inf = f32::NEG_INFINITY;
1684
1685        let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1686            for batch in values.iter() {
1687                let batch =
1688                    Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1689                acc.update_batch(&[batch]).unwrap();
1690            }
1691            let result = acc.evaluate().unwrap();
1692            assert_eq!(result, ScalarValue::Float32(Some(expected)));
1693        };
1694
1695        // This test checks both comparison between batches (which uses the min_max macro
1696        // defined above) and within a batch (which uses the arrow min/max compute function
1697        // and verifies both respect the total order comparison for floats)
1698
1699        let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1700        let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1701
1702        check(&mut min(), &[&[zero], &[pos_nan]], zero);
1703        check(&mut min(), &[&[zero, pos_nan]], zero);
1704        check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1705        check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1706        check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1707        check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1708        check(&mut max(), &[&[zero], &[neg_inf]], zero);
1709        check(&mut max(), &[&[zero, neg_inf]], zero);
1710    }
1711
1712    use datafusion_common::Result;
1713    use rand::Rng;
1714
1715    fn get_random_vec_i32(len: usize) -> Vec<i32> {
1716        let mut rng = rand::thread_rng();
1717        let mut input = Vec::with_capacity(len);
1718        for _i in 0..len {
1719            input.push(rng.gen_range(0..100));
1720        }
1721        input
1722    }
1723
1724    fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1725        let data = get_random_vec_i32(len);
1726        let mut expected = Vec::with_capacity(len);
1727        let mut moving_min = MovingMin::<i32>::new();
1728        let mut res = Vec::with_capacity(len);
1729        for i in 0..len {
1730            let start = i.saturating_sub(n_sliding_window);
1731            expected.push(*data[start..i + 1].iter().min().unwrap());
1732
1733            moving_min.push(data[i]);
1734            if i > n_sliding_window {
1735                moving_min.pop();
1736            }
1737            res.push(*moving_min.min().unwrap());
1738        }
1739        assert_eq!(res, expected);
1740        Ok(())
1741    }
1742
1743    fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1744        let data = get_random_vec_i32(len);
1745        let mut expected = Vec::with_capacity(len);
1746        let mut moving_max = MovingMax::<i32>::new();
1747        let mut res = Vec::with_capacity(len);
1748        for i in 0..len {
1749            let start = i.saturating_sub(n_sliding_window);
1750            expected.push(*data[start..i + 1].iter().max().unwrap());
1751
1752            moving_max.push(data[i]);
1753            if i > n_sliding_window {
1754                moving_max.pop();
1755            }
1756            res.push(*moving_max.max().unwrap());
1757        }
1758        assert_eq!(res, expected);
1759        Ok(())
1760    }
1761
1762    #[test]
1763    fn moving_min_tests() -> Result<()> {
1764        moving_min_i32(100, 10)?;
1765        moving_min_i32(100, 20)?;
1766        moving_min_i32(100, 50)?;
1767        moving_min_i32(100, 100)?;
1768        Ok(())
1769    }
1770
1771    #[test]
1772    fn moving_max_tests() -> Result<()> {
1773        moving_max_i32(100, 10)?;
1774        moving_max_i32(100, 20)?;
1775        moving_max_i32(100, 50)?;
1776        moving_max_i32(100, 100)?;
1777        Ok(())
1778    }
1779
1780    #[test]
1781    fn test_min_max_coerce_types() {
1782        // the coerced types is same with input types
1783        let funs: Vec<Box<dyn AggregateUDFImpl>> =
1784            vec![Box::new(Min::new()), Box::new(Max::new())];
1785        let input_types = vec![
1786            vec![DataType::Int32],
1787            vec![DataType::Decimal128(10, 2)],
1788            vec![DataType::Decimal256(1, 1)],
1789            vec![DataType::Utf8],
1790        ];
1791        for fun in funs {
1792            for input_type in &input_types {
1793                let result = fun.coerce_types(input_type);
1794                assert_eq!(*input_type, result.unwrap());
1795            }
1796        }
1797    }
1798
1799    #[test]
1800    fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
1801        let data_type =
1802            DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32));
1803        let result = get_min_max_result_type(&[data_type])?;
1804        assert_eq!(result, vec![DataType::Int32]);
1805        Ok(())
1806    }
1807}