1mod min_max_bytes;
22mod min_max_struct;
23
24use arrow::array::ArrayRef;
25use arrow::datatypes::{
26 DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType,
27 DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type,
28 Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
29 UInt32Type, UInt64Type, UInt8Type,
30};
31use datafusion_common::stats::Precision;
32use datafusion_common::{
33 exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
34};
35use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
36use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch};
37use datafusion_physical_expr::expressions;
38use std::cmp::Ordering;
39use std::fmt::Debug;
40
41use arrow::datatypes::i256;
42use arrow::datatypes::{
43 Date32Type, Date64Type, Time32MillisecondType, Time32SecondType,
44 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
45 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
46};
47
48use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
49use crate::min_max::min_max_struct::MinMaxStructAccumulator;
50use datafusion_common::ScalarValue;
51use datafusion_expr::{
52 function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation,
53 SetMonotonicity, Signature, Volatility,
54};
55use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
56use datafusion_macros::user_doc;
57use half::f16;
58use std::mem::size_of_val;
59use std::ops::Deref;
60
61fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
62 if input_types.len() != 1 {
64 return exec_err!(
65 "min/max was called with {} arguments. It requires only 1.",
66 input_types.len()
67 );
68 }
69 match &input_types[0] {
72 DataType::Dictionary(_, dict_value_type) => {
73 Ok(vec![dict_value_type.deref().clone()])
75 }
76 _ => Ok(input_types.to_vec()),
79 }
80}
81
82#[user_doc(
83 doc_section(label = "General Functions"),
84 description = "Returns the maximum value in the specified column.",
85 syntax_example = "max(expression)",
86 sql_example = r#"```sql
87> SELECT max(column_name) FROM table_name;
88+----------------------+
89| max(column_name) |
90+----------------------+
91| 150 |
92+----------------------+
93```"#,
94 standard_argument(name = "expression",)
95)]
96#[derive(Debug)]
98pub struct Max {
99 signature: Signature,
100}
101
102impl Max {
103 pub fn new() -> Self {
104 Self {
105 signature: Signature::user_defined(Volatility::Immutable),
106 }
107 }
108}
109
110impl Default for Max {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115macro_rules! primitive_max_accumulator {
120 ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
121 Ok(Box::new(
122 PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
123 match (new).partial_cmp(cur) {
124 Some(Ordering::Greater) | None => {
125 *cur = new
127 }
128 _ => {}
129 }
130 })
131 .with_starting_value($NATIVE::MIN),
133 ))
134 }};
135}
136
137macro_rules! primitive_min_accumulator {
143 ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
144 Ok(Box::new(
145 PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
146 match (new).partial_cmp(cur) {
147 Some(Ordering::Less) | None => {
148 *cur = new
150 }
151 _ => {}
152 }
153 })
154 .with_starting_value($NATIVE::MAX),
156 ))
157 }};
158}
159
160trait FromColumnStatistics {
161 fn value_from_column_statistics(
162 &self,
163 stats: &ColumnStatistics,
164 ) -> Option<ScalarValue>;
165
166 fn value_from_statistics(
167 &self,
168 statistics_args: &StatisticsArgs,
169 ) -> Option<ScalarValue> {
170 if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
171 match *num_rows {
172 0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
173 value if value > 0 => {
174 let col_stats = &statistics_args.statistics.column_statistics;
175 if statistics_args.exprs.len() == 1 {
176 if let Some(col_expr) = statistics_args.exprs[0]
178 .as_any()
179 .downcast_ref::<expressions::Column>()
180 {
181 return self.value_from_column_statistics(
182 &col_stats[col_expr.index()],
183 );
184 }
185 }
186 }
187 _ => {}
188 }
189 }
190 None
191 }
192}
193
194impl FromColumnStatistics for Max {
195 fn value_from_column_statistics(
196 &self,
197 col_stats: &ColumnStatistics,
198 ) -> Option<ScalarValue> {
199 if let Precision::Exact(ref val) = col_stats.max_value {
200 if !val.is_null() {
201 return Some(val.clone());
202 }
203 }
204 None
205 }
206}
207
208impl AggregateUDFImpl for Max {
209 fn as_any(&self) -> &dyn std::any::Any {
210 self
211 }
212
213 fn name(&self) -> &str {
214 "max"
215 }
216
217 fn signature(&self) -> &Signature {
218 &self.signature
219 }
220
221 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
222 Ok(arg_types[0].to_owned())
223 }
224
225 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
226 Ok(Box::new(MaxAccumulator::try_new(
227 acc_args.return_field.data_type(),
228 )?))
229 }
230
231 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
232 use DataType::*;
233 matches!(
234 args.return_field.data_type(),
235 Int8 | Int16
236 | Int32
237 | Int64
238 | UInt8
239 | UInt16
240 | UInt32
241 | UInt64
242 | Float16
243 | Float32
244 | Float64
245 | Decimal128(_, _)
246 | Decimal256(_, _)
247 | Date32
248 | Date64
249 | Time32(_)
250 | Time64(_)
251 | Timestamp(_, _)
252 | Utf8
253 | LargeUtf8
254 | Utf8View
255 | Binary
256 | LargeBinary
257 | BinaryView
258 | Duration(_)
259 | Struct(_)
260 )
261 }
262
263 fn create_groups_accumulator(
264 &self,
265 args: AccumulatorArgs,
266 ) -> Result<Box<dyn GroupsAccumulator>> {
267 use DataType::*;
268 use TimeUnit::*;
269 let data_type = args.return_field.data_type();
270 match data_type {
271 Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
272 Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
273 Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
274 Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
275 UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
276 UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
277 UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
278 UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
279 Float16 => {
280 primitive_max_accumulator!(data_type, f16, Float16Type)
281 }
282 Float32 => {
283 primitive_max_accumulator!(data_type, f32, Float32Type)
284 }
285 Float64 => {
286 primitive_max_accumulator!(data_type, f64, Float64Type)
287 }
288 Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
289 Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
290 Time32(Second) => {
291 primitive_max_accumulator!(data_type, i32, Time32SecondType)
292 }
293 Time32(Millisecond) => {
294 primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
295 }
296 Time64(Microsecond) => {
297 primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
298 }
299 Time64(Nanosecond) => {
300 primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
301 }
302 Timestamp(Second, _) => {
303 primitive_max_accumulator!(data_type, i64, TimestampSecondType)
304 }
305 Timestamp(Millisecond, _) => {
306 primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
307 }
308 Timestamp(Microsecond, _) => {
309 primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
310 }
311 Timestamp(Nanosecond, _) => {
312 primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
313 }
314 Duration(Second) => {
315 primitive_max_accumulator!(data_type, i64, DurationSecondType)
316 }
317 Duration(Millisecond) => {
318 primitive_max_accumulator!(data_type, i64, DurationMillisecondType)
319 }
320 Duration(Microsecond) => {
321 primitive_max_accumulator!(data_type, i64, DurationMicrosecondType)
322 }
323 Duration(Nanosecond) => {
324 primitive_max_accumulator!(data_type, i64, DurationNanosecondType)
325 }
326 Decimal128(_, _) => {
327 primitive_max_accumulator!(data_type, i128, Decimal128Type)
328 }
329 Decimal256(_, _) => {
330 primitive_max_accumulator!(data_type, i256, Decimal256Type)
331 }
332 Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
333 Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
334 }
335 Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max(
336 data_type.clone(),
337 ))),
338 _ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
340 }
341 }
342
343 fn create_sliding_accumulator(
344 &self,
345 args: AccumulatorArgs,
346 ) -> Result<Box<dyn Accumulator>> {
347 Ok(Box::new(SlidingMaxAccumulator::try_new(
348 args.return_field.data_type(),
349 )?))
350 }
351
352 fn is_descending(&self) -> Option<bool> {
353 Some(true)
354 }
355
356 fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
357 datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
358 }
359
360 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
361 get_min_max_result_type(arg_types)
362 }
363 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
364 datafusion_expr::ReversedUDAF::Identical
365 }
366 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
367 self.value_from_statistics(statistics_args)
368 }
369
370 fn documentation(&self) -> Option<&Documentation> {
371 self.doc()
372 }
373
374 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
375 SetMonotonicity::Increasing
378 }
379}
380
381macro_rules! min_max_generic {
382 ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
383 if $VALUE.is_null() {
384 let mut delta_copy = $DELTA.clone();
385 delta_copy.compact();
388 delta_copy
389 } else if $DELTA.is_null() {
390 $VALUE.clone()
391 } else {
392 match $VALUE.partial_cmp(&$DELTA) {
393 Some(choose_min_max!($OP)) => {
394 let mut delta_copy = $DELTA.clone();
397 delta_copy.compact();
398 delta_copy
399 }
400 _ => $VALUE.clone(),
401 }
402 }
403 }};
404}
405
406macro_rules! typed_min_max {
408 ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
409 ScalarValue::$SCALAR(
410 match ($VALUE, $DELTA) {
411 (None, None) => None,
412 (Some(a), None) => Some(*a),
413 (None, Some(b)) => Some(*b),
414 (Some(a), Some(b)) => Some((*a).$OP(*b)),
415 },
416 $($EXTRA_ARGS.clone()),*
417 )
418 }};
419}
420macro_rules! typed_min_max_float {
421 ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
422 ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
423 (None, None) => None,
424 (Some(a), None) => Some(*a),
425 (None, Some(b)) => Some(*b),
426 (Some(a), Some(b)) => match a.total_cmp(b) {
427 choose_min_max!($OP) => Some(*b),
428 _ => Some(*a),
429 },
430 })
431 }};
432}
433
434macro_rules! typed_min_max_string {
436 ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
437 ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
438 (None, None) => None,
439 (Some(a), None) => Some(a.clone()),
440 (None, Some(b)) => Some(b.clone()),
441 (Some(a), Some(b)) => Some((a).$OP(b).clone()),
442 })
443 }};
444}
445
446macro_rules! typed_min_max_string_arg {
448 ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{
449 ScalarValue::$SCALAR(
450 $ARG,
451 match ($VALUE, $DELTA) {
452 (None, None) => None,
453 (Some(a), None) => Some(a.clone()),
454 (None, Some(b)) => Some(b.clone()),
455 (Some(a), Some(b)) => Some((a).$OP(b).clone()),
456 },
457 )
458 }};
459}
460
461macro_rules! choose_min_max {
462 (min) => {
463 std::cmp::Ordering::Greater
464 };
465 (max) => {
466 std::cmp::Ordering::Less
467 };
468}
469
470macro_rules! interval_min_max {
471 ($OP:tt, $LHS:expr, $RHS:expr) => {{
472 match $LHS.partial_cmp(&$RHS) {
473 Some(choose_min_max!($OP)) => $RHS.clone(),
474 Some(_) => $LHS.clone(),
475 None => {
476 return internal_err!("Comparison error while computing interval min/max")
477 }
478 }
479 }};
480}
481
482macro_rules! min_max {
484 ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
485 Ok(match ($VALUE, $DELTA) {
486 (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
487 (
488 lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
489 rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
490 ) => {
491 if lhsp.eq(rhsp) && lhss.eq(rhss) {
492 typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss)
493 } else {
494 return internal_err!(
495 "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
496 (lhs, rhs)
497 );
498 }
499 }
500 (
501 lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss),
502 rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss)
503 ) => {
504 if lhsp.eq(rhsp) && lhss.eq(rhss) {
505 typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss)
506 } else {
507 return internal_err!(
508 "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
509 (lhs, rhs)
510 );
511 }
512 }
513 (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => {
514 typed_min_max!(lhs, rhs, Boolean, $OP)
515 }
516 (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
517 typed_min_max_float!(lhs, rhs, Float64, $OP)
518 }
519 (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
520 typed_min_max_float!(lhs, rhs, Float32, $OP)
521 }
522 (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
523 typed_min_max_float!(lhs, rhs, Float16, $OP)
524 }
525 (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
526 typed_min_max!(lhs, rhs, UInt64, $OP)
527 }
528 (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
529 typed_min_max!(lhs, rhs, UInt32, $OP)
530 }
531 (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
532 typed_min_max!(lhs, rhs, UInt16, $OP)
533 }
534 (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
535 typed_min_max!(lhs, rhs, UInt8, $OP)
536 }
537 (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
538 typed_min_max!(lhs, rhs, Int64, $OP)
539 }
540 (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
541 typed_min_max!(lhs, rhs, Int32, $OP)
542 }
543 (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
544 typed_min_max!(lhs, rhs, Int16, $OP)
545 }
546 (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
547 typed_min_max!(lhs, rhs, Int8, $OP)
548 }
549 (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => {
550 typed_min_max_string!(lhs, rhs, Utf8, $OP)
551 }
552 (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
553 typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
554 }
555 (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
556 typed_min_max_string!(lhs, rhs, Utf8View, $OP)
557 }
558 (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
559 typed_min_max_string!(lhs, rhs, Binary, $OP)
560 }
561 (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
562 typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
563 }
564 (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => {
565 if lsize == rsize {
566 typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize)
567 }
568 else {
569 return internal_err!(
570 "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}",
571 (lsize, rsize))
572 }
573 }
574 (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
575 typed_min_max_string!(lhs, rhs, BinaryView, $OP)
576 }
577 (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
578 typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
579 }
580 (
581 ScalarValue::TimestampMillisecond(lhs, l_tz),
582 ScalarValue::TimestampMillisecond(rhs, _),
583 ) => {
584 typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz)
585 }
586 (
587 ScalarValue::TimestampMicrosecond(lhs, l_tz),
588 ScalarValue::TimestampMicrosecond(rhs, _),
589 ) => {
590 typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz)
591 }
592 (
593 ScalarValue::TimestampNanosecond(lhs, l_tz),
594 ScalarValue::TimestampNanosecond(rhs, _),
595 ) => {
596 typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz)
597 }
598 (
599 ScalarValue::Date32(lhs),
600 ScalarValue::Date32(rhs),
601 ) => {
602 typed_min_max!(lhs, rhs, Date32, $OP)
603 }
604 (
605 ScalarValue::Date64(lhs),
606 ScalarValue::Date64(rhs),
607 ) => {
608 typed_min_max!(lhs, rhs, Date64, $OP)
609 }
610 (
611 ScalarValue::Time32Second(lhs),
612 ScalarValue::Time32Second(rhs),
613 ) => {
614 typed_min_max!(lhs, rhs, Time32Second, $OP)
615 }
616 (
617 ScalarValue::Time32Millisecond(lhs),
618 ScalarValue::Time32Millisecond(rhs),
619 ) => {
620 typed_min_max!(lhs, rhs, Time32Millisecond, $OP)
621 }
622 (
623 ScalarValue::Time64Microsecond(lhs),
624 ScalarValue::Time64Microsecond(rhs),
625 ) => {
626 typed_min_max!(lhs, rhs, Time64Microsecond, $OP)
627 }
628 (
629 ScalarValue::Time64Nanosecond(lhs),
630 ScalarValue::Time64Nanosecond(rhs),
631 ) => {
632 typed_min_max!(lhs, rhs, Time64Nanosecond, $OP)
633 }
634 (
635 ScalarValue::IntervalYearMonth(lhs),
636 ScalarValue::IntervalYearMonth(rhs),
637 ) => {
638 typed_min_max!(lhs, rhs, IntervalYearMonth, $OP)
639 }
640 (
641 ScalarValue::IntervalMonthDayNano(lhs),
642 ScalarValue::IntervalMonthDayNano(rhs),
643 ) => {
644 typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP)
645 }
646 (
647 ScalarValue::IntervalDayTime(lhs),
648 ScalarValue::IntervalDayTime(rhs),
649 ) => {
650 typed_min_max!(lhs, rhs, IntervalDayTime, $OP)
651 }
652 (
653 ScalarValue::IntervalYearMonth(_),
654 ScalarValue::IntervalMonthDayNano(_),
655 ) | (
656 ScalarValue::IntervalYearMonth(_),
657 ScalarValue::IntervalDayTime(_),
658 ) | (
659 ScalarValue::IntervalMonthDayNano(_),
660 ScalarValue::IntervalDayTime(_),
661 ) | (
662 ScalarValue::IntervalMonthDayNano(_),
663 ScalarValue::IntervalYearMonth(_),
664 ) | (
665 ScalarValue::IntervalDayTime(_),
666 ScalarValue::IntervalYearMonth(_),
667 ) | (
668 ScalarValue::IntervalDayTime(_),
669 ScalarValue::IntervalMonthDayNano(_),
670 ) => {
671 interval_min_max!($OP, $VALUE, $DELTA)
672 }
673 (
674 ScalarValue::DurationSecond(lhs),
675 ScalarValue::DurationSecond(rhs),
676 ) => {
677 typed_min_max!(lhs, rhs, DurationSecond, $OP)
678 }
679 (
680 ScalarValue::DurationMillisecond(lhs),
681 ScalarValue::DurationMillisecond(rhs),
682 ) => {
683 typed_min_max!(lhs, rhs, DurationMillisecond, $OP)
684 }
685 (
686 ScalarValue::DurationMicrosecond(lhs),
687 ScalarValue::DurationMicrosecond(rhs),
688 ) => {
689 typed_min_max!(lhs, rhs, DurationMicrosecond, $OP)
690 }
691 (
692 ScalarValue::DurationNanosecond(lhs),
693 ScalarValue::DurationNanosecond(rhs),
694 ) => {
695 typed_min_max!(lhs, rhs, DurationNanosecond, $OP)
696 }
697
698 (
699 lhs @ ScalarValue::Struct(_),
700 rhs @ ScalarValue::Struct(_),
701 ) => {
702 min_max_generic!(lhs, rhs, $OP)
703 }
704
705 (
706 lhs @ ScalarValue::List(_),
707 rhs @ ScalarValue::List(_),
708 ) => {
709 min_max_generic!(lhs, rhs, $OP)
710 }
711
712
713 (
714 lhs @ ScalarValue::LargeList(_),
715 rhs @ ScalarValue::LargeList(_),
716 ) => {
717 min_max_generic!(lhs, rhs, $OP)
718 }
719
720
721 (
722 lhs @ ScalarValue::FixedSizeList(_),
723 rhs @ ScalarValue::FixedSizeList(_),
724 ) => {
725 min_max_generic!(lhs, rhs, $OP)
726 }
727
728 e => {
729 return internal_err!(
730 "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
731 e
732 )
733 }
734 })
735 }};
736}
737
738#[derive(Debug)]
740pub struct MaxAccumulator {
741 max: ScalarValue,
742}
743
744impl MaxAccumulator {
745 pub fn try_new(datatype: &DataType) -> Result<Self> {
747 Ok(Self {
748 max: ScalarValue::try_from(datatype)?,
749 })
750 }
751}
752
753impl Accumulator for MaxAccumulator {
754 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
755 let values = &values[0];
756 let delta = &max_batch(values)?;
757 let new_max: Result<ScalarValue, DataFusionError> =
758 min_max!(&self.max, delta, max);
759 self.max = new_max?;
760 Ok(())
761 }
762
763 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
764 self.update_batch(states)
765 }
766
767 fn state(&mut self) -> Result<Vec<ScalarValue>> {
768 Ok(vec![self.evaluate()?])
769 }
770 fn evaluate(&mut self) -> Result<ScalarValue> {
771 Ok(self.max.clone())
772 }
773
774 fn size(&self) -> usize {
775 size_of_val(self) - size_of_val(&self.max) + self.max.size()
776 }
777}
778
779#[derive(Debug)]
780pub struct SlidingMaxAccumulator {
781 max: ScalarValue,
782 moving_max: MovingMax<ScalarValue>,
783}
784
785impl SlidingMaxAccumulator {
786 pub fn try_new(datatype: &DataType) -> Result<Self> {
788 Ok(Self {
789 max: ScalarValue::try_from(datatype)?,
790 moving_max: MovingMax::<ScalarValue>::new(),
791 })
792 }
793}
794
795impl Accumulator for SlidingMaxAccumulator {
796 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
797 for idx in 0..values[0].len() {
798 let val = ScalarValue::try_from_array(&values[0], idx)?;
799 self.moving_max.push(val);
800 }
801 if let Some(res) = self.moving_max.max() {
802 self.max = res.clone();
803 }
804 Ok(())
805 }
806
807 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
808 for _idx in 0..values[0].len() {
809 (self.moving_max).pop();
810 }
811 if let Some(res) = self.moving_max.max() {
812 self.max = res.clone();
813 }
814 Ok(())
815 }
816
817 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
818 self.update_batch(states)
819 }
820
821 fn state(&mut self) -> Result<Vec<ScalarValue>> {
822 Ok(vec![self.max.clone()])
823 }
824
825 fn evaluate(&mut self) -> Result<ScalarValue> {
826 Ok(self.max.clone())
827 }
828
829 fn supports_retract_batch(&self) -> bool {
830 true
831 }
832
833 fn size(&self) -> usize {
834 size_of_val(self) - size_of_val(&self.max) + self.max.size()
835 }
836}
837
838#[user_doc(
839 doc_section(label = "General Functions"),
840 description = "Returns the minimum value in the specified column.",
841 syntax_example = "min(expression)",
842 sql_example = r#"```sql
843> SELECT min(column_name) FROM table_name;
844+----------------------+
845| min(column_name) |
846+----------------------+
847| 12 |
848+----------------------+
849```"#,
850 standard_argument(name = "expression",)
851)]
852#[derive(Debug)]
853pub struct Min {
854 signature: Signature,
855}
856
857impl Min {
858 pub fn new() -> Self {
859 Self {
860 signature: Signature::user_defined(Volatility::Immutable),
861 }
862 }
863}
864
865impl Default for Min {
866 fn default() -> Self {
867 Self::new()
868 }
869}
870
871impl FromColumnStatistics for Min {
872 fn value_from_column_statistics(
873 &self,
874 col_stats: &ColumnStatistics,
875 ) -> Option<ScalarValue> {
876 if let Precision::Exact(ref val) = col_stats.min_value {
877 if !val.is_null() {
878 return Some(val.clone());
879 }
880 }
881 None
882 }
883}
884
885impl AggregateUDFImpl for Min {
886 fn as_any(&self) -> &dyn std::any::Any {
887 self
888 }
889
890 fn name(&self) -> &str {
891 "min"
892 }
893
894 fn signature(&self) -> &Signature {
895 &self.signature
896 }
897
898 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
899 Ok(arg_types[0].to_owned())
900 }
901
902 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
903 Ok(Box::new(MinAccumulator::try_new(
904 acc_args.return_field.data_type(),
905 )?))
906 }
907
908 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
909 use DataType::*;
910 matches!(
911 args.return_field.data_type(),
912 Int8 | Int16
913 | Int32
914 | Int64
915 | UInt8
916 | UInt16
917 | UInt32
918 | UInt64
919 | Float16
920 | Float32
921 | Float64
922 | Decimal128(_, _)
923 | Decimal256(_, _)
924 | Date32
925 | Date64
926 | Time32(_)
927 | Time64(_)
928 | Timestamp(_, _)
929 | Utf8
930 | LargeUtf8
931 | Utf8View
932 | Binary
933 | LargeBinary
934 | BinaryView
935 | Duration(_)
936 | Struct(_)
937 )
938 }
939
940 fn create_groups_accumulator(
941 &self,
942 args: AccumulatorArgs,
943 ) -> Result<Box<dyn GroupsAccumulator>> {
944 use DataType::*;
945 use TimeUnit::*;
946 let data_type = args.return_field.data_type();
947 match data_type {
948 Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
949 Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
950 Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
951 Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
952 UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
953 UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
954 UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
955 UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
956 Float16 => {
957 primitive_min_accumulator!(data_type, f16, Float16Type)
958 }
959 Float32 => {
960 primitive_min_accumulator!(data_type, f32, Float32Type)
961 }
962 Float64 => {
963 primitive_min_accumulator!(data_type, f64, Float64Type)
964 }
965 Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
966 Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
967 Time32(Second) => {
968 primitive_min_accumulator!(data_type, i32, Time32SecondType)
969 }
970 Time32(Millisecond) => {
971 primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
972 }
973 Time64(Microsecond) => {
974 primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
975 }
976 Time64(Nanosecond) => {
977 primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
978 }
979 Timestamp(Second, _) => {
980 primitive_min_accumulator!(data_type, i64, TimestampSecondType)
981 }
982 Timestamp(Millisecond, _) => {
983 primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
984 }
985 Timestamp(Microsecond, _) => {
986 primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
987 }
988 Timestamp(Nanosecond, _) => {
989 primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
990 }
991 Duration(Second) => {
992 primitive_min_accumulator!(data_type, i64, DurationSecondType)
993 }
994 Duration(Millisecond) => {
995 primitive_min_accumulator!(data_type, i64, DurationMillisecondType)
996 }
997 Duration(Microsecond) => {
998 primitive_min_accumulator!(data_type, i64, DurationMicrosecondType)
999 }
1000 Duration(Nanosecond) => {
1001 primitive_min_accumulator!(data_type, i64, DurationNanosecondType)
1002 }
1003 Decimal128(_, _) => {
1004 primitive_min_accumulator!(data_type, i128, Decimal128Type)
1005 }
1006 Decimal256(_, _) => {
1007 primitive_min_accumulator!(data_type, i256, Decimal256Type)
1008 }
1009 Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
1010 Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
1011 }
1012 Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min(
1013 data_type.clone(),
1014 ))),
1015 _ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
1017 }
1018 }
1019
1020 fn create_sliding_accumulator(
1021 &self,
1022 args: AccumulatorArgs,
1023 ) -> Result<Box<dyn Accumulator>> {
1024 Ok(Box::new(SlidingMinAccumulator::try_new(
1025 args.return_field.data_type(),
1026 )?))
1027 }
1028
1029 fn is_descending(&self) -> Option<bool> {
1030 Some(false)
1031 }
1032
1033 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1034 self.value_from_statistics(statistics_args)
1035 }
1036 fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
1037 datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
1038 }
1039
1040 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1041 get_min_max_result_type(arg_types)
1042 }
1043
1044 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1045 datafusion_expr::ReversedUDAF::Identical
1046 }
1047
1048 fn documentation(&self) -> Option<&Documentation> {
1049 self.doc()
1050 }
1051
1052 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
1053 SetMonotonicity::Decreasing
1056 }
1057}
1058
1059#[derive(Debug)]
1061pub struct MinAccumulator {
1062 min: ScalarValue,
1063}
1064
1065impl MinAccumulator {
1066 pub fn try_new(datatype: &DataType) -> Result<Self> {
1068 Ok(Self {
1069 min: ScalarValue::try_from(datatype)?,
1070 })
1071 }
1072}
1073
1074impl Accumulator for MinAccumulator {
1075 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1076 Ok(vec![self.evaluate()?])
1077 }
1078
1079 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1080 let values = &values[0];
1081 let delta = &min_batch(values)?;
1082 let new_min: Result<ScalarValue, DataFusionError> =
1083 min_max!(&self.min, delta, min);
1084 self.min = new_min?;
1085 Ok(())
1086 }
1087
1088 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1089 self.update_batch(states)
1090 }
1091
1092 fn evaluate(&mut self) -> Result<ScalarValue> {
1093 Ok(self.min.clone())
1094 }
1095
1096 fn size(&self) -> usize {
1097 size_of_val(self) - size_of_val(&self.min) + self.min.size()
1098 }
1099}
1100
1101#[derive(Debug)]
1102pub struct SlidingMinAccumulator {
1103 min: ScalarValue,
1104 moving_min: MovingMin<ScalarValue>,
1105}
1106
1107impl SlidingMinAccumulator {
1108 pub fn try_new(datatype: &DataType) -> Result<Self> {
1109 Ok(Self {
1110 min: ScalarValue::try_from(datatype)?,
1111 moving_min: MovingMin::<ScalarValue>::new(),
1112 })
1113 }
1114}
1115
1116impl Accumulator for SlidingMinAccumulator {
1117 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1118 Ok(vec![self.min.clone()])
1119 }
1120
1121 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1122 for idx in 0..values[0].len() {
1123 let val = ScalarValue::try_from_array(&values[0], idx)?;
1124 if !val.is_null() {
1125 self.moving_min.push(val);
1126 }
1127 }
1128 if let Some(res) = self.moving_min.min() {
1129 self.min = res.clone();
1130 }
1131 Ok(())
1132 }
1133
1134 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1135 for idx in 0..values[0].len() {
1136 let val = ScalarValue::try_from_array(&values[0], idx)?;
1137 if !val.is_null() {
1138 (self.moving_min).pop();
1139 }
1140 }
1141 if let Some(res) = self.moving_min.min() {
1142 self.min = res.clone();
1143 }
1144 Ok(())
1145 }
1146
1147 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1148 self.update_batch(states)
1149 }
1150
1151 fn evaluate(&mut self) -> Result<ScalarValue> {
1152 Ok(self.min.clone())
1153 }
1154
1155 fn supports_retract_batch(&self) -> bool {
1156 true
1157 }
1158
1159 fn size(&self) -> usize {
1160 size_of_val(self) - size_of_val(&self.min) + self.min.size()
1161 }
1162}
1163
1164#[derive(Debug)]
1202pub struct MovingMin<T> {
1203 push_stack: Vec<(T, T)>,
1204 pop_stack: Vec<(T, T)>,
1205}
1206
1207impl<T: Clone + PartialOrd> Default for MovingMin<T> {
1208 fn default() -> Self {
1209 Self {
1210 push_stack: Vec::new(),
1211 pop_stack: Vec::new(),
1212 }
1213 }
1214}
1215
1216impl<T: Clone + PartialOrd> MovingMin<T> {
1217 #[inline]
1220 pub fn new() -> Self {
1221 Self::default()
1222 }
1223
1224 #[inline]
1227 pub fn with_capacity(capacity: usize) -> Self {
1228 Self {
1229 push_stack: Vec::with_capacity(capacity),
1230 pop_stack: Vec::with_capacity(capacity),
1231 }
1232 }
1233
1234 #[inline]
1237 pub fn min(&self) -> Option<&T> {
1238 match (self.push_stack.last(), self.pop_stack.last()) {
1239 (None, None) => None,
1240 (Some((_, min)), None) => Some(min),
1241 (None, Some((_, min))) => Some(min),
1242 (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }),
1243 }
1244 }
1245
1246 #[inline]
1248 pub fn push(&mut self, val: T) {
1249 self.push_stack.push(match self.push_stack.last() {
1250 Some((_, min)) => {
1251 if val > *min {
1252 (val, min.clone())
1253 } else {
1254 (val.clone(), val)
1255 }
1256 }
1257 None => (val.clone(), val),
1258 });
1259 }
1260
1261 #[inline]
1263 pub fn pop(&mut self) -> Option<T> {
1264 if self.pop_stack.is_empty() {
1265 match self.push_stack.pop() {
1266 Some((val, _)) => {
1267 let mut last = (val.clone(), val);
1268 self.pop_stack.push(last.clone());
1269 while let Some((val, _)) = self.push_stack.pop() {
1270 let min = if last.1 < val {
1271 last.1.clone()
1272 } else {
1273 val.clone()
1274 };
1275 last = (val.clone(), min);
1276 self.pop_stack.push(last.clone());
1277 }
1278 }
1279 None => return None,
1280 }
1281 }
1282 self.pop_stack.pop().map(|(val, _)| val)
1283 }
1284
1285 #[inline]
1287 pub fn len(&self) -> usize {
1288 self.push_stack.len() + self.pop_stack.len()
1289 }
1290
1291 #[inline]
1293 pub fn is_empty(&self) -> bool {
1294 self.len() == 0
1295 }
1296}
1297
1298#[derive(Debug)]
1322pub struct MovingMax<T> {
1323 push_stack: Vec<(T, T)>,
1324 pop_stack: Vec<(T, T)>,
1325}
1326
1327impl<T: Clone + PartialOrd> Default for MovingMax<T> {
1328 fn default() -> Self {
1329 Self {
1330 push_stack: Vec::new(),
1331 pop_stack: Vec::new(),
1332 }
1333 }
1334}
1335
1336impl<T: Clone + PartialOrd> MovingMax<T> {
1337 #[inline]
1339 pub fn new() -> Self {
1340 Self::default()
1341 }
1342
1343 #[inline]
1346 pub fn with_capacity(capacity: usize) -> Self {
1347 Self {
1348 push_stack: Vec::with_capacity(capacity),
1349 pop_stack: Vec::with_capacity(capacity),
1350 }
1351 }
1352
1353 #[inline]
1355 pub fn max(&self) -> Option<&T> {
1356 match (self.push_stack.last(), self.pop_stack.last()) {
1357 (None, None) => None,
1358 (Some((_, max)), None) => Some(max),
1359 (None, Some((_, max))) => Some(max),
1360 (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }),
1361 }
1362 }
1363
1364 #[inline]
1366 pub fn push(&mut self, val: T) {
1367 self.push_stack.push(match self.push_stack.last() {
1368 Some((_, max)) => {
1369 if val < *max {
1370 (val, max.clone())
1371 } else {
1372 (val.clone(), val)
1373 }
1374 }
1375 None => (val.clone(), val),
1376 });
1377 }
1378
1379 #[inline]
1381 pub fn pop(&mut self) -> Option<T> {
1382 if self.pop_stack.is_empty() {
1383 match self.push_stack.pop() {
1384 Some((val, _)) => {
1385 let mut last = (val.clone(), val);
1386 self.pop_stack.push(last.clone());
1387 while let Some((val, _)) = self.push_stack.pop() {
1388 let max = if last.1 > val {
1389 last.1.clone()
1390 } else {
1391 val.clone()
1392 };
1393 last = (val.clone(), max);
1394 self.pop_stack.push(last.clone());
1395 }
1396 }
1397 None => return None,
1398 }
1399 }
1400 self.pop_stack.pop().map(|(val, _)| val)
1401 }
1402
1403 #[inline]
1405 pub fn len(&self) -> usize {
1406 self.push_stack.len() + self.pop_stack.len()
1407 }
1408
1409 #[inline]
1411 pub fn is_empty(&self) -> bool {
1412 self.len() == 0
1413 }
1414}
1415
1416make_udaf_expr_and_func!(
1417 Max,
1418 max,
1419 expression,
1420 "Returns the maximum of a group of values.",
1421 max_udaf
1422);
1423
1424make_udaf_expr_and_func!(
1425 Min,
1426 min,
1427 expression,
1428 "Returns the minimum of a group of values.",
1429 min_udaf
1430);
1431
1432#[cfg(test)]
1433mod tests {
1434 use super::*;
1435 use arrow::{
1436 array::{
1437 DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray,
1438 IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray,
1439 },
1440 datatypes::{
1441 IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
1442 IntervalYearMonthType,
1443 },
1444 };
1445 use std::sync::Arc;
1446
1447 #[test]
1448 fn interval_min_max() {
1449 let b = IntervalYearMonthArray::from(vec![
1451 IntervalYearMonthType::make_value(0, 1),
1452 IntervalYearMonthType::make_value(5, 34),
1453 IntervalYearMonthType::make_value(-2, 4),
1454 IntervalYearMonthType::make_value(7, -4),
1455 IntervalYearMonthType::make_value(0, 1),
1456 ]);
1457 let b: ArrayRef = Arc::new(b);
1458
1459 let mut min =
1460 MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1461 .unwrap();
1462 min.update_batch(&[Arc::clone(&b)]).unwrap();
1463 let min_res = min.evaluate().unwrap();
1464 assert_eq!(
1465 min_res,
1466 ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1467 -2, 4,
1468 )))
1469 );
1470
1471 let mut max =
1472 MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1473 .unwrap();
1474 max.update_batch(&[Arc::clone(&b)]).unwrap();
1475 let max_res = max.evaluate().unwrap();
1476 assert_eq!(
1477 max_res,
1478 ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1479 5, 34,
1480 )))
1481 );
1482
1483 let b = IntervalDayTimeArray::from(vec![
1485 IntervalDayTimeType::make_value(0, 0),
1486 IntervalDayTimeType::make_value(5, 454000),
1487 IntervalDayTimeType::make_value(-34, 0),
1488 IntervalDayTimeType::make_value(7, -4000),
1489 IntervalDayTimeType::make_value(1, 0),
1490 ]);
1491 let b: ArrayRef = Arc::new(b);
1492
1493 let mut min =
1494 MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1495 min.update_batch(&[Arc::clone(&b)]).unwrap();
1496 let min_res = min.evaluate().unwrap();
1497 assert_eq!(
1498 min_res,
1499 ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0)))
1500 );
1501
1502 let mut max =
1503 MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1504 max.update_batch(&[Arc::clone(&b)]).unwrap();
1505 let max_res = max.evaluate().unwrap();
1506 assert_eq!(
1507 max_res,
1508 ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000)))
1509 );
1510
1511 let b = IntervalMonthDayNanoArray::from(vec![
1513 IntervalMonthDayNanoType::make_value(1, 0, 0),
1514 IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1515 IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1516 IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1517 IntervalMonthDayNanoType::make_value(1, 0, 0),
1518 ]);
1519 let b: ArrayRef = Arc::new(b);
1520
1521 let mut min =
1522 MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1523 .unwrap();
1524 min.update_batch(&[Arc::clone(&b)]).unwrap();
1525 let min_res = min.evaluate().unwrap();
1526 assert_eq!(
1527 min_res,
1528 ScalarValue::IntervalMonthDayNano(Some(
1529 IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000)
1530 ))
1531 );
1532
1533 let mut max =
1534 MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1535 .unwrap();
1536 max.update_batch(&[Arc::clone(&b)]).unwrap();
1537 let max_res = max.evaluate().unwrap();
1538 assert_eq!(
1539 max_res,
1540 ScalarValue::IntervalMonthDayNano(Some(
1541 IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000)
1542 ))
1543 );
1544 }
1545
1546 #[test]
1547 fn float_min_max_with_nans() {
1548 let pos_nan = f32::NAN;
1549 let zero = 0_f32;
1550 let neg_inf = f32::NEG_INFINITY;
1551
1552 let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1553 for batch in values.iter() {
1554 let batch =
1555 Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1556 acc.update_batch(&[batch]).unwrap();
1557 }
1558 let result = acc.evaluate().unwrap();
1559 assert_eq!(result, ScalarValue::Float32(Some(expected)));
1560 };
1561
1562 let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1567 let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1568
1569 check(&mut min(), &[&[zero], &[pos_nan]], zero);
1570 check(&mut min(), &[&[zero, pos_nan]], zero);
1571 check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1572 check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1573 check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1574 check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1575 check(&mut max(), &[&[zero], &[neg_inf]], zero);
1576 check(&mut max(), &[&[zero, neg_inf]], zero);
1577 }
1578
1579 use datafusion_common::Result;
1580 use rand::Rng;
1581
1582 fn get_random_vec_i32(len: usize) -> Vec<i32> {
1583 let mut rng = rand::rng();
1584 let mut input = Vec::with_capacity(len);
1585 for _i in 0..len {
1586 input.push(rng.random_range(0..100));
1587 }
1588 input
1589 }
1590
1591 fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1592 let data = get_random_vec_i32(len);
1593 let mut expected = Vec::with_capacity(len);
1594 let mut moving_min = MovingMin::<i32>::new();
1595 let mut res = Vec::with_capacity(len);
1596 for i in 0..len {
1597 let start = i.saturating_sub(n_sliding_window);
1598 expected.push(*data[start..i + 1].iter().min().unwrap());
1599
1600 moving_min.push(data[i]);
1601 if i > n_sliding_window {
1602 moving_min.pop();
1603 }
1604 res.push(*moving_min.min().unwrap());
1605 }
1606 assert_eq!(res, expected);
1607 Ok(())
1608 }
1609
1610 fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1611 let data = get_random_vec_i32(len);
1612 let mut expected = Vec::with_capacity(len);
1613 let mut moving_max = MovingMax::<i32>::new();
1614 let mut res = Vec::with_capacity(len);
1615 for i in 0..len {
1616 let start = i.saturating_sub(n_sliding_window);
1617 expected.push(*data[start..i + 1].iter().max().unwrap());
1618
1619 moving_max.push(data[i]);
1620 if i > n_sliding_window {
1621 moving_max.pop();
1622 }
1623 res.push(*moving_max.max().unwrap());
1624 }
1625 assert_eq!(res, expected);
1626 Ok(())
1627 }
1628
1629 #[test]
1630 fn moving_min_tests() -> Result<()> {
1631 moving_min_i32(100, 10)?;
1632 moving_min_i32(100, 20)?;
1633 moving_min_i32(100, 50)?;
1634 moving_min_i32(100, 100)?;
1635 Ok(())
1636 }
1637
1638 #[test]
1639 fn moving_max_tests() -> Result<()> {
1640 moving_max_i32(100, 10)?;
1641 moving_max_i32(100, 20)?;
1642 moving_max_i32(100, 50)?;
1643 moving_max_i32(100, 100)?;
1644 Ok(())
1645 }
1646
1647 #[test]
1648 fn test_min_max_coerce_types() {
1649 let funs: Vec<Box<dyn AggregateUDFImpl>> =
1651 vec![Box::new(Min::new()), Box::new(Max::new())];
1652 let input_types = vec![
1653 vec![DataType::Int32],
1654 vec![DataType::Decimal128(10, 2)],
1655 vec![DataType::Decimal256(1, 1)],
1656 vec![DataType::Utf8],
1657 ];
1658 for fun in funs {
1659 for input_type in &input_types {
1660 let result = fun.coerce_types(input_type);
1661 assert_eq!(*input_type, result.unwrap());
1662 }
1663 }
1664 }
1665
1666 #[test]
1667 fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
1668 let data_type =
1669 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1670 let result = get_min_max_result_type(&[data_type])?;
1671 assert_eq!(result, vec![DataType::Utf8]);
1672 Ok(())
1673 }
1674
1675 #[test]
1676 fn test_min_max_dictionary() -> Result<()> {
1677 let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
1678 let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
1679 let dict_array =
1680 DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1681 let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1682 let rt_type =
1683 get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone();
1684
1685 let mut min_acc = MinAccumulator::try_new(&rt_type)?;
1686 min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1687 let min_result = min_acc.evaluate()?;
1688 assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1689
1690 let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
1691 max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1692 let max_result = max_acc.evaluate()?;
1693 assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
1694 Ok(())
1695 }
1696}