1use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25 downcast_integer,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29 array::{ArrayRef, AsArray},
30 datatypes::{
31 DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32 Float64Type,
33 },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39 ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43 DataFusionError, Result, ScalarValue, assert_eq_or_internal_err,
44 internal_datafusion_err,
45};
46use datafusion_expr::function::StateFieldsArgs;
47use datafusion_expr::{
48 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
49 function::AccumulatorArgs, utils::format_state_name,
50};
51use datafusion_expr::{EmitTo, GroupsAccumulator};
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
53use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
54use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
55use datafusion_macros::user_doc;
56use std::collections::HashMap;
57
58make_udaf_expr_and_func!(
59 Median,
60 median,
61 expression,
62 "Computes the median of a set of numbers",
63 median_udaf
64);
65
66#[user_doc(
67 doc_section(label = "General Functions"),
68 description = "Returns the median value in the specified column.",
69 syntax_example = "median(expression)",
70 sql_example = r#"```sql
71> SELECT median(column_name) FROM table_name;
72+----------------------+
73| median(column_name) |
74+----------------------+
75| 45.5 |
76+----------------------+
77```"#,
78 standard_argument(name = "expression", prefix = "The")
79)]
80#[derive(PartialEq, Eq, Hash, Debug)]
89pub struct Median {
90 signature: Signature,
91}
92
93impl Default for Median {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl Median {
100 pub fn new() -> Self {
101 Self {
102 signature: Signature::numeric(1, Volatility::Immutable),
103 }
104 }
105}
106
107impl AggregateUDFImpl for Median {
108 fn as_any(&self) -> &dyn std::any::Any {
109 self
110 }
111
112 fn name(&self) -> &str {
113 "median"
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
121 Ok(arg_types[0].clone())
122 }
123
124 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
125 let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
127 let state_name = if args.is_distinct {
128 "distinct_median"
129 } else {
130 "median"
131 };
132
133 Ok(vec![
134 Field::new(
135 format_state_name(args.name, state_name),
136 DataType::List(Arc::new(field)),
137 true,
138 )
139 .into(),
140 ])
141 }
142
143 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
144 macro_rules! helper {
145 ($t:ty, $dt:expr) => {
146 if acc_args.is_distinct {
147 Ok(Box::new(DistinctMedianAccumulator::<$t> {
148 data_type: $dt.clone(),
149 distinct_values: GenericDistinctBuffer::new($dt),
150 }))
151 } else {
152 Ok(Box::new(MedianAccumulator::<$t> {
153 data_type: $dt.clone(),
154 all_values: vec![],
155 }))
156 }
157 };
158 }
159
160 let dt = acc_args.expr_fields[0].data_type().clone();
161 downcast_integer! {
162 dt => (helper, dt),
163 DataType::Float16 => helper!(Float16Type, dt),
164 DataType::Float32 => helper!(Float32Type, dt),
165 DataType::Float64 => helper!(Float64Type, dt),
166 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
167 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
168 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
169 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
170 _ => Err(DataFusionError::NotImplemented(format!(
171 "MedianAccumulator not supported for {} with {}",
172 acc_args.name,
173 dt,
174 ))),
175 }
176 }
177
178 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
179 !args.is_distinct
180 }
181
182 fn create_groups_accumulator(
183 &self,
184 args: AccumulatorArgs,
185 ) -> Result<Box<dyn GroupsAccumulator>> {
186 let num_args = args.exprs.len();
187 assert_eq_or_internal_err!(
188 num_args,
189 1,
190 "median should only have 1 arg, but found num args:{}",
191 num_args
192 );
193
194 let dt = args.expr_fields[0].data_type().clone();
195
196 macro_rules! helper {
197 ($t:ty, $dt:expr) => {
198 Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
199 };
200 }
201
202 downcast_integer! {
203 dt => (helper, dt),
204 DataType::Float16 => helper!(Float16Type, dt),
205 DataType::Float32 => helper!(Float32Type, dt),
206 DataType::Float64 => helper!(Float64Type, dt),
207 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
208 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
209 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
210 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
211 _ => Err(DataFusionError::NotImplemented(format!(
212 "MedianGroupsAccumulator not supported for {} with {}",
213 args.name,
214 dt,
215 ))),
216 }
217 }
218
219 fn documentation(&self) -> Option<&Documentation> {
220 self.doc()
221 }
222}
223
224struct MedianAccumulator<T: ArrowNumericType> {
232 data_type: DataType,
233 all_values: Vec<T::Native>,
234}
235
236impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
237 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238 write!(f, "MedianAccumulator({})", self.data_type)
239 }
240}
241
242impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
243 fn state(&mut self) -> Result<Vec<ScalarValue>> {
244 let offsets =
248 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
249
250 let values_array = PrimitiveArray::<T>::new(
252 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
253 None,
254 )
255 .with_data_type(self.data_type.clone());
256
257 let list_array = ListArray::new(
259 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
260 offsets,
261 Arc::new(values_array),
262 None,
263 );
264
265 Ok(vec![ScalarValue::List(Arc::new(list_array))])
266 }
267
268 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
269 let values = values[0].as_primitive::<T>();
270 self.all_values.reserve(values.len() - values.null_count());
271 self.all_values.extend(values.iter().flatten());
272 Ok(())
273 }
274
275 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
276 let array = states[0].as_list::<i32>();
277 for v in array.iter().flatten() {
278 self.update_batch(&[v])?
279 }
280 Ok(())
281 }
282
283 fn evaluate(&mut self) -> Result<ScalarValue> {
284 let median = calculate_median::<T>(&mut self.all_values);
285 ScalarValue::new_primitive::<T>(median, &self.data_type)
286 }
287
288 fn size(&self) -> usize {
289 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
290 }
291
292 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
293 let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
294
295 let arr = &values[0];
296 for i in 0..arr.len() {
297 let v = ScalarValue::try_from_array(arr, i)?;
298 if !v.is_null() {
299 *to_remove.entry(v).or_default() += 1;
300 }
301 }
302
303 let mut i = 0;
304 while i < self.all_values.len() {
305 let k = ScalarValue::new_primitive::<T>(
306 Some(self.all_values[i]),
307 &self.data_type,
308 )?;
309 if let Some(count) = to_remove.get_mut(&k)
310 && *count > 0
311 {
312 self.all_values.swap_remove(i);
313 *count -= 1;
314 if *count == 0 {
315 to_remove.remove(&k);
316 if to_remove.is_empty() {
317 break;
318 }
319 }
320 }
321 i += 1;
322 }
323 Ok(())
324 }
325
326 fn supports_retract_batch(&self) -> bool {
327 true
328 }
329}
330
331#[derive(Debug)]
338struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
339 data_type: DataType,
340 group_values: Vec<Vec<T::Native>>,
341}
342
343impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
344 pub fn new(data_type: DataType) -> Self {
345 Self {
346 data_type,
347 group_values: Vec::new(),
348 }
349 }
350}
351
352impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
353 fn update_batch(
354 &mut self,
355 values: &[ArrayRef],
356 group_indices: &[usize],
357 opt_filter: Option<&BooleanArray>,
358 total_num_groups: usize,
359 ) -> Result<()> {
360 assert_eq!(values.len(), 1, "single argument to update_batch");
361 let values = values[0].as_primitive::<T>();
362
363 self.group_values.resize(total_num_groups, Vec::new());
365 accumulate(
366 group_indices,
367 values,
368 opt_filter,
369 |group_index, new_value| {
370 self.group_values[group_index].push(new_value);
371 },
372 );
373
374 Ok(())
375 }
376
377 fn merge_batch(
378 &mut self,
379 values: &[ArrayRef],
380 group_indices: &[usize],
381 _opt_filter: Option<&BooleanArray>,
383 total_num_groups: usize,
384 ) -> Result<()> {
385 assert_eq!(values.len(), 1, "one argument to merge_batch");
386
387 let input_group_values = values[0].as_list::<i32>();
408
409 self.group_values.resize(total_num_groups, Vec::new());
411
412 group_indices
417 .iter()
418 .zip(input_group_values.iter())
419 .for_each(|(&group_index, values_opt)| {
420 if let Some(values) = values_opt {
421 let values = values.as_primitive::<T>();
422 self.group_values[group_index].extend(values.values().iter());
423 }
424 });
425
426 Ok(())
427 }
428
429 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
430 let emit_group_values = emit_to.take_needed(&mut self.group_values);
432
433 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
435 offsets.push(0);
436 let mut cur_len = 0_i32;
437 for group_value in &emit_group_values {
438 cur_len += group_value.len() as i32;
439 offsets.push(cur_len);
440 }
441 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
449
450 let flatten_group_values =
452 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
453 let group_values_array =
454 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
455 .with_data_type(self.data_type.clone());
456
457 let result_list_array = ListArray::new(
459 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
460 offsets,
461 Arc::new(group_values_array),
462 None,
463 );
464
465 Ok(vec![Arc::new(result_list_array)])
466 }
467
468 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
469 let emit_group_values = emit_to.take_needed(&mut self.group_values);
471
472 let mut evaluate_result_builder =
474 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
475 for mut values in emit_group_values {
476 let median = calculate_median::<T>(&mut values);
477 evaluate_result_builder.append_option(median);
478 }
479
480 Ok(Arc::new(evaluate_result_builder.finish()))
481 }
482
483 fn convert_to_state(
484 &self,
485 values: &[ArrayRef],
486 opt_filter: Option<&BooleanArray>,
487 ) -> Result<Vec<ArrayRef>> {
488 assert_eq!(values.len(), 1, "one argument to merge_batch");
489
490 let input_array = values[0].as_primitive::<T>();
491
492 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
501 .with_data_type(self.data_type.clone());
502
503 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
505 internal_datafusion_err!(
506 "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
507 )
508 })?;
509 let offsets = (0..=offset_end).collect::<Vec<_>>();
510 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
512
513 let nulls = filtered_null_mask(opt_filter, input_array);
515
516 let converted_list_array = ListArray::new(
517 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
518 offsets,
519 Arc::new(values),
520 nulls,
521 );
522
523 Ok(vec![Arc::new(converted_list_array)])
524 }
525
526 fn supports_convert_to_state(&self) -> bool {
527 true
528 }
529
530 fn size(&self) -> usize {
531 self.group_values
532 .iter()
533 .map(|values| values.capacity() * size_of::<T>())
534 .sum::<usize>()
535 + self.group_values.capacity() * size_of::<Vec<T>>()
537 }
538}
539
540#[derive(Debug)]
541struct DistinctMedianAccumulator<T: ArrowNumericType> {
542 distinct_values: GenericDistinctBuffer<T>,
543 data_type: DataType,
544}
545
546impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
547 fn state(&mut self) -> Result<Vec<ScalarValue>> {
548 self.distinct_values.state()
549 }
550
551 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
552 self.distinct_values.update_batch(values)
553 }
554
555 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
556 self.distinct_values.merge_batch(states)
557 }
558
559 fn evaluate(&mut self) -> Result<ScalarValue> {
560 let mut d: Vec<T::Native> =
561 self.distinct_values.values.iter().map(|v| v.0).collect();
562 let median = calculate_median::<T>(&mut d);
563 ScalarValue::new_primitive::<T>(median, &self.data_type)
564 }
565
566 fn size(&self) -> usize {
567 size_of_val(self) + self.distinct_values.size()
568 }
569}
570
571fn slice_max<T>(array: &[T::Native]) -> T::Native
573where
574 T: ArrowPrimitiveType,
575 T::Native: PartialOrd, {
577 debug_assert!(!array.is_empty());
579 *array
581 .iter()
582 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
583 .unwrap()
584}
585
586fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
587 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
588
589 let len = values.len();
590 if len == 0 {
591 None
592 } else if len % 2 == 0 {
593 let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
594 let left_max = slice_max::<T>(low);
596 let two = T::Native::usize_as(2);
599 let median = match left_max.add_checked(*high) {
600 Ok(sum) => sum.div_wrapping(two),
601 Err(_) => {
602 let half_left = left_max.div_wrapping(two);
606 let half_right = (*high).div_wrapping(two);
607 let rem_left = left_max.mod_wrapping(two);
608 let rem_right = (*high).mod_wrapping(two);
609 let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
612 half_left.add_wrapping(half_right).add_wrapping(correction)
613 }
614 };
615 Some(median)
616 } else {
617 let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
618 Some(*median)
619 }
620}