datafusion_functions_aggregate/
min_max.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
19//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
20
21mod min_max_bytes;
22mod min_max_struct;
23
24use arrow::array::ArrayRef;
25use arrow::datatypes::{
26    DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
27    DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
28    DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
29    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30};
31use datafusion_common::stats::Precision;
32use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result};
33use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
34use datafusion_physical_expr::expressions;
35use std::cmp::Ordering;
36use std::fmt::Debug;
37
38use arrow::datatypes::i256;
39use arrow::datatypes::{
40    Date32Type, Date64Type, Time32MillisecondType, Time32SecondType,
41    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
42    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
43};
44
45use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
46use crate::min_max::min_max_struct::MinMaxStructAccumulator;
47use datafusion_common::ScalarValue;
48use datafusion_expr::{
49    function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation,
50    SetMonotonicity, Signature, Volatility,
51};
52use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
53use datafusion_macros::user_doc;
54use half::f16;
55use std::mem::size_of_val;
56use std::ops::Deref;
57
58fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
59    // make sure that the input types only has one element.
60    if input_types.len() != 1 {
61        return exec_err!(
62            "min/max was called with {} arguments. It requires only 1.",
63            input_types.len()
64        );
65    }
66    // min and max support the dictionary data type
67    // unpack the dictionary to get the value
68    match &input_types[0] {
69        DataType::Dictionary(_, dict_value_type) => {
70            // TODO add checker, if the value type is complex data type
71            Ok(vec![dict_value_type.deref().clone()])
72        }
73        // TODO add checker for datatype which min and max supported
74        // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
75        _ => Ok(input_types.to_vec()),
76    }
77}
78
79#[user_doc(
80    doc_section(label = "General Functions"),
81    description = "Returns the maximum value in the specified column.",
82    syntax_example = "max(expression)",
83    sql_example = r#"```sql
84> SELECT max(column_name) FROM table_name;
85+----------------------+
86| max(column_name)      |
87+----------------------+
88| 150                  |
89+----------------------+
90```"#,
91    standard_argument(name = "expression",)
92)]
93// MAX aggregate UDF
94#[derive(Debug, PartialEq, Eq, Hash)]
95pub struct Max {
96    signature: Signature,
97}
98
99impl Max {
100    pub fn new() -> Self {
101        Self {
102            signature: Signature::user_defined(Volatility::Immutable),
103        }
104    }
105}
106
107impl Default for Max {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
113/// the specified [`ArrowPrimitiveType`].
114///
115/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
116macro_rules! primitive_max_accumulator {
117    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
118        Ok(Box::new(
119            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
120                match (new).partial_cmp(cur) {
121                    Some(Ordering::Greater) | None => {
122                        // new is Greater or None
123                        *cur = new
124                    }
125                    _ => {}
126                }
127            })
128            // Initialize each accumulator to $NATIVE::MIN
129            .with_starting_value($NATIVE::MIN),
130        ))
131    }};
132}
133
134/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
135/// the specified [`ArrowPrimitiveType`].
136///
137///
138/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
139macro_rules! primitive_min_accumulator {
140    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
141        Ok(Box::new(
142            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
143                match (new).partial_cmp(cur) {
144                    Some(Ordering::Less) | None => {
145                        // new is Less or NaN
146                        *cur = new
147                    }
148                    _ => {}
149                }
150            })
151            // Initialize each accumulator to $NATIVE::MAX
152            .with_starting_value($NATIVE::MAX),
153        ))
154    }};
155}
156
157trait FromColumnStatistics {
158    fn value_from_column_statistics(
159        &self,
160        stats: &ColumnStatistics,
161    ) -> Option<ScalarValue>;
162
163    fn value_from_statistics(
164        &self,
165        statistics_args: &StatisticsArgs,
166    ) -> Option<ScalarValue> {
167        if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
168            match *num_rows {
169                0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
170                value if value > 0 => {
171                    let col_stats = &statistics_args.statistics.column_statistics;
172                    if statistics_args.exprs.len() == 1 {
173                        // TODO optimize with exprs other than Column
174                        if let Some(col_expr) = statistics_args.exprs[0]
175                            .as_any()
176                            .downcast_ref::<expressions::Column>()
177                        {
178                            return self.value_from_column_statistics(
179                                &col_stats[col_expr.index()],
180                            );
181                        }
182                    }
183                }
184                _ => {}
185            }
186        }
187        None
188    }
189}
190
191impl FromColumnStatistics for Max {
192    fn value_from_column_statistics(
193        &self,
194        col_stats: &ColumnStatistics,
195    ) -> Option<ScalarValue> {
196        if let Precision::Exact(ref val) = col_stats.max_value {
197            if !val.is_null() {
198                return Some(val.clone());
199            }
200        }
201        None
202    }
203}
204
205impl AggregateUDFImpl for Max {
206    fn as_any(&self) -> &dyn std::any::Any {
207        self
208    }
209
210    fn name(&self) -> &str {
211        "max"
212    }
213
214    fn signature(&self) -> &Signature {
215        &self.signature
216    }
217
218    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
219        Ok(arg_types[0].to_owned())
220    }
221
222    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
223        Ok(Box::new(MaxAccumulator::try_new(
224            acc_args.return_field.data_type(),
225        )?))
226    }
227
228    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
229        use DataType::*;
230        matches!(
231            args.return_field.data_type(),
232            Int8 | Int16
233                | Int32
234                | Int64
235                | UInt8
236                | UInt16
237                | UInt32
238                | UInt64
239                | Float16
240                | Float32
241                | Float64
242                | Decimal32(_, _)
243                | Decimal64(_, _)
244                | Decimal128(_, _)
245                | Decimal256(_, _)
246                | Date32
247                | Date64
248                | Time32(_)
249                | Time64(_)
250                | Timestamp(_, _)
251                | Utf8
252                | LargeUtf8
253                | Utf8View
254                | Binary
255                | LargeBinary
256                | BinaryView
257                | Duration(_)
258                | Struct(_)
259        )
260    }
261
262    fn create_groups_accumulator(
263        &self,
264        args: AccumulatorArgs,
265    ) -> Result<Box<dyn GroupsAccumulator>> {
266        use DataType::*;
267        use TimeUnit::*;
268        let data_type = args.return_field.data_type();
269        match data_type {
270            Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
271            Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
272            Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
273            Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
274            UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
275            UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
276            UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
277            UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
278            Float16 => {
279                primitive_max_accumulator!(data_type, f16, Float16Type)
280            }
281            Float32 => {
282                primitive_max_accumulator!(data_type, f32, Float32Type)
283            }
284            Float64 => {
285                primitive_max_accumulator!(data_type, f64, Float64Type)
286            }
287            Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
288            Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
289            Time32(Second) => {
290                primitive_max_accumulator!(data_type, i32, Time32SecondType)
291            }
292            Time32(Millisecond) => {
293                primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
294            }
295            Time64(Microsecond) => {
296                primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
297            }
298            Time64(Nanosecond) => {
299                primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
300            }
301            Timestamp(Second, _) => {
302                primitive_max_accumulator!(data_type, i64, TimestampSecondType)
303            }
304            Timestamp(Millisecond, _) => {
305                primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
306            }
307            Timestamp(Microsecond, _) => {
308                primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
309            }
310            Timestamp(Nanosecond, _) => {
311                primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
312            }
313            Duration(Second) => {
314                primitive_max_accumulator!(data_type, i64, DurationSecondType)
315            }
316            Duration(Millisecond) => {
317                primitive_max_accumulator!(data_type, i64, DurationMillisecondType)
318            }
319            Duration(Microsecond) => {
320                primitive_max_accumulator!(data_type, i64, DurationMicrosecondType)
321            }
322            Duration(Nanosecond) => {
323                primitive_max_accumulator!(data_type, i64, DurationNanosecondType)
324            }
325            Decimal32(_, _) => {
326                primitive_max_accumulator!(data_type, i32, Decimal32Type)
327            }
328            Decimal64(_, _) => {
329                primitive_max_accumulator!(data_type, i64, Decimal64Type)
330            }
331            Decimal128(_, _) => {
332                primitive_max_accumulator!(data_type, i128, Decimal128Type)
333            }
334            Decimal256(_, _) => {
335                primitive_max_accumulator!(data_type, i256, Decimal256Type)
336            }
337            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
338                Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
339            }
340            Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max(
341                data_type.clone(),
342            ))),
343            // This is only reached if groups_accumulator_supported is out of sync
344            _ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
345        }
346    }
347
348    fn create_sliding_accumulator(
349        &self,
350        args: AccumulatorArgs,
351    ) -> Result<Box<dyn Accumulator>> {
352        Ok(Box::new(SlidingMaxAccumulator::try_new(
353            args.return_field.data_type(),
354        )?))
355    }
356
357    fn is_descending(&self) -> Option<bool> {
358        Some(true)
359    }
360
361    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
362        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
363    }
364
365    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
366        get_min_max_result_type(arg_types)
367    }
368    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
369        datafusion_expr::ReversedUDAF::Identical
370    }
371    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
372        self.value_from_statistics(statistics_args)
373    }
374
375    fn documentation(&self) -> Option<&Documentation> {
376        self.doc()
377    }
378
379    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
380        // `MAX` is monotonically increasing as it always increases or stays
381        // the same as new values are seen.
382        SetMonotonicity::Increasing
383    }
384}
385
386#[derive(Debug)]
387pub struct SlidingMaxAccumulator {
388    max: ScalarValue,
389    moving_max: MovingMax<ScalarValue>,
390}
391
392impl SlidingMaxAccumulator {
393    /// new max accumulator
394    pub fn try_new(datatype: &DataType) -> Result<Self> {
395        Ok(Self {
396            max: ScalarValue::try_from(datatype)?,
397            moving_max: MovingMax::<ScalarValue>::new(),
398        })
399    }
400}
401
402impl Accumulator for SlidingMaxAccumulator {
403    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
404        for idx in 0..values[0].len() {
405            let val = ScalarValue::try_from_array(&values[0], idx)?;
406            self.moving_max.push(val);
407        }
408        if let Some(res) = self.moving_max.max() {
409            self.max = res.clone();
410        }
411        Ok(())
412    }
413
414    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
415        for _idx in 0..values[0].len() {
416            (self.moving_max).pop();
417        }
418        if let Some(res) = self.moving_max.max() {
419            self.max = res.clone();
420        }
421        Ok(())
422    }
423
424    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
425        self.update_batch(states)
426    }
427
428    fn state(&mut self) -> Result<Vec<ScalarValue>> {
429        Ok(vec![self.max.clone()])
430    }
431
432    fn evaluate(&mut self) -> Result<ScalarValue> {
433        Ok(self.max.clone())
434    }
435
436    fn supports_retract_batch(&self) -> bool {
437        true
438    }
439
440    fn size(&self) -> usize {
441        size_of_val(self) - size_of_val(&self.max) + self.max.size()
442    }
443}
444
445#[user_doc(
446    doc_section(label = "General Functions"),
447    description = "Returns the minimum value in the specified column.",
448    syntax_example = "min(expression)",
449    sql_example = r#"```sql
450> SELECT min(column_name) FROM table_name;
451+----------------------+
452| min(column_name)      |
453+----------------------+
454| 12                   |
455+----------------------+
456```"#,
457    standard_argument(name = "expression",)
458)]
459#[derive(Debug, PartialEq, Eq, Hash)]
460pub struct Min {
461    signature: Signature,
462}
463
464impl Min {
465    pub fn new() -> Self {
466        Self {
467            signature: Signature::user_defined(Volatility::Immutable),
468        }
469    }
470}
471
472impl Default for Min {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478impl FromColumnStatistics for Min {
479    fn value_from_column_statistics(
480        &self,
481        col_stats: &ColumnStatistics,
482    ) -> Option<ScalarValue> {
483        if let Precision::Exact(ref val) = col_stats.min_value {
484            if !val.is_null() {
485                return Some(val.clone());
486            }
487        }
488        None
489    }
490}
491
492impl AggregateUDFImpl for Min {
493    fn as_any(&self) -> &dyn std::any::Any {
494        self
495    }
496
497    fn name(&self) -> &str {
498        "min"
499    }
500
501    fn signature(&self) -> &Signature {
502        &self.signature
503    }
504
505    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
506        Ok(arg_types[0].to_owned())
507    }
508
509    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
510        Ok(Box::new(MinAccumulator::try_new(
511            acc_args.return_field.data_type(),
512        )?))
513    }
514
515    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
516        use DataType::*;
517        matches!(
518            args.return_field.data_type(),
519            Int8 | Int16
520                | Int32
521                | Int64
522                | UInt8
523                | UInt16
524                | UInt32
525                | UInt64
526                | Float16
527                | Float32
528                | Float64
529                | Decimal32(_, _)
530                | Decimal64(_, _)
531                | Decimal128(_, _)
532                | Decimal256(_, _)
533                | Date32
534                | Date64
535                | Time32(_)
536                | Time64(_)
537                | Timestamp(_, _)
538                | Utf8
539                | LargeUtf8
540                | Utf8View
541                | Binary
542                | LargeBinary
543                | BinaryView
544                | Duration(_)
545                | Struct(_)
546        )
547    }
548
549    fn create_groups_accumulator(
550        &self,
551        args: AccumulatorArgs,
552    ) -> Result<Box<dyn GroupsAccumulator>> {
553        use DataType::*;
554        use TimeUnit::*;
555        let data_type = args.return_field.data_type();
556        match data_type {
557            Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
558            Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
559            Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
560            Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
561            UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
562            UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
563            UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
564            UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
565            Float16 => {
566                primitive_min_accumulator!(data_type, f16, Float16Type)
567            }
568            Float32 => {
569                primitive_min_accumulator!(data_type, f32, Float32Type)
570            }
571            Float64 => {
572                primitive_min_accumulator!(data_type, f64, Float64Type)
573            }
574            Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
575            Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
576            Time32(Second) => {
577                primitive_min_accumulator!(data_type, i32, Time32SecondType)
578            }
579            Time32(Millisecond) => {
580                primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
581            }
582            Time64(Microsecond) => {
583                primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
584            }
585            Time64(Nanosecond) => {
586                primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
587            }
588            Timestamp(Second, _) => {
589                primitive_min_accumulator!(data_type, i64, TimestampSecondType)
590            }
591            Timestamp(Millisecond, _) => {
592                primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
593            }
594            Timestamp(Microsecond, _) => {
595                primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
596            }
597            Timestamp(Nanosecond, _) => {
598                primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
599            }
600            Duration(Second) => {
601                primitive_min_accumulator!(data_type, i64, DurationSecondType)
602            }
603            Duration(Millisecond) => {
604                primitive_min_accumulator!(data_type, i64, DurationMillisecondType)
605            }
606            Duration(Microsecond) => {
607                primitive_min_accumulator!(data_type, i64, DurationMicrosecondType)
608            }
609            Duration(Nanosecond) => {
610                primitive_min_accumulator!(data_type, i64, DurationNanosecondType)
611            }
612            Decimal32(_, _) => {
613                primitive_min_accumulator!(data_type, i32, Decimal32Type)
614            }
615            Decimal64(_, _) => {
616                primitive_min_accumulator!(data_type, i64, Decimal64Type)
617            }
618            Decimal128(_, _) => {
619                primitive_min_accumulator!(data_type, i128, Decimal128Type)
620            }
621            Decimal256(_, _) => {
622                primitive_min_accumulator!(data_type, i256, Decimal256Type)
623            }
624            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
625                Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
626            }
627            Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min(
628                data_type.clone(),
629            ))),
630            // This is only reached if groups_accumulator_supported is out of sync
631            _ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
632        }
633    }
634
635    fn create_sliding_accumulator(
636        &self,
637        args: AccumulatorArgs,
638    ) -> Result<Box<dyn Accumulator>> {
639        Ok(Box::new(SlidingMinAccumulator::try_new(
640            args.return_field.data_type(),
641        )?))
642    }
643
644    fn is_descending(&self) -> Option<bool> {
645        Some(false)
646    }
647
648    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
649        self.value_from_statistics(statistics_args)
650    }
651    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
652        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
653    }
654
655    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
656        get_min_max_result_type(arg_types)
657    }
658
659    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
660        datafusion_expr::ReversedUDAF::Identical
661    }
662
663    fn documentation(&self) -> Option<&Documentation> {
664        self.doc()
665    }
666
667    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
668        // `MIN` is monotonically decreasing as it always decreases or stays
669        // the same as new values are seen.
670        SetMonotonicity::Decreasing
671    }
672}
673
674#[derive(Debug)]
675pub struct SlidingMinAccumulator {
676    min: ScalarValue,
677    moving_min: MovingMin<ScalarValue>,
678}
679
680impl SlidingMinAccumulator {
681    pub fn try_new(datatype: &DataType) -> Result<Self> {
682        Ok(Self {
683            min: ScalarValue::try_from(datatype)?,
684            moving_min: MovingMin::<ScalarValue>::new(),
685        })
686    }
687}
688
689impl Accumulator for SlidingMinAccumulator {
690    fn state(&mut self) -> Result<Vec<ScalarValue>> {
691        Ok(vec![self.min.clone()])
692    }
693
694    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
695        for idx in 0..values[0].len() {
696            let val = ScalarValue::try_from_array(&values[0], idx)?;
697            if !val.is_null() {
698                self.moving_min.push(val);
699            }
700        }
701        if let Some(res) = self.moving_min.min() {
702            self.min = res.clone();
703        }
704        Ok(())
705    }
706
707    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
708        for idx in 0..values[0].len() {
709            let val = ScalarValue::try_from_array(&values[0], idx)?;
710            if !val.is_null() {
711                (self.moving_min).pop();
712            }
713        }
714        if let Some(res) = self.moving_min.min() {
715            self.min = res.clone();
716        }
717        Ok(())
718    }
719
720    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
721        self.update_batch(states)
722    }
723
724    fn evaluate(&mut self) -> Result<ScalarValue> {
725        Ok(self.min.clone())
726    }
727
728    fn supports_retract_batch(&self) -> bool {
729        true
730    }
731
732    fn size(&self) -> usize {
733        size_of_val(self) - size_of_val(&self.min) + self.min.size()
734    }
735}
736
737/// Keep track of the minimum value in a sliding window.
738///
739/// The implementation is taken from <https://github.com/spebern/moving_min_max/blob/master/src/lib.rs>
740///
741/// `moving min max` provides one data structure for keeping track of the
742/// minimum value and one for keeping track of the maximum value in a sliding
743/// window.
744///
745/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty,
746/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from
747/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue,
748/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values.
749///
750/// The complexity of the operations are
751/// - O(1) for getting the minimum/maximum
752/// - O(1) for push
753/// - amortized O(1) for pop
754///
755/// ```
756/// # use datafusion_functions_aggregate::min_max::MovingMin;
757/// let mut moving_min = MovingMin::<i32>::new();
758/// moving_min.push(2);
759/// moving_min.push(1);
760/// moving_min.push(3);
761///
762/// assert_eq!(moving_min.min(), Some(&1));
763/// assert_eq!(moving_min.pop(), Some(2));
764///
765/// assert_eq!(moving_min.min(), Some(&1));
766/// assert_eq!(moving_min.pop(), Some(1));
767///
768/// assert_eq!(moving_min.min(), Some(&3));
769/// assert_eq!(moving_min.pop(), Some(3));
770///
771/// assert_eq!(moving_min.min(), None);
772/// assert_eq!(moving_min.pop(), None);
773/// ```
774#[derive(Debug)]
775pub struct MovingMin<T> {
776    push_stack: Vec<(T, T)>,
777    pop_stack: Vec<(T, T)>,
778}
779
780impl<T: Clone + PartialOrd> Default for MovingMin<T> {
781    fn default() -> Self {
782        Self {
783            push_stack: Vec::new(),
784            pop_stack: Vec::new(),
785        }
786    }
787}
788
789impl<T: Clone + PartialOrd> MovingMin<T> {
790    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
791    /// window.
792    #[inline]
793    pub fn new() -> Self {
794        Self::default()
795    }
796
797    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
798    /// window with `capacity` allocated slots.
799    #[inline]
800    pub fn with_capacity(capacity: usize) -> Self {
801        Self {
802            push_stack: Vec::with_capacity(capacity),
803            pop_stack: Vec::with_capacity(capacity),
804        }
805    }
806
807    /// Returns the minimum of the sliding window or `None` if the window is
808    /// empty.
809    #[inline]
810    pub fn min(&self) -> Option<&T> {
811        match (self.push_stack.last(), self.pop_stack.last()) {
812            (None, None) => None,
813            (Some((_, min)), None) => Some(min),
814            (None, Some((_, min))) => Some(min),
815            (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }),
816        }
817    }
818
819    /// Pushes a new element into the sliding window.
820    #[inline]
821    pub fn push(&mut self, val: T) {
822        self.push_stack.push(match self.push_stack.last() {
823            Some((_, min)) => {
824                if val > *min {
825                    (val, min.clone())
826                } else {
827                    (val.clone(), val)
828                }
829            }
830            None => (val.clone(), val),
831        });
832    }
833
834    /// Removes and returns the last value of the sliding window.
835    #[inline]
836    pub fn pop(&mut self) -> Option<T> {
837        if self.pop_stack.is_empty() {
838            match self.push_stack.pop() {
839                Some((val, _)) => {
840                    let mut last = (val.clone(), val);
841                    self.pop_stack.push(last.clone());
842                    while let Some((val, _)) = self.push_stack.pop() {
843                        let min = if last.1 < val {
844                            last.1.clone()
845                        } else {
846                            val.clone()
847                        };
848                        last = (val.clone(), min);
849                        self.pop_stack.push(last.clone());
850                    }
851                }
852                None => return None,
853            }
854        }
855        self.pop_stack.pop().map(|(val, _)| val)
856    }
857
858    /// Returns the number of elements stored in the sliding window.
859    #[inline]
860    pub fn len(&self) -> usize {
861        self.push_stack.len() + self.pop_stack.len()
862    }
863
864    /// Returns `true` if the moving window contains no elements.
865    #[inline]
866    pub fn is_empty(&self) -> bool {
867        self.len() == 0
868    }
869}
870
871/// Keep track of the maximum value in a sliding window.
872///
873/// See [`MovingMin`] for more details.
874///
875/// ```
876/// # use datafusion_functions_aggregate::min_max::MovingMax;
877/// let mut moving_max = MovingMax::<i32>::new();
878/// moving_max.push(2);
879/// moving_max.push(3);
880/// moving_max.push(1);
881///
882/// assert_eq!(moving_max.max(), Some(&3));
883/// assert_eq!(moving_max.pop(), Some(2));
884///
885/// assert_eq!(moving_max.max(), Some(&3));
886/// assert_eq!(moving_max.pop(), Some(3));
887///
888/// assert_eq!(moving_max.max(), Some(&1));
889/// assert_eq!(moving_max.pop(), Some(1));
890///
891/// assert_eq!(moving_max.max(), None);
892/// assert_eq!(moving_max.pop(), None);
893/// ```
894#[derive(Debug)]
895pub struct MovingMax<T> {
896    push_stack: Vec<(T, T)>,
897    pop_stack: Vec<(T, T)>,
898}
899
900impl<T: Clone + PartialOrd> Default for MovingMax<T> {
901    fn default() -> Self {
902        Self {
903            push_stack: Vec::new(),
904            pop_stack: Vec::new(),
905        }
906    }
907}
908
909impl<T: Clone + PartialOrd> MovingMax<T> {
910    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window.
911    #[inline]
912    pub fn new() -> Self {
913        Self::default()
914    }
915
916    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with
917    /// `capacity` allocated slots.
918    #[inline]
919    pub fn with_capacity(capacity: usize) -> Self {
920        Self {
921            push_stack: Vec::with_capacity(capacity),
922            pop_stack: Vec::with_capacity(capacity),
923        }
924    }
925
926    /// Returns the maximum of the sliding window or `None` if the window is empty.
927    #[inline]
928    pub fn max(&self) -> Option<&T> {
929        match (self.push_stack.last(), self.pop_stack.last()) {
930            (None, None) => None,
931            (Some((_, max)), None) => Some(max),
932            (None, Some((_, max))) => Some(max),
933            (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }),
934        }
935    }
936
937    /// Pushes a new element into the sliding window.
938    #[inline]
939    pub fn push(&mut self, val: T) {
940        self.push_stack.push(match self.push_stack.last() {
941            Some((_, max)) => {
942                if val < *max {
943                    (val, max.clone())
944                } else {
945                    (val.clone(), val)
946                }
947            }
948            None => (val.clone(), val),
949        });
950    }
951
952    /// Removes and returns the last value of the sliding window.
953    #[inline]
954    pub fn pop(&mut self) -> Option<T> {
955        if self.pop_stack.is_empty() {
956            match self.push_stack.pop() {
957                Some((val, _)) => {
958                    let mut last = (val.clone(), val);
959                    self.pop_stack.push(last.clone());
960                    while let Some((val, _)) = self.push_stack.pop() {
961                        let max = if last.1 > val {
962                            last.1.clone()
963                        } else {
964                            val.clone()
965                        };
966                        last = (val.clone(), max);
967                        self.pop_stack.push(last.clone());
968                    }
969                }
970                None => return None,
971            }
972        }
973        self.pop_stack.pop().map(|(val, _)| val)
974    }
975
976    /// Returns the number of elements stored in the sliding window.
977    #[inline]
978    pub fn len(&self) -> usize {
979        self.push_stack.len() + self.pop_stack.len()
980    }
981
982    /// Returns `true` if the moving window contains no elements.
983    #[inline]
984    pub fn is_empty(&self) -> bool {
985        self.len() == 0
986    }
987}
988
989make_udaf_expr_and_func!(
990    Max,
991    max,
992    expression,
993    "Returns the maximum of a group of values.",
994    max_udaf
995);
996
997make_udaf_expr_and_func!(
998    Min,
999    min,
1000    expression,
1001    "Returns the minimum of a group of values.",
1002    min_udaf
1003);
1004
1005// Re-export accumulators from the common module for backwards compatibility
1006pub use datafusion_functions_aggregate_common::min_max::{
1007    MaxAccumulator, MinAccumulator,
1008};
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013    use arrow::{
1014        array::{
1015            DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray,
1016            IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray,
1017        },
1018        datatypes::{
1019            IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
1020            IntervalYearMonthType,
1021        },
1022    };
1023    use std::sync::Arc;
1024
1025    #[test]
1026    fn interval_min_max() {
1027        // IntervalYearMonth
1028        let b = IntervalYearMonthArray::from(vec![
1029            IntervalYearMonthType::make_value(0, 1),
1030            IntervalYearMonthType::make_value(5, 34),
1031            IntervalYearMonthType::make_value(-2, 4),
1032            IntervalYearMonthType::make_value(7, -4),
1033            IntervalYearMonthType::make_value(0, 1),
1034        ]);
1035        let b: ArrayRef = Arc::new(b);
1036
1037        let mut min =
1038            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1039                .unwrap();
1040        min.update_batch(&[Arc::clone(&b)]).unwrap();
1041        let min_res = min.evaluate().unwrap();
1042        assert_eq!(
1043            min_res,
1044            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1045                -2, 4,
1046            )))
1047        );
1048
1049        let mut max =
1050            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1051                .unwrap();
1052        max.update_batch(&[Arc::clone(&b)]).unwrap();
1053        let max_res = max.evaluate().unwrap();
1054        assert_eq!(
1055            max_res,
1056            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1057                5, 34,
1058            )))
1059        );
1060
1061        // IntervalDayTime
1062        let b = IntervalDayTimeArray::from(vec![
1063            IntervalDayTimeType::make_value(0, 0),
1064            IntervalDayTimeType::make_value(5, 454000),
1065            IntervalDayTimeType::make_value(-34, 0),
1066            IntervalDayTimeType::make_value(7, -4000),
1067            IntervalDayTimeType::make_value(1, 0),
1068        ]);
1069        let b: ArrayRef = Arc::new(b);
1070
1071        let mut min =
1072            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1073        min.update_batch(&[Arc::clone(&b)]).unwrap();
1074        let min_res = min.evaluate().unwrap();
1075        assert_eq!(
1076            min_res,
1077            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0)))
1078        );
1079
1080        let mut max =
1081            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1082        max.update_batch(&[Arc::clone(&b)]).unwrap();
1083        let max_res = max.evaluate().unwrap();
1084        assert_eq!(
1085            max_res,
1086            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000)))
1087        );
1088
1089        // IntervalMonthDayNano
1090        let b = IntervalMonthDayNanoArray::from(vec![
1091            IntervalMonthDayNanoType::make_value(1, 0, 0),
1092            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1093            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1094            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1095            IntervalMonthDayNanoType::make_value(1, 0, 0),
1096        ]);
1097        let b: ArrayRef = Arc::new(b);
1098
1099        let mut min =
1100            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1101                .unwrap();
1102        min.update_batch(&[Arc::clone(&b)]).unwrap();
1103        let min_res = min.evaluate().unwrap();
1104        assert_eq!(
1105            min_res,
1106            ScalarValue::IntervalMonthDayNano(Some(
1107                IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000)
1108            ))
1109        );
1110
1111        let mut max =
1112            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1113                .unwrap();
1114        max.update_batch(&[Arc::clone(&b)]).unwrap();
1115        let max_res = max.evaluate().unwrap();
1116        assert_eq!(
1117            max_res,
1118            ScalarValue::IntervalMonthDayNano(Some(
1119                IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000)
1120            ))
1121        );
1122    }
1123
1124    #[test]
1125    fn float_min_max_with_nans() {
1126        let pos_nan = f32::NAN;
1127        let zero = 0_f32;
1128        let neg_inf = f32::NEG_INFINITY;
1129
1130        let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1131            for batch in values.iter() {
1132                let batch =
1133                    Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1134                acc.update_batch(&[batch]).unwrap();
1135            }
1136            let result = acc.evaluate().unwrap();
1137            assert_eq!(result, ScalarValue::Float32(Some(expected)));
1138        };
1139
1140        // This test checks both comparison between batches (which uses the min_max macro
1141        // defined above) and within a batch (which uses the arrow min/max compute function
1142        // and verifies both respect the total order comparison for floats)
1143
1144        let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1145        let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1146
1147        check(&mut min(), &[&[zero], &[pos_nan]], zero);
1148        check(&mut min(), &[&[zero, pos_nan]], zero);
1149        check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1150        check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1151        check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1152        check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1153        check(&mut max(), &[&[zero], &[neg_inf]], zero);
1154        check(&mut max(), &[&[zero, neg_inf]], zero);
1155    }
1156
1157    use datafusion_common::Result;
1158    use rand::Rng;
1159
1160    fn get_random_vec_i32(len: usize) -> Vec<i32> {
1161        let mut rng = rand::rng();
1162        let mut input = Vec::with_capacity(len);
1163        for _i in 0..len {
1164            input.push(rng.random_range(0..100));
1165        }
1166        input
1167    }
1168
1169    fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1170        let data = get_random_vec_i32(len);
1171        let mut expected = Vec::with_capacity(len);
1172        let mut moving_min = MovingMin::<i32>::new();
1173        let mut res = Vec::with_capacity(len);
1174        for i in 0..len {
1175            let start = i.saturating_sub(n_sliding_window);
1176            expected.push(*data[start..i + 1].iter().min().unwrap());
1177
1178            moving_min.push(data[i]);
1179            if i > n_sliding_window {
1180                moving_min.pop();
1181            }
1182            res.push(*moving_min.min().unwrap());
1183        }
1184        assert_eq!(res, expected);
1185        Ok(())
1186    }
1187
1188    fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1189        let data = get_random_vec_i32(len);
1190        let mut expected = Vec::with_capacity(len);
1191        let mut moving_max = MovingMax::<i32>::new();
1192        let mut res = Vec::with_capacity(len);
1193        for i in 0..len {
1194            let start = i.saturating_sub(n_sliding_window);
1195            expected.push(*data[start..i + 1].iter().max().unwrap());
1196
1197            moving_max.push(data[i]);
1198            if i > n_sliding_window {
1199                moving_max.pop();
1200            }
1201            res.push(*moving_max.max().unwrap());
1202        }
1203        assert_eq!(res, expected);
1204        Ok(())
1205    }
1206
1207    #[test]
1208    fn moving_min_tests() -> Result<()> {
1209        moving_min_i32(100, 10)?;
1210        moving_min_i32(100, 20)?;
1211        moving_min_i32(100, 50)?;
1212        moving_min_i32(100, 100)?;
1213        Ok(())
1214    }
1215
1216    #[test]
1217    fn moving_max_tests() -> Result<()> {
1218        moving_max_i32(100, 10)?;
1219        moving_max_i32(100, 20)?;
1220        moving_max_i32(100, 50)?;
1221        moving_max_i32(100, 100)?;
1222        Ok(())
1223    }
1224
1225    #[test]
1226    fn test_min_max_coerce_types() {
1227        // the coerced types is same with input types
1228        let funs: Vec<Box<dyn AggregateUDFImpl>> =
1229            vec![Box::new(Min::new()), Box::new(Max::new())];
1230        let input_types = vec![
1231            vec![DataType::Int32],
1232            vec![DataType::Decimal128(10, 2)],
1233            vec![DataType::Decimal256(1, 1)],
1234            vec![DataType::Utf8],
1235        ];
1236        for fun in funs {
1237            for input_type in &input_types {
1238                let result = fun.coerce_types(input_type);
1239                assert_eq!(*input_type, result.unwrap());
1240            }
1241        }
1242    }
1243
1244    #[test]
1245    fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
1246        let data_type =
1247            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1248        let result = get_min_max_result_type(&[data_type])?;
1249        assert_eq!(result, vec![DataType::Utf8]);
1250        Ok(())
1251    }
1252
1253    #[test]
1254    fn test_min_max_dictionary() -> Result<()> {
1255        let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
1256        let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
1257        let dict_array =
1258            DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1259        let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1260        let rt_type =
1261            get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone();
1262
1263        let mut min_acc = MinAccumulator::try_new(&rt_type)?;
1264        min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1265        let min_result = min_acc.evaluate()?;
1266        assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1267
1268        let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
1269        max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1270        let max_result = max_acc.evaluate()?;
1271        assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
1272        Ok(())
1273    }
1274}