1use std::cmp::Ordering;
19use std::fmt::{Debug, Formatter};
20use std::mem::{size_of, size_of_val};
21use std::sync::Arc;
22
23use arrow::array::{
24 downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25 PrimitiveBuilder,
26};
27use arrow::buffer::{OffsetBuffer, ScalarBuffer};
28use arrow::{
29 array::{ArrayRef, AsArray},
30 datatypes::{
31 DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
32 Float64Type,
33 },
34};
35
36use arrow::array::Array;
37use arrow::array::ArrowNativeTypeOp;
38use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef};
39
40use datafusion_common::{
41 internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
42};
43use datafusion_expr::function::StateFieldsArgs;
44use datafusion_expr::{
45 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
46 Documentation, Signature, Volatility,
47};
48use datafusion_expr::{EmitTo, GroupsAccumulator};
49use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
51use datafusion_functions_aggregate_common::utils::Hashable;
52use datafusion_macros::user_doc;
53
54make_udaf_expr_and_func!(
55 Median,
56 median,
57 expression,
58 "Computes the median of a set of numbers",
59 median_udaf
60);
61
62#[user_doc(
63 doc_section(label = "General Functions"),
64 description = "Returns the median value in the specified column.",
65 syntax_example = "median(expression)",
66 sql_example = r#"```sql
67> SELECT median(column_name) FROM table_name;
68+----------------------+
69| median(column_name) |
70+----------------------+
71| 45.5 |
72+----------------------+
73```"#,
74 standard_argument(name = "expression", prefix = "The")
75)]
76pub struct Median {
85 signature: Signature,
86}
87
88impl Debug for Median {
89 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
90 f.debug_struct("Median")
91 .field("name", &self.name())
92 .field("signature", &self.signature)
93 .finish()
94 }
95}
96
97impl Default for Median {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl Median {
104 pub fn new() -> Self {
105 Self {
106 signature: Signature::numeric(1, Volatility::Immutable),
107 }
108 }
109}
110
111impl AggregateUDFImpl for Median {
112 fn as_any(&self) -> &dyn std::any::Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 "median"
118 }
119
120 fn signature(&self) -> &Signature {
121 &self.signature
122 }
123
124 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
125 Ok(arg_types[0].clone())
126 }
127
128 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
129 let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true);
131 let state_name = if args.is_distinct {
132 "distinct_median"
133 } else {
134 "median"
135 };
136
137 Ok(vec![Field::new(
138 format_state_name(args.name, state_name),
139 DataType::List(Arc::new(field)),
140 true,
141 )
142 .into()])
143 }
144
145 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
146 macro_rules! helper {
147 ($t:ty, $dt:expr) => {
148 if acc_args.is_distinct {
149 Ok(Box::new(DistinctMedianAccumulator::<$t> {
150 data_type: $dt.clone(),
151 distinct_values: HashSet::new(),
152 }))
153 } else {
154 Ok(Box::new(MedianAccumulator::<$t> {
155 data_type: $dt.clone(),
156 all_values: vec![],
157 }))
158 }
159 };
160 }
161
162 let dt = acc_args.exprs[0].data_type(acc_args.schema)?;
163 downcast_integer! {
164 dt => (helper, dt),
165 DataType::Float16 => helper!(Float16Type, dt),
166 DataType::Float32 => helper!(Float32Type, dt),
167 DataType::Float64 => helper!(Float64Type, dt),
168 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
169 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
170 _ => Err(DataFusionError::NotImplemented(format!(
171 "MedianAccumulator not supported for {} with {}",
172 acc_args.name,
173 dt,
174 ))),
175 }
176 }
177
178 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
179 !args.is_distinct
180 }
181
182 fn create_groups_accumulator(
183 &self,
184 args: AccumulatorArgs,
185 ) -> Result<Box<dyn GroupsAccumulator>> {
186 let num_args = args.exprs.len();
187 if num_args != 1 {
188 return internal_err!(
189 "median should only have 1 arg, but found num args:{}",
190 args.exprs.len()
191 );
192 }
193
194 let dt = args.exprs[0].data_type(args.schema)?;
195
196 macro_rules! helper {
197 ($t:ty, $dt:expr) => {
198 Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
199 };
200 }
201
202 downcast_integer! {
203 dt => (helper, dt),
204 DataType::Float16 => helper!(Float16Type, dt),
205 DataType::Float32 => helper!(Float32Type, dt),
206 DataType::Float64 => helper!(Float64Type, dt),
207 DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
208 DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
209 _ => Err(DataFusionError::NotImplemented(format!(
210 "MedianGroupsAccumulator not supported for {} with {}",
211 args.name,
212 dt,
213 ))),
214 }
215 }
216
217 fn aliases(&self) -> &[String] {
218 &[]
219 }
220
221 fn documentation(&self) -> Option<&Documentation> {
222 self.doc()
223 }
224}
225
226struct MedianAccumulator<T: ArrowNumericType> {
234 data_type: DataType,
235 all_values: Vec<T::Native>,
236}
237
238impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
239 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
240 write!(f, "MedianAccumulator({})", self.data_type)
241 }
242}
243
244impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
245 fn state(&mut self) -> Result<Vec<ScalarValue>> {
246 let offsets =
250 OffsetBuffer::new(ScalarBuffer::from(vec![0, self.all_values.len() as i32]));
251
252 let values_array = PrimitiveArray::<T>::new(
254 ScalarBuffer::from(std::mem::take(&mut self.all_values)),
255 None,
256 )
257 .with_data_type(self.data_type.clone());
258
259 let list_array = ListArray::new(
261 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
262 offsets,
263 Arc::new(values_array),
264 None,
265 );
266
267 Ok(vec![ScalarValue::List(Arc::new(list_array))])
268 }
269
270 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
271 let values = values[0].as_primitive::<T>();
272 self.all_values.reserve(values.len() - values.null_count());
273 self.all_values.extend(values.iter().flatten());
274 Ok(())
275 }
276
277 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
278 let array = states[0].as_list::<i32>();
279 for v in array.iter().flatten() {
280 self.update_batch(&[v])?
281 }
282 Ok(())
283 }
284
285 fn evaluate(&mut self) -> Result<ScalarValue> {
286 let d = std::mem::take(&mut self.all_values);
287 let median = calculate_median::<T>(d);
288 ScalarValue::new_primitive::<T>(median, &self.data_type)
289 }
290
291 fn size(&self) -> usize {
292 size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
293 }
294}
295
296#[derive(Debug)]
304struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
305 data_type: DataType,
306 group_values: Vec<Vec<T::Native>>,
307}
308
309impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
310 pub fn new(data_type: DataType) -> Self {
311 Self {
312 data_type,
313 group_values: Vec::new(),
314 }
315 }
316}
317
318impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
319 fn update_batch(
320 &mut self,
321 values: &[ArrayRef],
322 group_indices: &[usize],
323 opt_filter: Option<&BooleanArray>,
324 total_num_groups: usize,
325 ) -> Result<()> {
326 assert_eq!(values.len(), 1, "single argument to update_batch");
327 let values = values[0].as_primitive::<T>();
328
329 self.group_values.resize(total_num_groups, Vec::new());
331 accumulate(
332 group_indices,
333 values,
334 opt_filter,
335 |group_index, new_value| {
336 self.group_values[group_index].push(new_value);
337 },
338 );
339
340 Ok(())
341 }
342
343 fn merge_batch(
344 &mut self,
345 values: &[ArrayRef],
346 group_indices: &[usize],
347 _opt_filter: Option<&BooleanArray>,
349 total_num_groups: usize,
350 ) -> Result<()> {
351 assert_eq!(values.len(), 1, "one argument to merge_batch");
352
353 let input_group_values = values[0].as_list::<i32>();
374
375 self.group_values.resize(total_num_groups, Vec::new());
377
378 group_indices
383 .iter()
384 .zip(input_group_values.iter())
385 .for_each(|(&group_index, values_opt)| {
386 if let Some(values) = values_opt {
387 let values = values.as_primitive::<T>();
388 self.group_values[group_index].extend(values.values().iter());
389 }
390 });
391
392 Ok(())
393 }
394
395 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
396 let emit_group_values = emit_to.take_needed(&mut self.group_values);
398
399 let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
401 offsets.push(0);
402 let mut cur_len = 0_i32;
403 for group_value in &emit_group_values {
404 cur_len += group_value.len() as i32;
405 offsets.push(cur_len);
406 }
407 let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
415
416 let flatten_group_values =
418 emit_group_values.into_iter().flatten().collect::<Vec<_>>();
419 let group_values_array =
420 PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
421 .with_data_type(self.data_type.clone());
422
423 let result_list_array = ListArray::new(
425 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
426 offsets,
427 Arc::new(group_values_array),
428 None,
429 );
430
431 Ok(vec![Arc::new(result_list_array)])
432 }
433
434 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
435 let emit_group_values = emit_to.take_needed(&mut self.group_values);
437
438 let mut evaluate_result_builder =
440 PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
441 for values in emit_group_values {
442 let median = calculate_median::<T>(values);
443 evaluate_result_builder.append_option(median);
444 }
445
446 Ok(Arc::new(evaluate_result_builder.finish()))
447 }
448
449 fn convert_to_state(
450 &self,
451 values: &[ArrayRef],
452 opt_filter: Option<&BooleanArray>,
453 ) -> Result<Vec<ArrayRef>> {
454 assert_eq!(values.len(), 1, "one argument to merge_batch");
455
456 let input_array = values[0].as_primitive::<T>();
457
458 let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
467 .with_data_type(self.data_type.clone());
468
469 let offset_end = i32::try_from(input_array.len()).map_err(|e| {
471 internal_datafusion_err!(
472 "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
473 )
474 })?;
475 let offsets = (0..=offset_end).collect::<Vec<_>>();
476 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
478
479 let nulls = filtered_null_mask(opt_filter, input_array);
481
482 let converted_list_array = ListArray::new(
483 Arc::new(Field::new_list_field(self.data_type.clone(), true)),
484 offsets,
485 Arc::new(values),
486 nulls,
487 );
488
489 Ok(vec![Arc::new(converted_list_array)])
490 }
491
492 fn supports_convert_to_state(&self) -> bool {
493 true
494 }
495
496 fn size(&self) -> usize {
497 self.group_values
498 .iter()
499 .map(|values| values.capacity() * size_of::<T>())
500 .sum::<usize>()
501 + self.group_values.capacity() * size_of::<Vec<T>>()
503 }
504}
505
506struct DistinctMedianAccumulator<T: ArrowNumericType> {
514 data_type: DataType,
515 distinct_values: HashSet<Hashable<T::Native>>,
516}
517
518impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
519 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520 write!(f, "DistinctMedianAccumulator({})", self.data_type)
521 }
522}
523
524impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
525 fn state(&mut self) -> Result<Vec<ScalarValue>> {
526 let all_values = self
527 .distinct_values
528 .iter()
529 .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
530 .collect::<Result<Vec<_>>>()?;
531
532 let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
533 Ok(vec![ScalarValue::List(arr)])
534 }
535
536 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
537 if values.is_empty() {
538 return Ok(());
539 }
540
541 let array = values[0].as_primitive::<T>();
542 match array.nulls().filter(|x| x.null_count() > 0) {
543 Some(n) => {
544 for idx in n.valid_indices() {
545 self.distinct_values.insert(Hashable(array.value(idx)));
546 }
547 }
548 None => array.values().iter().for_each(|x| {
549 self.distinct_values.insert(Hashable(*x));
550 }),
551 }
552 Ok(())
553 }
554
555 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
556 let array = states[0].as_list::<i32>();
557 for v in array.iter().flatten() {
558 self.update_batch(&[v])?
559 }
560 Ok(())
561 }
562
563 fn evaluate(&mut self) -> Result<ScalarValue> {
564 let d = std::mem::take(&mut self.distinct_values)
565 .into_iter()
566 .map(|v| v.0)
567 .collect::<Vec<_>>();
568 let median = calculate_median::<T>(d);
569 ScalarValue::new_primitive::<T>(median, &self.data_type)
570 }
571
572 fn size(&self) -> usize {
573 size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
574 }
575}
576
577fn slice_max<T>(array: &[T::Native]) -> T::Native
579where
580 T: ArrowPrimitiveType,
581 T::Native: PartialOrd, {
583 debug_assert!(!array.is_empty());
585 *array
587 .iter()
588 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
589 .unwrap()
590}
591
592fn calculate_median<T: ArrowNumericType>(
593 mut values: Vec<T::Native>,
594) -> Option<T::Native> {
595 let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
596
597 let len = values.len();
598 if len == 0 {
599 None
600 } else if len % 2 == 0 {
601 let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
602 let left_max = slice_max::<T>(low);
604 let median = left_max
605 .add_wrapping(*high)
606 .div_wrapping(T::Native::usize_as(2));
607 Some(median)
608 } else {
609 let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
610 Some(*median)
611 }
612}