datafusion_functions_aggregate/
median.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25    downcast_integer,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29    array::{ArrayRef, AsArray},
30    datatypes::{
31        DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32        Float64Type,
33    },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39    ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43    DataFusionError, Result, ScalarValue, assert_eq_or_internal_err,
44    internal_datafusion_err,
45};
46use datafusion_expr::function::StateFieldsArgs;
47use datafusion_expr::{
48    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
49    function::AccumulatorArgs, utils::format_state_name,
50};
51use datafusion_expr::{EmitTo, GroupsAccumulator};
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
53use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
54use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
55use datafusion_macros::user_doc;
56use std::collections::HashMap;
57
58make_udaf_expr_and_func!(
59    Median,
60    median,
61    expression,
62    "Computes the median of a set of numbers",
63    median_udaf
64);
65
66#[user_doc(
67    doc_section(label = "General Functions"),
68    description = "Returns the median value in the specified column.",
69    syntax_example = "median(expression)",
70    sql_example = r#"```sql
71> SELECT median(column_name) FROM table_name;
72+----------------------+
73| median(column_name)   |
74+----------------------+
75| 45.5                 |
76+----------------------+
77```"#,
78    standard_argument(name = "expression", prefix = "The")
79)]
80/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a
81/// lot of memory because all values need to be stored in memory before a result can be
82/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more
83/// efficient solution.
84///
85/// If using the distinct variation, the memory usage will be similarly high if the
86/// cardinality is high as it stores all distinct values in memory before computing the
87/// result, but if cardinality is low then memory usage will also be lower.
88#[derive(PartialEq, Eq, Hash)]
89pub struct Median {
90    signature: Signature,
91}
92
93impl Debug for Median {
94    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
95        f.debug_struct("Median")
96            .field("name", &self.name())
97            .field("signature", &self.signature)
98            .finish()
99    }
100}
101
102impl Default for Median {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl Median {
109    pub fn new() -> Self {
110        Self {
111            signature: Signature::numeric(1, Volatility::Immutable),
112        }
113    }
114}
115
116impl AggregateUDFImpl for Median {
117    fn as_any(&self) -> &dyn std::any::Any {
118        self
119    }
120
121    fn name(&self) -> &str {
122        "median"
123    }
124
125    fn signature(&self) -> &Signature {
126        &self.signature
127    }
128
129    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
130        Ok(arg_types[0].clone())
131    }
132
133    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
134        //Intermediate state is a list of the elements we have collected so far
135        let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
136        let state_name = if args.is_distinct {
137            "distinct_median"
138        } else {
139            "median"
140        };
141
142        Ok(vec![
143            Field::new(
144                format_state_name(args.name, state_name),
145                DataType::List(Arc::new(field)),
146                true,
147            )
148            .into(),
149        ])
150    }
151
152    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
153        macro_rules! helper {
154            ($t:ty, $dt:expr) => {
155                if acc_args.is_distinct {
156                    Ok(Box::new(DistinctMedianAccumulator::<$t> {
157                        data_type: $dt.clone(),
158                        distinct_values: GenericDistinctBuffer::new($dt),
159                    }))
160                } else {
161                    Ok(Box::new(MedianAccumulator::<$t> {
162                        data_type: $dt.clone(),
163                        all_values: vec![],
164                    }))
165                }
166            };
167        }
168
169        let dt = acc_args.expr_fields[0].data_type().clone();
170        downcast_integer! {
171            dt => (helper, dt),
172            DataType::Float16 => helper!(Float16Type, dt),
173            DataType::Float32 => helper!(Float32Type, dt),
174            DataType::Float64 => helper!(Float64Type, dt),
175            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
176            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
177            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
178            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
179            _ => Err(DataFusionError::NotImplemented(format!(
180                "MedianAccumulator not supported for {} with {}",
181                acc_args.name,
182                dt,
183            ))),
184        }
185    }
186
187    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
188        !args.is_distinct
189    }
190
191    fn create_groups_accumulator(
192        &self,
193        args: AccumulatorArgs,
194    ) -> Result<Box<dyn GroupsAccumulator>> {
195        let num_args = args.exprs.len();
196        assert_eq_or_internal_err!(
197            num_args,
198            1,
199            "median should only have 1 arg, but found num args:{}",
200            num_args
201        );
202
203        let dt = args.expr_fields[0].data_type().clone();
204
205        macro_rules! helper {
206            ($t:ty, $dt:expr) => {
207                Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
208            };
209        }
210
211        downcast_integer! {
212            dt => (helper, dt),
213            DataType::Float16 => helper!(Float16Type, dt),
214            DataType::Float32 => helper!(Float32Type, dt),
215            DataType::Float64 => helper!(Float64Type, dt),
216            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
217            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
218            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
219            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
220            _ => Err(DataFusionError::NotImplemented(format!(
221                "MedianGroupsAccumulator not supported for {} with {}",
222                args.name,
223                dt,
224            ))),
225        }
226    }
227
228    fn documentation(&self) -> Option<&Documentation> {
229        self.doc()
230    }
231}
232
233/// The median accumulator accumulates the raw input values
234/// as `ScalarValue`s
235///
236/// The intermediate state is represented as a List of scalar values updated by
237/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
238/// in the final evaluation step so that we avoid expensive conversions and
239/// allocations during `update_batch`.
240struct MedianAccumulator<T: ArrowNumericType> {
241    data_type: DataType,
242    all_values: Vec<T::Native>,
243}
244
245impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
246    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
247        write!(f, "MedianAccumulator({})", self.data_type)
248    }
249}
250
251impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
252    fn state(&mut self) -> Result<Vec<ScalarValue>> {
253        // Convert `all_values` to `ListArray` and return a single List ScalarValue
254
255        // Build offsets
256        let offsets =
257            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
258
259        // Build inner array
260        let values_array = PrimitiveArray::<T>::new(
261            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
262            None,
263        )
264        .with_data_type(self.data_type.clone());
265
266        // Build the result list array
267        let list_array = ListArray::new(
268            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
269            offsets,
270            Arc::new(values_array),
271            None,
272        );
273
274        Ok(vec![ScalarValue::List(Arc::new(list_array))])
275    }
276
277    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
278        let values = values[0].as_primitive::<T>();
279        self.all_values.reserve(values.len() - values.null_count());
280        self.all_values.extend(values.iter().flatten());
281        Ok(())
282    }
283
284    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
285        let array = states[0].as_list::<i32>();
286        for v in array.iter().flatten() {
287            self.update_batch(&[v])?
288        }
289        Ok(())
290    }
291
292    fn evaluate(&mut self) -> Result<ScalarValue> {
293        let median = calculate_median::<T>(&mut self.all_values);
294        ScalarValue::new_primitive::<T>(median, &self.data_type)
295    }
296
297    fn size(&self) -> usize {
298        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
299    }
300
301    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
302        let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
303
304        let arr = &values[0];
305        for i in 0..arr.len() {
306            let v = ScalarValue::try_from_array(arr, i)?;
307            if !v.is_null() {
308                *to_remove.entry(v).or_default() += 1;
309            }
310        }
311
312        let mut i = 0;
313        while i < self.all_values.len() {
314            let k = ScalarValue::new_primitive::<T>(
315                Some(self.all_values[i]),
316                &self.data_type,
317            )?;
318            if let Some(count) = to_remove.get_mut(&k)
319                && *count > 0
320            {
321                self.all_values.swap_remove(i);
322                *count -= 1;
323                if *count == 0 {
324                    to_remove.remove(&k);
325                    if to_remove.is_empty() {
326                        break;
327                    }
328                }
329            }
330            i += 1;
331        }
332        Ok(())
333    }
334
335    fn supports_retract_batch(&self) -> bool {
336        true
337    }
338}
339
340/// The median groups accumulator accumulates the raw input values
341///
342/// For calculating the accurate medians of groups, we need to store all values
343/// of groups before final evaluation.
344/// So values in each group will be stored in a `Vec<T>`, and the total group values
345/// will be actually organized as a `Vec<Vec<T>>`.
346#[derive(Debug)]
347struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
348    data_type: DataType,
349    group_values: Vec<Vec<T::Native>>,
350}
351
352impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
353    pub fn new(data_type: DataType) -> Self {
354        Self {
355            data_type,
356            group_values: Vec::new(),
357        }
358    }
359}
360
361impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
362    fn update_batch(
363        &mut self,
364        values: &[ArrayRef],
365        group_indices: &[usize],
366        opt_filter: Option<&BooleanArray>,
367        total_num_groups: usize,
368    ) -> Result<()> {
369        assert_eq!(values.len(), 1, "single argument to update_batch");
370        let values = values[0].as_primitive::<T>();
371
372        // Push the `not nulls + not filtered` row into its group
373        self.group_values.resize(total_num_groups, Vec::new());
374        accumulate(
375            group_indices,
376            values,
377            opt_filter,
378            |group_index, new_value| {
379                self.group_values[group_index].push(new_value);
380            },
381        );
382
383        Ok(())
384    }
385
386    fn merge_batch(
387        &mut self,
388        values: &[ArrayRef],
389        group_indices: &[usize],
390        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
391        _opt_filter: Option<&BooleanArray>,
392        total_num_groups: usize,
393    ) -> Result<()> {
394        assert_eq!(values.len(), 1, "one argument to merge_batch");
395
396        // The merged values should be organized like as a `ListArray` which is nullable
397        // (input with nulls usually generated from `convert_to_state`), but `inner array` of
398        // `ListArray`  is `non-nullable`.
399        //
400        // Following is the possible and impossible input `values`:
401        //
402        // # Possible values
403        // ```text
404        //   group 0: [1, 2, 3]
405        //   group 1: null (list array is nullable)
406        //   group 2: [6, 7, 8]
407        //   ...
408        //   group n: [...]
409        // ```
410        //
411        // # Impossible values
412        // ```text
413        //   group x: [1, 2, null] (values in list array is non-nullable)
414        // ```
415        //
416        let input_group_values = values[0].as_list::<i32>();
417
418        // Ensure group values big enough
419        self.group_values.resize(total_num_groups, Vec::new());
420
421        // Extend values to related groups
422        // TODO: avoid using iterator of the `ListArray`, this will lead to
423        // many calls of `slice` of its ``inner array`, and `slice` is not
424        // so efficient(due to the calculation of `null_count` for each `slice`).
425        group_indices
426            .iter()
427            .zip(input_group_values.iter())
428            .for_each(|(&group_index, values_opt)| {
429                if let Some(values) = values_opt {
430                    let values = values.as_primitive::<T>();
431                    self.group_values[group_index].extend(values.values().iter());
432                }
433            });
434
435        Ok(())
436    }
437
438    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
439        // Emit values
440        let emit_group_values = emit_to.take_needed(&mut self.group_values);
441
442        // Build offsets
443        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
444        offsets.push(0);
445        let mut cur_len = 0_i32;
446        for group_value in &emit_group_values {
447            cur_len += group_value.len() as i32;
448            offsets.push(cur_len);
449        }
450        // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
451        // but safety should be considered more carefully here(and I am not sure if it can get
452        // performance improvement when we introduce checks to keep the safety...).
453        //
454        // Can see more details in:
455        // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
456        //
457        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
458
459        // Build inner array
460        let flatten_group_values =
461            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
462        let group_values_array =
463            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
464                .with_data_type(self.data_type.clone());
465
466        // Build the result list array
467        let result_list_array = ListArray::new(
468            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
469            offsets,
470            Arc::new(group_values_array),
471            None,
472        );
473
474        Ok(vec![Arc::new(result_list_array)])
475    }
476
477    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
478        // Emit values
479        let emit_group_values = emit_to.take_needed(&mut self.group_values);
480
481        // Calculate median for each group
482        let mut evaluate_result_builder =
483            PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
484        for mut values in emit_group_values {
485            let median = calculate_median::<T>(&mut values);
486            evaluate_result_builder.append_option(median);
487        }
488
489        Ok(Arc::new(evaluate_result_builder.finish()))
490    }
491
492    fn convert_to_state(
493        &self,
494        values: &[ArrayRef],
495        opt_filter: Option<&BooleanArray>,
496    ) -> Result<Vec<ArrayRef>> {
497        assert_eq!(values.len(), 1, "one argument to merge_batch");
498
499        let input_array = values[0].as_primitive::<T>();
500
501        // Directly convert the input array to states, each row will be
502        // seen as a respective group.
503        // For detail, the `input_array` will be converted to a `ListArray`.
504        // And if row is `not null + not filtered`, it will be converted to a list
505        // with only one element; otherwise, this row in `ListArray` will be set
506        // to null.
507
508        // Reuse values buffer in `input_array` to build `values` in `ListArray`
509        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
510            .with_data_type(self.data_type.clone());
511
512        // `offsets` in `ListArray`, each row as a list element
513        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
514            internal_datafusion_err!(
515                "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
516            )
517        })?;
518        let offsets = (0..=offset_end).collect::<Vec<_>>();
519        // Safety: all checks in `OffsetBuffer::new` are ensured to pass
520        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
521
522        // `nulls` for converted `ListArray`
523        let nulls = filtered_null_mask(opt_filter, input_array);
524
525        let converted_list_array = ListArray::new(
526            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
527            offsets,
528            Arc::new(values),
529            nulls,
530        );
531
532        Ok(vec![Arc::new(converted_list_array)])
533    }
534
535    fn supports_convert_to_state(&self) -> bool {
536        true
537    }
538
539    fn size(&self) -> usize {
540        self.group_values
541            .iter()
542            .map(|values| values.capacity() * size_of::<T>())
543            .sum::<usize>()
544            // account for size of self.grou_values too
545            + self.group_values.capacity() * size_of::<Vec<T>>()
546    }
547}
548
549#[derive(Debug)]
550struct DistinctMedianAccumulator<T: ArrowNumericType> {
551    distinct_values: GenericDistinctBuffer<T>,
552    data_type: DataType,
553}
554
555impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
556    fn state(&mut self) -> Result<Vec<ScalarValue>> {
557        self.distinct_values.state()
558    }
559
560    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
561        self.distinct_values.update_batch(values)
562    }
563
564    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
565        self.distinct_values.merge_batch(states)
566    }
567
568    fn evaluate(&mut self) -> Result<ScalarValue> {
569        let mut d = std::mem::take(&mut self.distinct_values.values)
570            .into_iter()
571            .map(|v| v.0)
572            .collect::<Vec<_>>();
573        let median = calculate_median::<T>(&mut d);
574        ScalarValue::new_primitive::<T>(median, &self.data_type)
575    }
576
577    fn size(&self) -> usize {
578        size_of_val(self) + self.distinct_values.size()
579    }
580}
581
582/// Get maximum entry in the slice,
583fn slice_max<T>(array: &[T::Native]) -> T::Native
584where
585    T: ArrowPrimitiveType,
586    T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
587{
588    // Make sure that, array is not empty.
589    debug_assert!(!array.is_empty());
590    // `.unwrap()` is safe here as the array is supposed to be non-empty
591    *array
592        .iter()
593        .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
594        .unwrap()
595}
596
597fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
598    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
599
600    let len = values.len();
601    if len == 0 {
602        None
603    } else if len % 2 == 0 {
604        let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
605        // Get the maximum of the low (left side after bi-partitioning)
606        let left_max = slice_max::<T>(low);
607        // Calculate median as the average of the two middle values.
608        // Use checked arithmetic to detect overflow and fall back to safe formula.
609        let two = T::Native::usize_as(2);
610        let median = match left_max.add_checked(*high) {
611            Ok(sum) => sum.div_wrapping(two),
612            Err(_) => {
613                // Overflow detected - use safe midpoint formula:
614                // a/2 + b/2 + ((a%2 + b%2) / 2)
615                // This avoids overflow by dividing before adding.
616                let half_left = left_max.div_wrapping(two);
617                let half_right = (*high).div_wrapping(two);
618                let rem_left = left_max.mod_wrapping(two);
619                let rem_right = (*high).mod_wrapping(two);
620                // The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed)
621                // divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed)
622                let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
623                half_left.add_wrapping(half_right).add_wrapping(correction)
624            }
625        };
626        Some(median)
627    } else {
628        let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
629        Some(*median)
630    }
631}