1use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25 PrimitiveBuilder,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29 array::{ArrayRef, AsArray},
30 datatypes::{
31 DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32 Float64Type,
33 },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef};
39
40use datafusion_common::{
41 internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
42};
43use datafusion_expr::function::StateFieldsArgs;
44use datafusion_expr::{
45 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
46 Documentation, Signature, Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
51use datafusion_functions_aggregate_common::utils::Hashable;
52use datafusion_macros::user_doc;
53
54make_udaf_expr_and_func!(
55 Median,
56 median,
57 expression,
58 "Computes the median of a set of numbers",
59 median_udaf
60);
61
62#[user_doc(
63 doc_section(label = "General Functions"),
64 description = "Returns the median value in the specified column.",
65 syntax_example = "median(expression)",
66 sql_example = r#"```sql
67> SELECT median(column_name) FROM table_name;
68+----------------------+
69| median(column_name) |
70+----------------------+
71| 45.5 |
72+----------------------+
73```"#,
74 standard_argument(name = "expression", prefix = "The")
75)]
76#[derive(PartialEq, Eq, Hash)]
85pub struct Median {
86 signature: Signature,
87}
88
89impl Debug for Median {
90 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
91 f.debug_struct("Median")
92 .field("name", &self.name())
93 .field("signature", &self.signature)
94 .finish()
95 }
96}
97
98impl Default for Median {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl Median {
105 pub fn new() -> Self {
106 Self {
107 signature: Signature::numeric(1, Volatility::Immutable),
108 }
109 }
110}
111
112impl AggregateUDFImpl for Median {
113 fn as_any(&self) -> &dyn std::any::Any {
114 self
115 }
116
117 fn name(&self) -> &str {
118 "median"
119 }
120
121 fn signature(&self) -> &Signature {
122 &self.signature
123 }
124
125 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126 Ok(arg_types[0].clone())
127 }
128
129 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
130 let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
132 let state_name = if args.is_distinct {
133 "distinct_median"
134 } else {
135 "median"
136 };
137
138 Ok(vec![Field::new(
139 format_state_name(args.name, state_name),
140 DataType::List(Arc::new(field)),
141 true,
142 )
143 .into()])
144 }
145
146 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
147 macro_rules! helper {
148 ($t:ty, $dt:expr) => {
149 if acc_args.is_distinct {
150 Ok(Box::new(DistinctMedianAccumulator::<$t> {
151 data_type: $dt.clone(),
152 distinct_values: HashSet::new(),
153 }))
154 } else {
155 Ok(Box::new(MedianAccumulator::<$t> {
156 data_type: $dt.clone(),
157 all_values: vec![],
158 }))
159 }
160 };
161 }
162
163 let dt = acc_args.exprs[0].data_type(acc_args.schema)?;
164 downcast_integer! {
165 dt => (helper, dt),
166 DataType::Float16 => helper!(Float16Type, dt),
167 DataType::Float32 => helper!(Float32Type, dt),
168 DataType::Float64 => helper!(Float64Type, dt),
169 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
170 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
171 _ => Err(DataFusionError::NotImplemented(format!(
172 "MedianAccumulator not supported for {} with {}",
173 acc_args.name,
174 dt,
175 ))),
176 }
177 }
178
179 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
180 !args.is_distinct
181 }
182
183 fn create_groups_accumulator(
184 &self,
185 args: AccumulatorArgs,
186 ) -> Result<Box<dyn GroupsAccumulator>> {
187 let num_args = args.exprs.len();
188 if num_args != 1 {
189 return internal_err!(
190 "median should only have 1 arg, but found num args:{}",
191 args.exprs.len()
192 );
193 }
194
195 let dt = args.exprs[0].data_type(args.schema)?;
196
197 macro_rules! helper {
198 ($t:ty, $dt:expr) => {
199 Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
200 };
201 }
202
203 downcast_integer! {
204 dt => (helper, dt),
205 DataType::Float16 => helper!(Float16Type, dt),
206 DataType::Float32 => helper!(Float32Type, dt),
207 DataType::Float64 => helper!(Float64Type, dt),
208 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
209 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
210 _ => Err(DataFusionError::NotImplemented(format!(
211 "MedianGroupsAccumulator not supported for {} with {}",
212 args.name,
213 dt,
214 ))),
215 }
216 }
217
218 fn documentation(&self) -> Option<&Documentation> {
219 self.doc()
220 }
221}
222
223struct MedianAccumulator<T: ArrowNumericType> {
231 data_type: DataType,
232 all_values: Vec<T::Native>,
233}
234
235impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
236 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237 write!(f, "MedianAccumulator({})", self.data_type)
238 }
239}
240
241impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
242 fn state(&mut self) -> Result<Vec<ScalarValue>> {
243 let offsets =
247 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
248
249 let values_array = PrimitiveArray::<T>::new(
251 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
252 None,
253 )
254 .with_data_type(self.data_type.clone());
255
256 let list_array = ListArray::new(
258 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
259 offsets,
260 Arc::new(values_array),
261 None,
262 );
263
264 Ok(vec![ScalarValue::List(Arc::new(list_array))])
265 }
266
267 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
268 let values = values[0].as_primitive::<T>();
269 self.all_values.reserve(values.len() - values.null_count());
270 self.all_values.extend(values.iter().flatten());
271 Ok(())
272 }
273
274 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
275 let array = states[0].as_list::<i32>();
276 for v in array.iter().flatten() {
277 self.update_batch(&[v])?
278 }
279 Ok(())
280 }
281
282 fn evaluate(&mut self) -> Result<ScalarValue> {
283 let d = std::mem::take(&mut self.all_values);
284 let median = calculate_median::<T>(d);
285 ScalarValue::new_primitive::<T>(median, &self.data_type)
286 }
287
288 fn size(&self) -> usize {
289 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
290 }
291}
292
293#[derive(Debug)]
301struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
302 data_type: DataType,
303 group_values: Vec<Vec<T::Native>>,
304}
305
306impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
307 pub fn new(data_type: DataType) -> Self {
308 Self {
309 data_type,
310 group_values: Vec::new(),
311 }
312 }
313}
314
315impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
316 fn update_batch(
317 &mut self,
318 values: &[ArrayRef],
319 group_indices: &[usize],
320 opt_filter: Option<&BooleanArray>,
321 total_num_groups: usize,
322 ) -> Result<()> {
323 assert_eq!(values.len(), 1, "single argument to update_batch");
324 let values = values[0].as_primitive::<T>();
325
326 self.group_values.resize(total_num_groups, Vec::new());
328 accumulate(
329 group_indices,
330 values,
331 opt_filter,
332 |group_index, new_value| {
333 self.group_values[group_index].push(new_value);
334 },
335 );
336
337 Ok(())
338 }
339
340 fn merge_batch(
341 &mut self,
342 values: &[ArrayRef],
343 group_indices: &[usize],
344 _opt_filter: Option<&BooleanArray>,
346 total_num_groups: usize,
347 ) -> Result<()> {
348 assert_eq!(values.len(), 1, "one argument to merge_batch");
349
350 let input_group_values = values[0].as_list::<i32>();
371
372 self.group_values.resize(total_num_groups, Vec::new());
374
375 group_indices
380 .iter()
381 .zip(input_group_values.iter())
382 .for_each(|(&group_index, values_opt)| {
383 if let Some(values) = values_opt {
384 let values = values.as_primitive::<T>();
385 self.group_values[group_index].extend(values.values().iter());
386 }
387 });
388
389 Ok(())
390 }
391
392 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
393 let emit_group_values = emit_to.take_needed(&mut self.group_values);
395
396 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
398 offsets.push(0);
399 let mut cur_len = 0_i32;
400 for group_value in &emit_group_values {
401 cur_len += group_value.len() as i32;
402 offsets.push(cur_len);
403 }
404 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
412
413 let flatten_group_values =
415 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
416 let group_values_array =
417 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
418 .with_data_type(self.data_type.clone());
419
420 let result_list_array = ListArray::new(
422 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
423 offsets,
424 Arc::new(group_values_array),
425 None,
426 );
427
428 Ok(vec![Arc::new(result_list_array)])
429 }
430
431 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
432 let emit_group_values = emit_to.take_needed(&mut self.group_values);
434
435 let mut evaluate_result_builder =
437 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
438 for values in emit_group_values {
439 let median = calculate_median::<T>(values);
440 evaluate_result_builder.append_option(median);
441 }
442
443 Ok(Arc::new(evaluate_result_builder.finish()))
444 }
445
446 fn convert_to_state(
447 &self,
448 values: &[ArrayRef],
449 opt_filter: Option<&BooleanArray>,
450 ) -> Result<Vec<ArrayRef>> {
451 assert_eq!(values.len(), 1, "one argument to merge_batch");
452
453 let input_array = values[0].as_primitive::<T>();
454
455 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
464 .with_data_type(self.data_type.clone());
465
466 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
468 internal_datafusion_err!(
469 "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
470 )
471 })?;
472 let offsets = (0..=offset_end).collect::<Vec<_>>();
473 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
475
476 let nulls = filtered_null_mask(opt_filter, input_array);
478
479 let converted_list_array = ListArray::new(
480 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
481 offsets,
482 Arc::new(values),
483 nulls,
484 );
485
486 Ok(vec![Arc::new(converted_list_array)])
487 }
488
489 fn supports_convert_to_state(&self) -> bool {
490 true
491 }
492
493 fn size(&self) -> usize {
494 self.group_values
495 .iter()
496 .map(|values| values.capacity() * size_of::<T>())
497 .sum::<usize>()
498 + self.group_values.capacity() * size_of::<Vec<T>>()
500 }
501}
502
503struct DistinctMedianAccumulator<T: ArrowNumericType> {
511 data_type: DataType,
512 distinct_values: HashSet<Hashable<T::Native>>,
513}
514
515impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
516 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
517 write!(f, "DistinctMedianAccumulator({})", self.data_type)
518 }
519}
520
521impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
522 fn state(&mut self) -> Result<Vec<ScalarValue>> {
523 let all_values = self
524 .distinct_values
525 .iter()
526 .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
527 .collect::<Result<Vec<_>>>()?;
528
529 let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
530 Ok(vec![ScalarValue::List(arr)])
531 }
532
533 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
534 if values.is_empty() {
535 return Ok(());
536 }
537
538 let array = values[0].as_primitive::<T>();
539 match array.nulls().filter(|x| x.null_count() > 0) {
540 Some(n) => {
541 for idx in n.valid_indices() {
542 self.distinct_values.insert(Hashable(array.value(idx)));
543 }
544 }
545 None => array.values().iter().for_each(|x| {
546 self.distinct_values.insert(Hashable(*x));
547 }),
548 }
549 Ok(())
550 }
551
552 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
553 let array = states[0].as_list::<i32>();
554 for v in array.iter().flatten() {
555 self.update_batch(&[v])?
556 }
557 Ok(())
558 }
559
560 fn evaluate(&mut self) -> Result<ScalarValue> {
561 let d = std::mem::take(&mut self.distinct_values)
562 .into_iter()
563 .map(|v| v.0)
564 .collect::<Vec<_>>();
565 let median = calculate_median::<T>(d);
566 ScalarValue::new_primitive::<T>(median, &self.data_type)
567 }
568
569 fn size(&self) -> usize {
570 size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
571 }
572}
573
574fn slice_max<T>(array: &[T::Native]) -> T::Native
576where
577 T: ArrowPrimitiveType,
578 T::Native: PartialOrd, {
580 debug_assert!(!array.is_empty());
582 *array
584 .iter()
585 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
586 .unwrap()
587}
588
589fn calculate_median<T: ArrowNumericType>(
590 mut values: Vec<T::Native>,
591) -> Option<T::Native> {
592 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
593
594 let len = values.len();
595 if len == 0 {
596 None
597 } else if len % 2 == 0 {
598 let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
599 let left_max = slice_max::<T>(low);
601 let median = left_max
602 .add_wrapping(*high)
603 .div_wrapping(T::Native::usize_as(2));
604 Some(median)
605 } else {
606 let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
607 Some(*median)
608 }
609}