datafusion_functions_aggregate/
median.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24    downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25    PrimitiveBuilder,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29    array::{ArrayRef, AsArray},
30    datatypes::{
31        DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32        Float64Type,
33    },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39    ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43    internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
44};
45use datafusion_expr::function::StateFieldsArgs;
46use datafusion_expr::{
47    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
48    Documentation, Signature, Volatility,
49};
50use datafusion_expr::{EmitTo, GroupsAccumulator};
51use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
53use datafusion_functions_aggregate_common::utils::Hashable;
54use datafusion_macros::user_doc;
55
56make_udaf_expr_and_func!(
57    Median,
58    median,
59    expression,
60    "Computes the median of a set of numbers",
61    median_udaf
62);
63
64#[user_doc(
65    doc_section(label = "General Functions"),
66    description = "Returns the median value in the specified column.",
67    syntax_example = "median(expression)",
68    sql_example = r#"```sql
69> SELECT median(column_name) FROM table_name;
70+----------------------+
71| median(column_name)   |
72+----------------------+
73| 45.5                 |
74+----------------------+
75```"#,
76    standard_argument(name = "expression", prefix = "The")
77)]
78/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a
79/// lot of memory because all values need to be stored in memory before a result can be
80/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more
81/// efficient solution.
82///
83/// If using the distinct variation, the memory usage will be similarly high if the
84/// cardinality is high as it stores all distinct values in memory before computing the
85/// result, but if cardinality is low then memory usage will also be lower.
86#[derive(PartialEq, Eq, Hash)]
87pub struct Median {
88    signature: Signature,
89}
90
91impl Debug for Median {
92    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
93        f.debug_struct("Median")
94            .field("name", &self.name())
95            .field("signature", &self.signature)
96            .finish()
97    }
98}
99
100impl Default for Median {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl Median {
107    pub fn new() -> Self {
108        Self {
109            signature: Signature::numeric(1, Volatility::Immutable),
110        }
111    }
112}
113
114impl AggregateUDFImpl for Median {
115    fn as_any(&self) -> &dyn std::any::Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "median"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128        Ok(arg_types[0].clone())
129    }
130
131    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
132        //Intermediate state is a list of the elements we have collected so far
133        let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
134        let state_name = if args.is_distinct {
135            "distinct_median"
136        } else {
137            "median"
138        };
139
140        Ok(vec![Field::new(
141            format_state_name(args.name, state_name),
142            DataType::List(Arc::new(field)),
143            true,
144        )
145        .into()])
146    }
147
148    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
149        macro_rules! helper {
150            ($t:ty, $dt:expr) => {
151                if acc_args.is_distinct {
152                    Ok(Box::new(DistinctMedianAccumulator::<$t> {
153                        data_type: $dt.clone(),
154                        distinct_values: HashSet::new(),
155                    }))
156                } else {
157                    Ok(Box::new(MedianAccumulator::<$t> {
158                        data_type: $dt.clone(),
159                        all_values: vec![],
160                    }))
161                }
162            };
163        }
164
165        let dt = acc_args.expr_fields[0].data_type().clone();
166        downcast_integer! {
167            dt => (helper, dt),
168            DataType::Float16 => helper!(Float16Type, dt),
169            DataType::Float32 => helper!(Float32Type, dt),
170            DataType::Float64 => helper!(Float64Type, dt),
171            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
172            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
173            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
174            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
175            _ => Err(DataFusionError::NotImplemented(format!(
176                "MedianAccumulator not supported for {} with {}",
177                acc_args.name,
178                dt,
179            ))),
180        }
181    }
182
183    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
184        !args.is_distinct
185    }
186
187    fn create_groups_accumulator(
188        &self,
189        args: AccumulatorArgs,
190    ) -> Result<Box<dyn GroupsAccumulator>> {
191        let num_args = args.exprs.len();
192        if num_args != 1 {
193            return internal_err!(
194                "median should only have 1 arg, but found num args:{}",
195                args.exprs.len()
196            );
197        }
198
199        let dt = args.expr_fields[0].data_type().clone();
200
201        macro_rules! helper {
202            ($t:ty, $dt:expr) => {
203                Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
204            };
205        }
206
207        downcast_integer! {
208            dt => (helper, dt),
209            DataType::Float16 => helper!(Float16Type, dt),
210            DataType::Float32 => helper!(Float32Type, dt),
211            DataType::Float64 => helper!(Float64Type, dt),
212            DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
213            DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
214            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
215            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
216            _ => Err(DataFusionError::NotImplemented(format!(
217                "MedianGroupsAccumulator not supported for {} with {}",
218                args.name,
219                dt,
220            ))),
221        }
222    }
223
224    fn documentation(&self) -> Option<&Documentation> {
225        self.doc()
226    }
227}
228
229/// The median accumulator accumulates the raw input values
230/// as `ScalarValue`s
231///
232/// The intermediate state is represented as a List of scalar values updated by
233/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
234/// in the final evaluation step so that we avoid expensive conversions and
235/// allocations during `update_batch`.
236struct MedianAccumulator<T: ArrowNumericType> {
237    data_type: DataType,
238    all_values: Vec<T::Native>,
239}
240
241impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
242    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
243        write!(f, "MedianAccumulator({})", self.data_type)
244    }
245}
246
247impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
248    fn state(&mut self) -> Result<Vec<ScalarValue>> {
249        // Convert `all_values` to `ListArray` and return a single List ScalarValue
250
251        // Build offsets
252        let offsets =
253            OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
254
255        // Build inner array
256        let values_array = PrimitiveArray::<T>::new(
257            ScalarBuffer::from(std::mem::take(&mut self.all_values)),
258            None,
259        )
260        .with_data_type(self.data_type.clone());
261
262        // Build the result list array
263        let list_array = ListArray::new(
264            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
265            offsets,
266            Arc::new(values_array),
267            None,
268        );
269
270        Ok(vec![ScalarValue::List(Arc::new(list_array))])
271    }
272
273    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
274        let values = values[0].as_primitive::<T>();
275        self.all_values.reserve(values.len() - values.null_count());
276        self.all_values.extend(values.iter().flatten());
277        Ok(())
278    }
279
280    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
281        let array = states[0].as_list::<i32>();
282        for v in array.iter().flatten() {
283            self.update_batch(&[v])?
284        }
285        Ok(())
286    }
287
288    fn evaluate(&mut self) -> Result<ScalarValue> {
289        let d = std::mem::take(&mut self.all_values);
290        let median = calculate_median::<T>(d);
291        ScalarValue::new_primitive::<T>(median, &self.data_type)
292    }
293
294    fn size(&self) -> usize {
295        size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
296    }
297}
298
299/// The median groups accumulator accumulates the raw input values
300///
301/// For calculating the accurate medians of groups, we need to store all values
302/// of groups before final evaluation.
303/// So values in each group will be stored in a `Vec<T>`, and the total group values
304/// will be actually organized as a `Vec<Vec<T>>`.
305#[derive(Debug)]
306struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
307    data_type: DataType,
308    group_values: Vec<Vec<T::Native>>,
309}
310
311impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
312    pub fn new(data_type: DataType) -> Self {
313        Self {
314            data_type,
315            group_values: Vec::new(),
316        }
317    }
318}
319
320impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
321    fn update_batch(
322        &mut self,
323        values: &[ArrayRef],
324        group_indices: &[usize],
325        opt_filter: Option<&BooleanArray>,
326        total_num_groups: usize,
327    ) -> Result<()> {
328        assert_eq!(values.len(), 1, "single argument to update_batch");
329        let values = values[0].as_primitive::<T>();
330
331        // Push the `not nulls + not filtered` row into its group
332        self.group_values.resize(total_num_groups, Vec::new());
333        accumulate(
334            group_indices,
335            values,
336            opt_filter,
337            |group_index, new_value| {
338                self.group_values[group_index].push(new_value);
339            },
340        );
341
342        Ok(())
343    }
344
345    fn merge_batch(
346        &mut self,
347        values: &[ArrayRef],
348        group_indices: &[usize],
349        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
350        _opt_filter: Option<&BooleanArray>,
351        total_num_groups: usize,
352    ) -> Result<()> {
353        assert_eq!(values.len(), 1, "one argument to merge_batch");
354
355        // The merged values should be organized like as a `ListArray` which is nullable
356        // (input with nulls usually generated from `convert_to_state`), but `inner array` of
357        // `ListArray`  is `non-nullable`.
358        //
359        // Following is the possible and impossible input `values`:
360        //
361        // # Possible values
362        // ```text
363        //   group 0: [1, 2, 3]
364        //   group 1: null (list array is nullable)
365        //   group 2: [6, 7, 8]
366        //   ...
367        //   group n: [...]
368        // ```
369        //
370        // # Impossible values
371        // ```text
372        //   group x: [1, 2, null] (values in list array is non-nullable)
373        // ```
374        //
375        let input_group_values = values[0].as_list::<i32>();
376
377        // Ensure group values big enough
378        self.group_values.resize(total_num_groups, Vec::new());
379
380        // Extend values to related groups
381        // TODO: avoid using iterator of the `ListArray`, this will lead to
382        // many calls of `slice` of its ``inner array`, and `slice` is not
383        // so efficient(due to the calculation of `null_count` for each `slice`).
384        group_indices
385            .iter()
386            .zip(input_group_values.iter())
387            .for_each(|(&group_index, values_opt)| {
388                if let Some(values) = values_opt {
389                    let values = values.as_primitive::<T>();
390                    self.group_values[group_index].extend(values.values().iter());
391                }
392            });
393
394        Ok(())
395    }
396
397    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
398        // Emit values
399        let emit_group_values = emit_to.take_needed(&mut self.group_values);
400
401        // Build offsets
402        let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
403        offsets.push(0);
404        let mut cur_len = 0_i32;
405        for group_value in &emit_group_values {
406            cur_len += group_value.len() as i32;
407            offsets.push(cur_len);
408        }
409        // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
410        // but safety should be considered more carefully here(and I am not sure if it can get
411        // performance improvement when we introduce checks to keep the safety...).
412        //
413        // Can see more details in:
414        // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
415        //
416        let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
417
418        // Build inner array
419        let flatten_group_values =
420            emit_group_values.into_iter().flatten().collect::<Vec<_>>();
421        let group_values_array =
422            PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
423                .with_data_type(self.data_type.clone());
424
425        // Build the result list array
426        let result_list_array = ListArray::new(
427            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
428            offsets,
429            Arc::new(group_values_array),
430            None,
431        );
432
433        Ok(vec![Arc::new(result_list_array)])
434    }
435
436    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
437        // Emit values
438        let emit_group_values = emit_to.take_needed(&mut self.group_values);
439
440        // Calculate median for each group
441        let mut evaluate_result_builder =
442            PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
443        for values in emit_group_values {
444            let median = calculate_median::<T>(values);
445            evaluate_result_builder.append_option(median);
446        }
447
448        Ok(Arc::new(evaluate_result_builder.finish()))
449    }
450
451    fn convert_to_state(
452        &self,
453        values: &[ArrayRef],
454        opt_filter: Option<&BooleanArray>,
455    ) -> Result<Vec<ArrayRef>> {
456        assert_eq!(values.len(), 1, "one argument to merge_batch");
457
458        let input_array = values[0].as_primitive::<T>();
459
460        // Directly convert the input array to states, each row will be
461        // seen as a respective group.
462        // For detail, the `input_array` will be converted to a `ListArray`.
463        // And if row is `not null + not filtered`, it will be converted to a list
464        // with only one element; otherwise, this row in `ListArray` will be set
465        // to null.
466
467        // Reuse values buffer in `input_array` to build `values` in `ListArray`
468        let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
469            .with_data_type(self.data_type.clone());
470
471        // `offsets` in `ListArray`, each row as a list element
472        let offset_end = i32::try_from(input_array.len()).map_err(|e| {
473            internal_datafusion_err!(
474                "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
475            )
476        })?;
477        let offsets = (0..=offset_end).collect::<Vec<_>>();
478        // Safety: all checks in `OffsetBuffer::new` are ensured to pass
479        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
480
481        // `nulls` for converted `ListArray`
482        let nulls = filtered_null_mask(opt_filter, input_array);
483
484        let converted_list_array = ListArray::new(
485            Arc::new(Field::new_list_field(self.data_type.clone(), true)),
486            offsets,
487            Arc::new(values),
488            nulls,
489        );
490
491        Ok(vec![Arc::new(converted_list_array)])
492    }
493
494    fn supports_convert_to_state(&self) -> bool {
495        true
496    }
497
498    fn size(&self) -> usize {
499        self.group_values
500            .iter()
501            .map(|values| values.capacity() * size_of::<T>())
502            .sum::<usize>()
503            // account for size of self.grou_values too
504            + self.group_values.capacity() * size_of::<Vec<T>>()
505    }
506}
507
508/// The distinct median accumulator accumulates the raw input values
509/// as `ScalarValue`s
510///
511/// The intermediate state is represented as a List of scalar values updated by
512/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
513/// in the final evaluation step so that we avoid expensive conversions and
514/// allocations during `update_batch`.
515struct DistinctMedianAccumulator<T: ArrowNumericType> {
516    data_type: DataType,
517    distinct_values: HashSet<Hashable<T::Native>>,
518}
519
520impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
521    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
522        write!(f, "DistinctMedianAccumulator({})", self.data_type)
523    }
524}
525
526impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
527    fn state(&mut self) -> Result<Vec<ScalarValue>> {
528        let all_values = self
529            .distinct_values
530            .iter()
531            .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
532            .collect::<Result<Vec<_>>>()?;
533
534        let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
535        Ok(vec![ScalarValue::List(arr)])
536    }
537
538    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
539        if values.is_empty() {
540            return Ok(());
541        }
542
543        let array = values[0].as_primitive::<T>();
544        match array.nulls().filter(|x| x.null_count() > 0) {
545            Some(n) => {
546                for idx in n.valid_indices() {
547                    self.distinct_values.insert(Hashable(array.value(idx)));
548                }
549            }
550            None => array.values().iter().for_each(|x| {
551                self.distinct_values.insert(Hashable(*x));
552            }),
553        }
554        Ok(())
555    }
556
557    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
558        let array = states[0].as_list::<i32>();
559        for v in array.iter().flatten() {
560            self.update_batch(&[v])?
561        }
562        Ok(())
563    }
564
565    fn evaluate(&mut self) -> Result<ScalarValue> {
566        let d = std::mem::take(&mut self.distinct_values)
567            .into_iter()
568            .map(|v| v.0)
569            .collect::<Vec<_>>();
570        let median = calculate_median::<T>(d);
571        ScalarValue::new_primitive::<T>(median, &self.data_type)
572    }
573
574    fn size(&self) -> usize {
575        size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
576    }
577}
578
579/// Get maximum entry in the slice,
580fn slice_max<T>(array: &[T::Native]) -> T::Native
581where
582    T: ArrowPrimitiveType,
583    T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
584{
585    // Make sure that, array is not empty.
586    debug_assert!(!array.is_empty());
587    // `.unwrap()` is safe here as the array is supposed to be non-empty
588    *array
589        .iter()
590        .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
591        .unwrap()
592}
593
594fn calculate_median<T: ArrowNumericType>(
595    mut values: Vec<T::Native>,
596) -> Option<T::Native> {
597    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
598
599    let len = values.len();
600    if len == 0 {
601        None
602    } else if len % 2 == 0 {
603        let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
604        // Get the maximum of the low (left side after bi-partitioning)
605        let left_max = slice_max::<T>(low);
606        let median = left_max
607            .add_wrapping(*high)
608            .div_wrapping(T::Native::usize_as(2));
609        Some(median)
610    } else {
611        let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
612        Some(*median)
613    }
614}