1use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder,
25 downcast_integer,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29 array::{ArrayRef, AsArray},
30 datatypes::{
31 DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32 Float64Type,
33 },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39 ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43 DataFusionError, Result, ScalarValue, assert_eq_or_internal_err,
44 internal_datafusion_err,
45};
46use datafusion_expr::function::StateFieldsArgs;
47use datafusion_expr::{
48 Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
49 function::AccumulatorArgs, utils::format_state_name,
50};
51use datafusion_expr::{EmitTo, GroupsAccumulator};
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
53use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
54use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
55use datafusion_macros::user_doc;
56use std::collections::HashMap;
57
58make_udaf_expr_and_func!(
59 Median,
60 median,
61 expression,
62 "Computes the median of a set of numbers",
63 median_udaf
64);
65
66#[user_doc(
67 doc_section(label = "General Functions"),
68 description = "Returns the median value in the specified column.",
69 syntax_example = "median(expression)",
70 sql_example = r#"```sql
71> SELECT median(column_name) FROM table_name;
72+----------------------+
73| median(column_name) |
74+----------------------+
75| 45.5 |
76+----------------------+
77```"#,
78 standard_argument(name = "expression", prefix = "The")
79)]
80#[derive(PartialEq, Eq, Hash)]
89pub struct Median {
90 signature: Signature,
91}
92
93impl Debug for Median {
94 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
95 f.debug_struct("Median")
96 .field("name", &self.name())
97 .field("signature", &self.signature)
98 .finish()
99 }
100}
101
102impl Default for Median {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl Median {
109 pub fn new() -> Self {
110 Self {
111 signature: Signature::numeric(1, Volatility::Immutable),
112 }
113 }
114}
115
116impl AggregateUDFImpl for Median {
117 fn as_any(&self) -> &dyn std::any::Any {
118 self
119 }
120
121 fn name(&self) -> &str {
122 "median"
123 }
124
125 fn signature(&self) -> &Signature {
126 &self.signature
127 }
128
129 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
130 Ok(arg_types[0].clone())
131 }
132
133 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
134 let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
136 let state_name = if args.is_distinct {
137 "distinct_median"
138 } else {
139 "median"
140 };
141
142 Ok(vec![
143 Field::new(
144 format_state_name(args.name, state_name),
145 DataType::List(Arc::new(field)),
146 true,
147 )
148 .into(),
149 ])
150 }
151
152 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
153 macro_rules! helper {
154 ($t:ty, $dt:expr) => {
155 if acc_args.is_distinct {
156 Ok(Box::new(DistinctMedianAccumulator::<$t> {
157 data_type: $dt.clone(),
158 distinct_values: GenericDistinctBuffer::new($dt),
159 }))
160 } else {
161 Ok(Box::new(MedianAccumulator::<$t> {
162 data_type: $dt.clone(),
163 all_values: vec![],
164 }))
165 }
166 };
167 }
168
169 let dt = acc_args.expr_fields[0].data_type().clone();
170 downcast_integer! {
171 dt => (helper, dt),
172 DataType::Float16 => helper!(Float16Type, dt),
173 DataType::Float32 => helper!(Float32Type, dt),
174 DataType::Float64 => helper!(Float64Type, dt),
175 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
176 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
177 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
178 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
179 _ => Err(DataFusionError::NotImplemented(format!(
180 "MedianAccumulator not supported for {} with {}",
181 acc_args.name,
182 dt,
183 ))),
184 }
185 }
186
187 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
188 !args.is_distinct
189 }
190
191 fn create_groups_accumulator(
192 &self,
193 args: AccumulatorArgs,
194 ) -> Result<Box<dyn GroupsAccumulator>> {
195 let num_args = args.exprs.len();
196 assert_eq_or_internal_err!(
197 num_args,
198 1,
199 "median should only have 1 arg, but found num args:{}",
200 num_args
201 );
202
203 let dt = args.expr_fields[0].data_type().clone();
204
205 macro_rules! helper {
206 ($t:ty, $dt:expr) => {
207 Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
208 };
209 }
210
211 downcast_integer! {
212 dt => (helper, dt),
213 DataType::Float16 => helper!(Float16Type, dt),
214 DataType::Float32 => helper!(Float32Type, dt),
215 DataType::Float64 => helper!(Float64Type, dt),
216 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
217 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
218 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
219 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
220 _ => Err(DataFusionError::NotImplemented(format!(
221 "MedianGroupsAccumulator not supported for {} with {}",
222 args.name,
223 dt,
224 ))),
225 }
226 }
227
228 fn documentation(&self) -> Option<&Documentation> {
229 self.doc()
230 }
231}
232
233struct MedianAccumulator<T: ArrowNumericType> {
241 data_type: DataType,
242 all_values: Vec<T::Native>,
243}
244
245impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
246 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
247 write!(f, "MedianAccumulator({})", self.data_type)
248 }
249}
250
251impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
252 fn state(&mut self) -> Result<Vec<ScalarValue>> {
253 let offsets =
257 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
258
259 let values_array = PrimitiveArray::<T>::new(
261 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
262 None,
263 )
264 .with_data_type(self.data_type.clone());
265
266 let list_array = ListArray::new(
268 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
269 offsets,
270 Arc::new(values_array),
271 None,
272 );
273
274 Ok(vec![ScalarValue::List(Arc::new(list_array))])
275 }
276
277 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
278 let values = values[0].as_primitive::<T>();
279 self.all_values.reserve(values.len() - values.null_count());
280 self.all_values.extend(values.iter().flatten());
281 Ok(())
282 }
283
284 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
285 let array = states[0].as_list::<i32>();
286 for v in array.iter().flatten() {
287 self.update_batch(&[v])?
288 }
289 Ok(())
290 }
291
292 fn evaluate(&mut self) -> Result<ScalarValue> {
293 let median = calculate_median::<T>(&mut self.all_values);
294 ScalarValue::new_primitive::<T>(median, &self.data_type)
295 }
296
297 fn size(&self) -> usize {
298 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
299 }
300
301 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
302 let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
303
304 let arr = &values[0];
305 for i in 0..arr.len() {
306 let v = ScalarValue::try_from_array(arr, i)?;
307 if !v.is_null() {
308 *to_remove.entry(v).or_default() += 1;
309 }
310 }
311
312 let mut i = 0;
313 while i < self.all_values.len() {
314 let k = ScalarValue::new_primitive::<T>(
315 Some(self.all_values[i]),
316 &self.data_type,
317 )?;
318 if let Some(count) = to_remove.get_mut(&k)
319 && *count > 0
320 {
321 self.all_values.swap_remove(i);
322 *count -= 1;
323 if *count == 0 {
324 to_remove.remove(&k);
325 if to_remove.is_empty() {
326 break;
327 }
328 }
329 }
330 i += 1;
331 }
332 Ok(())
333 }
334
335 fn supports_retract_batch(&self) -> bool {
336 true
337 }
338}
339
340#[derive(Debug)]
347struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
348 data_type: DataType,
349 group_values: Vec<Vec<T::Native>>,
350}
351
352impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
353 pub fn new(data_type: DataType) -> Self {
354 Self {
355 data_type,
356 group_values: Vec::new(),
357 }
358 }
359}
360
361impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
362 fn update_batch(
363 &mut self,
364 values: &[ArrayRef],
365 group_indices: &[usize],
366 opt_filter: Option<&BooleanArray>,
367 total_num_groups: usize,
368 ) -> Result<()> {
369 assert_eq!(values.len(), 1, "single argument to update_batch");
370 let values = values[0].as_primitive::<T>();
371
372 self.group_values.resize(total_num_groups, Vec::new());
374 accumulate(
375 group_indices,
376 values,
377 opt_filter,
378 |group_index, new_value| {
379 self.group_values[group_index].push(new_value);
380 },
381 );
382
383 Ok(())
384 }
385
386 fn merge_batch(
387 &mut self,
388 values: &[ArrayRef],
389 group_indices: &[usize],
390 _opt_filter: Option<&BooleanArray>,
392 total_num_groups: usize,
393 ) -> Result<()> {
394 assert_eq!(values.len(), 1, "one argument to merge_batch");
395
396 let input_group_values = values[0].as_list::<i32>();
417
418 self.group_values.resize(total_num_groups, Vec::new());
420
421 group_indices
426 .iter()
427 .zip(input_group_values.iter())
428 .for_each(|(&group_index, values_opt)| {
429 if let Some(values) = values_opt {
430 let values = values.as_primitive::<T>();
431 self.group_values[group_index].extend(values.values().iter());
432 }
433 });
434
435 Ok(())
436 }
437
438 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
439 let emit_group_values = emit_to.take_needed(&mut self.group_values);
441
442 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
444 offsets.push(0);
445 let mut cur_len = 0_i32;
446 for group_value in &emit_group_values {
447 cur_len += group_value.len() as i32;
448 offsets.push(cur_len);
449 }
450 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
458
459 let flatten_group_values =
461 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
462 let group_values_array =
463 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
464 .with_data_type(self.data_type.clone());
465
466 let result_list_array = ListArray::new(
468 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
469 offsets,
470 Arc::new(group_values_array),
471 None,
472 );
473
474 Ok(vec![Arc::new(result_list_array)])
475 }
476
477 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
478 let emit_group_values = emit_to.take_needed(&mut self.group_values);
480
481 let mut evaluate_result_builder =
483 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
484 for mut values in emit_group_values {
485 let median = calculate_median::<T>(&mut values);
486 evaluate_result_builder.append_option(median);
487 }
488
489 Ok(Arc::new(evaluate_result_builder.finish()))
490 }
491
492 fn convert_to_state(
493 &self,
494 values: &[ArrayRef],
495 opt_filter: Option<&BooleanArray>,
496 ) -> Result<Vec<ArrayRef>> {
497 assert_eq!(values.len(), 1, "one argument to merge_batch");
498
499 let input_array = values[0].as_primitive::<T>();
500
501 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
510 .with_data_type(self.data_type.clone());
511
512 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
514 internal_datafusion_err!(
515 "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
516 )
517 })?;
518 let offsets = (0..=offset_end).collect::<Vec<_>>();
519 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
521
522 let nulls = filtered_null_mask(opt_filter, input_array);
524
525 let converted_list_array = ListArray::new(
526 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
527 offsets,
528 Arc::new(values),
529 nulls,
530 );
531
532 Ok(vec![Arc::new(converted_list_array)])
533 }
534
535 fn supports_convert_to_state(&self) -> bool {
536 true
537 }
538
539 fn size(&self) -> usize {
540 self.group_values
541 .iter()
542 .map(|values| values.capacity() * size_of::<T>())
543 .sum::<usize>()
544 + self.group_values.capacity() * size_of::<Vec<T>>()
546 }
547}
548
549#[derive(Debug)]
550struct DistinctMedianAccumulator<T: ArrowNumericType> {
551 distinct_values: GenericDistinctBuffer<T>,
552 data_type: DataType,
553}
554
555impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
556 fn state(&mut self) -> Result<Vec<ScalarValue>> {
557 self.distinct_values.state()
558 }
559
560 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
561 self.distinct_values.update_batch(values)
562 }
563
564 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
565 self.distinct_values.merge_batch(states)
566 }
567
568 fn evaluate(&mut self) -> Result<ScalarValue> {
569 let mut d = std::mem::take(&mut self.distinct_values.values)
570 .into_iter()
571 .map(|v| v.0)
572 .collect::<Vec<_>>();
573 let median = calculate_median::<T>(&mut d);
574 ScalarValue::new_primitive::<T>(median, &self.data_type)
575 }
576
577 fn size(&self) -> usize {
578 size_of_val(self) + self.distinct_values.size()
579 }
580}
581
582fn slice_max<T>(array: &[T::Native]) -> T::Native
584where
585 T: ArrowPrimitiveType,
586 T::Native: PartialOrd, {
588 debug_assert!(!array.is_empty());
590 *array
592 .iter()
593 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
594 .unwrap()
595}
596
597fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
598 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
599
600 let len = values.len();
601 if len == 0 {
602 None
603 } else if len % 2 == 0 {
604 let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
605 let left_max = slice_max::<T>(low);
607 let two = T::Native::usize_as(2);
610 let median = match left_max.add_checked(*high) {
611 Ok(sum) => sum.div_wrapping(two),
612 Err(_) => {
613 let half_left = left_max.div_wrapping(two);
617 let half_right = (*high).div_wrapping(two);
618 let rem_left = left_max.mod_wrapping(two);
619 let rem_right = (*high).mod_wrapping(two);
620 let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
623 half_left.add_wrapping(half_right).add_wrapping(correction)
624 }
625 };
626 Some(median)
627 } else {
628 let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
629 Some(*median)
630 }
631}