Skip to main content

datafusion_functions_aggregate/
median.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25    downcast_integer,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29    array::{ArrayRef, AsArray},
30    datatypes::{
31        DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32        Float64Type,
33    },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39    ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43    DataFusionError, Result, ScalarValue, assert_eq_or_internal_err,
44    internal_datafusion_err,
45};
46use datafusion_expr::function::StateFieldsArgs;
47use datafusion_expr::{
48    Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
49    function::AccumulatorArgs, utils::format_state_name,
50};
51use datafusion_expr::{EmitTo, GroupsAccumulator};
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
53use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
54use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
55use datafusion_macros::user_doc;
56use std::collections::HashMap;
57
58make_udaf_expr_and_func!(
59    Median,
60    median,
61    expression,
62    "Computes the median of a set of numbers",
63    median_udaf
64);
65
66#[user_doc(
67    doc_section(label = "General Functions"),
68    description = "Returns the median value in the specified column.",
69    syntax_example = "median(expression)",
70    sql_example = r#"```sql
71> SELECT median(column_name) FROM table_name;
72+----------------------+
73| median(column_name)   |
74+----------------------+
75| 45.5                 |
76+----------------------+
77```"#,
78    standard_argument(name = "expression", prefix = "The")
79)]
80/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a
81/// lot of memory because all values need to be stored in memory before a result can be
82/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more
83/// efficient solution.
84///
85/// If using the distinct variation, the memory usage will be similarly high if the
86/// cardinality is high as it stores all distinct values in memory before computing the
87/// result, but if cardinality is low then memory usage will also be lower.
88#[derive(PartialEq, Eq, Hash, Debug)]
89pub struct Median {
90    signature: Signature,
91}
92
93impl Default for Median {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl Median {
100    pub fn new() -> Self {
101        Self {
102            signature: Signature::numeric(1, Volatility::Immutable),
103        }
104    }
105}
106
107impl AggregateUDFImpl for Median {
108    fn as_any(&self) -> &dyn std::any::Any {
109        self
110    }
111
112    fn name(&self) -> &str {
113        "median"
114    }
115
116    fn signature(&self) -> &Signature {
117        &self.signature
118    }
119
120    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
121        Ok(arg_types[0].clone())
122    }
123
124    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
125        //Intermediate state is a list of the elements we have collected so far
126        let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
127        let state_name = if args.is_distinct {
128            "distinct_median"
129        } else {
130            "median"
131        };
132
133        Ok(vec![
134            Field::new(
135                format_state_name(args.name, state_name),
136                DataType::List(Arc::new(field)),
137                true,
138            )
139            .into(),
140        ])
141    }
142
143    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
144        macro_rules! helper {
145            ($t:ty, $dt:expr) => {
146                if acc_args.is_distinct {
147                    Ok(Box::new(DistinctMedianAccumulator::<$t> {
148                        data_type: $dt.clone(),
149                        distinct_values: GenericDistinctBuffer::new($dt),
150                    }))
151                } else {
152                    Ok(Box::new(MedianAccumulator::<$t> {
153                        data_type: $dt.clone(),
154                        all_values: vec![],
155                    }))
156                }
157            };
158        }
159
160        let dt = acc_args.expr_fields[0].data_type().clone();
161        downcast_integer! {
162            dt => (helper, dt),
163            DataType::Float16 => helper!(Float16Type, dt),
164            DataType::Float32 => helper!(Float32Type, dt),
165            DataType::Float64 => helper!(Float64Type, dt),
166            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
167            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
168            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
169            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
170            _ => Err(DataFusionError::NotImplemented(format!(
171                "MedianAccumulator not supported for {} with {}",
172                acc_args.name,
173                dt,
174            ))),
175        }
176    }
177
178    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
179        !args.is_distinct
180    }
181
182    fn create_groups_accumulator(
183        &self,
184        args: AccumulatorArgs,
185    ) -> Result<Box<dyn GroupsAccumulator>> {
186        let num_args = args.exprs.len();
187        assert_eq_or_internal_err!(
188            num_args,
189            1,
190            "median should only have 1 arg, but found num args:{}",
191            num_args
192        );
193
194        let dt = args.expr_fields[0].data_type().clone();
195
196        macro_rules! helper {
197            ($t:ty, $dt:expr) => {
198                Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
199            };
200        }
201
202        downcast_integer! {
203            dt => (helper, dt),
204            DataType::Float16 => helper!(Float16Type, dt),
205            DataType::Float32 => helper!(Float32Type, dt),
206            DataType::Float64 => helper!(Float64Type, dt),
207            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
208            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
209            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
210            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
211            _ => Err(DataFusionError::NotImplemented(format!(
212                "MedianGroupsAccumulator not supported for {} with {}",
213                args.name,
214                dt,
215            ))),
216        }
217    }
218
219    fn documentation(&self) -> Option<&Documentation> {
220        self.doc()
221    }
222}
223
224/// The median accumulator accumulates the raw input values
225/// as `ScalarValue`s
226///
227/// The intermediate state is represented as a List of scalar values updated by
228/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
229/// in the final evaluation step so that we avoid expensive conversions and
230/// allocations during `update_batch`.
231struct MedianAccumulator<T: ArrowNumericType> {
232    data_type: DataType,
233    all_values: Vec<T::Native>,
234}
235
236impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
237    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238        write!(f, "MedianAccumulator({})", self.data_type)
239    }
240}
241
242impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
243    fn state(&mut self) -> Result<Vec<ScalarValue>> {
244        // Convert `all_values` to `ListArray` and return a single List ScalarValue
245
246        // Build offsets
247        let offsets =
248            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
249
250        // Build inner array
251        let values_array = PrimitiveArray::<T>::new(
252            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
253            None,
254        )
255        .with_data_type(self.data_type.clone());
256
257        // Build the result list array
258        let list_array = ListArray::new(
259            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
260            offsets,
261            Arc::new(values_array),
262            None,
263        );
264
265        Ok(vec![ScalarValue::List(Arc::new(list_array))])
266    }
267
268    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
269        let values = values[0].as_primitive::<T>();
270        self.all_values.reserve(values.len() - values.null_count());
271        self.all_values.extend(values.iter().flatten());
272        Ok(())
273    }
274
275    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
276        let array = states[0].as_list::<i32>();
277        for v in array.iter().flatten() {
278            self.update_batch(&[v])?
279        }
280        Ok(())
281    }
282
283    fn evaluate(&mut self) -> Result<ScalarValue> {
284        let median = calculate_median::<T>(&mut self.all_values);
285        ScalarValue::new_primitive::<T>(median, &self.data_type)
286    }
287
288    fn size(&self) -> usize {
289        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
290    }
291
292    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
293        let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
294
295        let arr = &values[0];
296        for i in 0..arr.len() {
297            let v = ScalarValue::try_from_array(arr, i)?;
298            if !v.is_null() {
299                *to_remove.entry(v).or_default() += 1;
300            }
301        }
302
303        let mut i = 0;
304        while i < self.all_values.len() {
305            let k = ScalarValue::new_primitive::<T>(
306                Some(self.all_values[i]),
307                &self.data_type,
308            )?;
309            if let Some(count) = to_remove.get_mut(&k)
310                && *count > 0
311            {
312                self.all_values.swap_remove(i);
313                *count -= 1;
314                if *count == 0 {
315                    to_remove.remove(&k);
316                    if to_remove.is_empty() {
317                        break;
318                    }
319                }
320            }
321            i += 1;
322        }
323        Ok(())
324    }
325
326    fn supports_retract_batch(&self) -> bool {
327        true
328    }
329}
330
331/// The median groups accumulator accumulates the raw input values
332///
333/// For calculating the accurate medians of groups, we need to store all values
334/// of groups before final evaluation.
335/// So values in each group will be stored in a `Vec<T>`, and the total group values
336/// will be actually organized as a `Vec<Vec<T>>`.
337#[derive(Debug)]
338struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
339    data_type: DataType,
340    group_values: Vec<Vec<T::Native>>,
341}
342
343impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
344    pub fn new(data_type: DataType) -> Self {
345        Self {
346            data_type,
347            group_values: Vec::new(),
348        }
349    }
350}
351
352impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
353    fn update_batch(
354        &mut self,
355        values: &[ArrayRef],
356        group_indices: &[usize],
357        opt_filter: Option<&BooleanArray>,
358        total_num_groups: usize,
359    ) -> Result<()> {
360        assert_eq!(values.len(), 1, "single argument to update_batch");
361        let values = values[0].as_primitive::<T>();
362
363        // Push the `not nulls + not filtered` row into its group
364        self.group_values.resize(total_num_groups, Vec::new());
365        accumulate(
366            group_indices,
367            values,
368            opt_filter,
369            |group_index, new_value| {
370                self.group_values[group_index].push(new_value);
371            },
372        );
373
374        Ok(())
375    }
376
377    fn merge_batch(
378        &mut self,
379        values: &[ArrayRef],
380        group_indices: &[usize],
381        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
382        _opt_filter: Option<&BooleanArray>,
383        total_num_groups: usize,
384    ) -> Result<()> {
385        assert_eq!(values.len(), 1, "one argument to merge_batch");
386
387        // The merged values should be organized like as a `ListArray` which is nullable
388        // (input with nulls usually generated from `convert_to_state`), but `inner array` of
389        // `ListArray`  is `non-nullable`.
390        //
391        // Following is the possible and impossible input `values`:
392        //
393        // # Possible values
394        // ```text
395        //   group 0: [1, 2, 3]
396        //   group 1: null (list array is nullable)
397        //   group 2: [6, 7, 8]
398        //   ...
399        //   group n: [...]
400        // ```
401        //
402        // # Impossible values
403        // ```text
404        //   group x: [1, 2, null] (values in list array is non-nullable)
405        // ```
406        //
407        let input_group_values = values[0].as_list::<i32>();
408
409        // Ensure group values big enough
410        self.group_values.resize(total_num_groups, Vec::new());
411
412        // Extend values to related groups
413        // TODO: avoid using iterator of the `ListArray`, this will lead to
414        // many calls of `slice` of its ``inner array`, and `slice` is not
415        // so efficient(due to the calculation of `null_count` for each `slice`).
416        group_indices
417            .iter()
418            .zip(input_group_values.iter())
419            .for_each(|(&group_index, values_opt)| {
420                if let Some(values) = values_opt {
421                    let values = values.as_primitive::<T>();
422                    self.group_values[group_index].extend(values.values().iter());
423                }
424            });
425
426        Ok(())
427    }
428
429    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
430        // Emit values
431        let emit_group_values = emit_to.take_needed(&mut self.group_values);
432
433        // Build offsets
434        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
435        offsets.push(0);
436        let mut cur_len = 0_i32;
437        for group_value in &emit_group_values {
438            cur_len += group_value.len() as i32;
439            offsets.push(cur_len);
440        }
441        // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
442        // but safety should be considered more carefully here(and I am not sure if it can get
443        // performance improvement when we introduce checks to keep the safety...).
444        //
445        // Can see more details in:
446        // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
447        //
448        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
449
450        // Build inner array
451        let flatten_group_values =
452            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
453        let group_values_array =
454            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
455                .with_data_type(self.data_type.clone());
456
457        // Build the result list array
458        let result_list_array = ListArray::new(
459            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
460            offsets,
461            Arc::new(group_values_array),
462            None,
463        );
464
465        Ok(vec![Arc::new(result_list_array)])
466    }
467
468    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
469        // Emit values
470        let emit_group_values = emit_to.take_needed(&mut self.group_values);
471
472        // Calculate median for each group
473        let mut evaluate_result_builder =
474            PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
475        for mut values in emit_group_values {
476            let median = calculate_median::<T>(&mut values);
477            evaluate_result_builder.append_option(median);
478        }
479
480        Ok(Arc::new(evaluate_result_builder.finish()))
481    }
482
483    fn convert_to_state(
484        &self,
485        values: &[ArrayRef],
486        opt_filter: Option<&BooleanArray>,
487    ) -> Result<Vec<ArrayRef>> {
488        assert_eq!(values.len(), 1, "one argument to merge_batch");
489
490        let input_array = values[0].as_primitive::<T>();
491
492        // Directly convert the input array to states, each row will be
493        // seen as a respective group.
494        // For detail, the `input_array` will be converted to a `ListArray`.
495        // And if row is `not null + not filtered`, it will be converted to a list
496        // with only one element; otherwise, this row in `ListArray` will be set
497        // to null.
498
499        // Reuse values buffer in `input_array` to build `values` in `ListArray`
500        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
501            .with_data_type(self.data_type.clone());
502
503        // `offsets` in `ListArray`, each row as a list element
504        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
505            internal_datafusion_err!(
506                "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
507            )
508        })?;
509        let offsets = (0..=offset_end).collect::<Vec<_>>();
510        // Safety: all checks in `OffsetBuffer::new` are ensured to pass
511        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
512
513        // `nulls` for converted `ListArray`
514        let nulls = filtered_null_mask(opt_filter, input_array);
515
516        let converted_list_array = ListArray::new(
517            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
518            offsets,
519            Arc::new(values),
520            nulls,
521        );
522
523        Ok(vec![Arc::new(converted_list_array)])
524    }
525
526    fn supports_convert_to_state(&self) -> bool {
527        true
528    }
529
530    fn size(&self) -> usize {
531        self.group_values
532            .iter()
533            .map(|values| values.capacity() * size_of::<T>())
534            .sum::<usize>()
535            // account for size of self.grou_values too
536            + self.group_values.capacity() * size_of::<Vec<T>>()
537    }
538}
539
540#[derive(Debug)]
541struct DistinctMedianAccumulator<T: ArrowNumericType> {
542    distinct_values: GenericDistinctBuffer<T>,
543    data_type: DataType,
544}
545
546impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
547    fn state(&mut self) -> Result<Vec<ScalarValue>> {
548        self.distinct_values.state()
549    }
550
551    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
552        self.distinct_values.update_batch(values)
553    }
554
555    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
556        self.distinct_values.merge_batch(states)
557    }
558
559    fn evaluate(&mut self) -> Result<ScalarValue> {
560        let mut d: Vec<T::Native> =
561            self.distinct_values.values.iter().map(|v| v.0).collect();
562        let median = calculate_median::<T>(&mut d);
563        ScalarValue::new_primitive::<T>(median, &self.data_type)
564    }
565
566    fn size(&self) -> usize {
567        size_of_val(self) + self.distinct_values.size()
568    }
569}
570
571/// Get maximum entry in the slice,
572fn slice_max<T>(array: &[T::Native]) -> T::Native
573where
574    T: ArrowPrimitiveType,
575    T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
576{
577    // Make sure that, array is not empty.
578    debug_assert!(!array.is_empty());
579    // `.unwrap()` is safe here as the array is supposed to be non-empty
580    *array
581        .iter()
582        .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
583        .unwrap()
584}
585
586fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
587    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
588
589    let len = values.len();
590    if len == 0 {
591        None
592    } else if len % 2 == 0 {
593        let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
594        // Get the maximum of the low (left side after bi-partitioning)
595        let left_max = slice_max::<T>(low);
596        // Calculate median as the average of the two middle values.
597        // Use checked arithmetic to detect overflow and fall back to safe formula.
598        let two = T::Native::usize_as(2);
599        let median = match left_max.add_checked(*high) {
600            Ok(sum) => sum.div_wrapping(two),
601            Err(_) => {
602                // Overflow detected - use safe midpoint formula:
603                // a/2 + b/2 + ((a%2 + b%2) / 2)
604                // This avoids overflow by dividing before adding.
605                let half_left = left_max.div_wrapping(two);
606                let half_right = (*high).div_wrapping(two);
607                let rem_left = left_max.mod_wrapping(two);
608                let rem_right = (*high).mod_wrapping(two);
609                // The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed)
610                // divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed)
611                let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
612                half_left.add_wrapping(half_right).add_wrapping(correction)
613            }
614        };
615        Some(median)
616    } else {
617        let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
618        Some(*median)
619    }
620}