1use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25 PrimitiveBuilder,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29 array::{ArrayRef, AsArray},
30 datatypes::{
31 DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32 Float64Type,
33 },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{
39 ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40};
41
42use datafusion_common::{
43 internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
44};
45use datafusion_expr::function::StateFieldsArgs;
46use datafusion_expr::{
47 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
48 Documentation, Signature, Volatility,
49};
50use datafusion_expr::{EmitTo, GroupsAccumulator};
51use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
52use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
53use datafusion_functions_aggregate_common::utils::Hashable;
54use datafusion_macros::user_doc;
55
56make_udaf_expr_and_func!(
57 Median,
58 median,
59 expression,
60 "Computes the median of a set of numbers",
61 median_udaf
62);
63
64#[user_doc(
65 doc_section(label = "General Functions"),
66 description = "Returns the median value in the specified column.",
67 syntax_example = "median(expression)",
68 sql_example = r#"```sql
69> SELECT median(column_name) FROM table_name;
70+----------------------+
71| median(column_name) |
72+----------------------+
73| 45.5 |
74+----------------------+
75```"#,
76 standard_argument(name = "expression", prefix = "The")
77)]
78#[derive(PartialEq, Eq, Hash)]
87pub struct Median {
88 signature: Signature,
89}
90
91impl Debug for Median {
92 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
93 f.debug_struct("Median")
94 .field("name", &self.name())
95 .field("signature", &self.signature)
96 .finish()
97 }
98}
99
100impl Default for Median {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl Median {
107 pub fn new() -> Self {
108 Self {
109 signature: Signature::numeric(1, Volatility::Immutable),
110 }
111 }
112}
113
114impl AggregateUDFImpl for Median {
115 fn as_any(&self) -> &dyn std::any::Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "median"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128 Ok(arg_types[0].clone())
129 }
130
131 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
132 let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
134 let state_name = if args.is_distinct {
135 "distinct_median"
136 } else {
137 "median"
138 };
139
140 Ok(vec![Field::new(
141 format_state_name(args.name, state_name),
142 DataType::List(Arc::new(field)),
143 true,
144 )
145 .into()])
146 }
147
148 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
149 macro_rules! helper {
150 ($t:ty, $dt:expr) => {
151 if acc_args.is_distinct {
152 Ok(Box::new(DistinctMedianAccumulator::<$t> {
153 data_type: $dt.clone(),
154 distinct_values: HashSet::new(),
155 }))
156 } else {
157 Ok(Box::new(MedianAccumulator::<$t> {
158 data_type: $dt.clone(),
159 all_values: vec![],
160 }))
161 }
162 };
163 }
164
165 let dt = acc_args.expr_fields[0].data_type().clone();
166 downcast_integer! {
167 dt => (helper, dt),
168 DataType::Float16 => helper!(Float16Type, dt),
169 DataType::Float32 => helper!(Float32Type, dt),
170 DataType::Float64 => helper!(Float64Type, dt),
171 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
172 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
173 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
174 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
175 _ => Err(DataFusionError::NotImplemented(format!(
176 "MedianAccumulator not supported for {} with {}",
177 acc_args.name,
178 dt,
179 ))),
180 }
181 }
182
183 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
184 !args.is_distinct
185 }
186
187 fn create_groups_accumulator(
188 &self,
189 args: AccumulatorArgs,
190 ) -> Result<Box<dyn GroupsAccumulator>> {
191 let num_args = args.exprs.len();
192 if num_args != 1 {
193 return internal_err!(
194 "median should only have 1 arg, but found num args:{}",
195 args.exprs.len()
196 );
197 }
198
199 let dt = args.expr_fields[0].data_type().clone();
200
201 macro_rules! helper {
202 ($t:ty, $dt:expr) => {
203 Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
204 };
205 }
206
207 downcast_integer! {
208 dt => (helper, dt),
209 DataType::Float16 => helper!(Float16Type, dt),
210 DataType::Float32 => helper!(Float32Type, dt),
211 DataType::Float64 => helper!(Float64Type, dt),
212 DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
213 DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
214 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
215 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
216 _ => Err(DataFusionError::NotImplemented(format!(
217 "MedianGroupsAccumulator not supported for {} with {}",
218 args.name,
219 dt,
220 ))),
221 }
222 }
223
224 fn documentation(&self) -> Option<&Documentation> {
225 self.doc()
226 }
227}
228
229struct MedianAccumulator<T: ArrowNumericType> {
237 data_type: DataType,
238 all_values: Vec<T::Native>,
239}
240
241impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
242 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
243 write!(f, "MedianAccumulator({})", self.data_type)
244 }
245}
246
247impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
248 fn state(&mut self) -> Result<Vec<ScalarValue>> {
249 let offsets =
253 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
254
255 let values_array = PrimitiveArray::<T>::new(
257 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
258 None,
259 )
260 .with_data_type(self.data_type.clone());
261
262 let list_array = ListArray::new(
264 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
265 offsets,
266 Arc::new(values_array),
267 None,
268 );
269
270 Ok(vec![ScalarValue::List(Arc::new(list_array))])
271 }
272
273 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
274 let values = values[0].as_primitive::<T>();
275 self.all_values.reserve(values.len() - values.null_count());
276 self.all_values.extend(values.iter().flatten());
277 Ok(())
278 }
279
280 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
281 let array = states[0].as_list::<i32>();
282 for v in array.iter().flatten() {
283 self.update_batch(&[v])?
284 }
285 Ok(())
286 }
287
288 fn evaluate(&mut self) -> Result<ScalarValue> {
289 let d = std::mem::take(&mut self.all_values);
290 let median = calculate_median::<T>(d);
291 ScalarValue::new_primitive::<T>(median, &self.data_type)
292 }
293
294 fn size(&self) -> usize {
295 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
296 }
297}
298
299#[derive(Debug)]
306struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
307 data_type: DataType,
308 group_values: Vec<Vec<T::Native>>,
309}
310
311impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
312 pub fn new(data_type: DataType) -> Self {
313 Self {
314 data_type,
315 group_values: Vec::new(),
316 }
317 }
318}
319
320impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
321 fn update_batch(
322 &mut self,
323 values: &[ArrayRef],
324 group_indices: &[usize],
325 opt_filter: Option<&BooleanArray>,
326 total_num_groups: usize,
327 ) -> Result<()> {
328 assert_eq!(values.len(), 1, "single argument to update_batch");
329 let values = values[0].as_primitive::<T>();
330
331 self.group_values.resize(total_num_groups, Vec::new());
333 accumulate(
334 group_indices,
335 values,
336 opt_filter,
337 |group_index, new_value| {
338 self.group_values[group_index].push(new_value);
339 },
340 );
341
342 Ok(())
343 }
344
345 fn merge_batch(
346 &mut self,
347 values: &[ArrayRef],
348 group_indices: &[usize],
349 _opt_filter: Option<&BooleanArray>,
351 total_num_groups: usize,
352 ) -> Result<()> {
353 assert_eq!(values.len(), 1, "one argument to merge_batch");
354
355 let input_group_values = values[0].as_list::<i32>();
376
377 self.group_values.resize(total_num_groups, Vec::new());
379
380 group_indices
385 .iter()
386 .zip(input_group_values.iter())
387 .for_each(|(&group_index, values_opt)| {
388 if let Some(values) = values_opt {
389 let values = values.as_primitive::<T>();
390 self.group_values[group_index].extend(values.values().iter());
391 }
392 });
393
394 Ok(())
395 }
396
397 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
398 let emit_group_values = emit_to.take_needed(&mut self.group_values);
400
401 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
403 offsets.push(0);
404 let mut cur_len = 0_i32;
405 for group_value in &emit_group_values {
406 cur_len += group_value.len() as i32;
407 offsets.push(cur_len);
408 }
409 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
417
418 let flatten_group_values =
420 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
421 let group_values_array =
422 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
423 .with_data_type(self.data_type.clone());
424
425 let result_list_array = ListArray::new(
427 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
428 offsets,
429 Arc::new(group_values_array),
430 None,
431 );
432
433 Ok(vec![Arc::new(result_list_array)])
434 }
435
436 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
437 let emit_group_values = emit_to.take_needed(&mut self.group_values);
439
440 let mut evaluate_result_builder =
442 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
443 for values in emit_group_values {
444 let median = calculate_median::<T>(values);
445 evaluate_result_builder.append_option(median);
446 }
447
448 Ok(Arc::new(evaluate_result_builder.finish()))
449 }
450
451 fn convert_to_state(
452 &self,
453 values: &[ArrayRef],
454 opt_filter: Option<&BooleanArray>,
455 ) -> Result<Vec<ArrayRef>> {
456 assert_eq!(values.len(), 1, "one argument to merge_batch");
457
458 let input_array = values[0].as_primitive::<T>();
459
460 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
469 .with_data_type(self.data_type.clone());
470
471 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
473 internal_datafusion_err!(
474 "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
475 )
476 })?;
477 let offsets = (0..=offset_end).collect::<Vec<_>>();
478 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
480
481 let nulls = filtered_null_mask(opt_filter, input_array);
483
484 let converted_list_array = ListArray::new(
485 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
486 offsets,
487 Arc::new(values),
488 nulls,
489 );
490
491 Ok(vec![Arc::new(converted_list_array)])
492 }
493
494 fn supports_convert_to_state(&self) -> bool {
495 true
496 }
497
498 fn size(&self) -> usize {
499 self.group_values
500 .iter()
501 .map(|values| values.capacity() * size_of::<T>())
502 .sum::<usize>()
503 + self.group_values.capacity() * size_of::<Vec<T>>()
505 }
506}
507
508struct DistinctMedianAccumulator<T: ArrowNumericType> {
516 data_type: DataType,
517 distinct_values: HashSet<Hashable<T::Native>>,
518}
519
520impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
521 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
522 write!(f, "DistinctMedianAccumulator({})", self.data_type)
523 }
524}
525
526impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
527 fn state(&mut self) -> Result<Vec<ScalarValue>> {
528 let all_values = self
529 .distinct_values
530 .iter()
531 .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
532 .collect::<Result<Vec<_>>>()?;
533
534 let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
535 Ok(vec![ScalarValue::List(arr)])
536 }
537
538 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
539 if values.is_empty() {
540 return Ok(());
541 }
542
543 let array = values[0].as_primitive::<T>();
544 match array.nulls().filter(|x| x.null_count() > 0) {
545 Some(n) => {
546 for idx in n.valid_indices() {
547 self.distinct_values.insert(Hashable(array.value(idx)));
548 }
549 }
550 None => array.values().iter().for_each(|x| {
551 self.distinct_values.insert(Hashable(*x));
552 }),
553 }
554 Ok(())
555 }
556
557 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
558 let array = states[0].as_list::<i32>();
559 for v in array.iter().flatten() {
560 self.update_batch(&[v])?
561 }
562 Ok(())
563 }
564
565 fn evaluate(&mut self) -> Result<ScalarValue> {
566 let d = std::mem::take(&mut self.distinct_values)
567 .into_iter()
568 .map(|v| v.0)
569 .collect::<Vec<_>>();
570 let median = calculate_median::<T>(d);
571 ScalarValue::new_primitive::<T>(median, &self.data_type)
572 }
573
574 fn size(&self) -> usize {
575 size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
576 }
577}
578
579fn slice_max<T>(array: &[T::Native]) -> T::Native
581where
582 T: ArrowPrimitiveType,
583 T::Native: PartialOrd, {
585 debug_assert!(!array.is_empty());
587 *array
589 .iter()
590 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
591 .unwrap()
592}
593
594fn calculate_median<T: ArrowNumericType>(
595 mut values: Vec<T::Native>,
596) -> Option<T::Native> {
597 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
598
599 let len = values.len();
600 if len == 0 {
601 None
602 } else if len % 2 == 0 {
603 let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
604 let left_max = slice_max::<T>(low);
606 let median = left_max
607 .add_wrapping(*high)
608 .div_wrapping(T::Native::usize_as(2));
609 Some(median)
610 } else {
611 let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
612 Some(*median)
613 }
614}