datafusion_functions_aggregate/
approx_median.rs1use arrow::datatypes::DataType::{Float64, UInt64};
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::types::NativeType;
23use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
24use std::any::Any;
25use std::fmt::Debug;
26use std::sync::Arc;
27
28use datafusion_common::{Result, not_impl_err};
29use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30use datafusion_expr::utils::format_state_name;
31use datafusion_expr::{
32 Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature,
33 TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36
37use crate::approx_percentile_cont::ApproxPercentileAccumulator;
38
39make_udaf_expr_and_func!(
40 ApproxMedian,
41 approx_median,
42 expression,
43 "Computes the approximate median of a set of numbers",
44 approx_median_udaf
45);
46
47#[user_doc(
49 doc_section(label = "Approximate Functions"),
50 description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.",
51 syntax_example = "approx_median(expression)",
52 sql_example = r#"```sql
53> SELECT approx_median(column_name) FROM table_name;
54+-----------------------------------+
55| approx_median(column_name) |
56+-----------------------------------+
57| 23.5 |
58+-----------------------------------+
59```"#,
60 standard_argument(name = "expression",)
61)]
62#[derive(Debug, PartialEq, Eq, Hash)]
63pub struct ApproxMedian {
64 signature: Signature,
65}
66
67impl Default for ApproxMedian {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl ApproxMedian {
74 pub fn new() -> Self {
76 Self {
77 signature: Signature::one_of(
78 vec![
79 TypeSignature::Coercible(vec![Coercion::new_exact(
80 TypeSignatureClass::Integer,
81 )]),
82 TypeSignature::Coercible(vec![Coercion::new_implicit(
83 TypeSignatureClass::Float,
84 vec![TypeSignatureClass::Decimal],
85 NativeType::Float64,
86 )]),
87 ],
88 Volatility::Immutable,
89 ),
90 }
91 }
92}
93
94impl AggregateUDFImpl for ApproxMedian {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
100 if args.input_fields[0].data_type().is_null() {
101 Ok(vec![
102 Field::new(
103 format_state_name(args.name, self.name()),
104 DataType::Null,
105 true,
106 )
107 .into(),
108 ])
109 } else {
110 Ok(vec![
111 Field::new(format_state_name(args.name, "max_size"), UInt64, false),
112 Field::new(format_state_name(args.name, "sum"), Float64, false),
113 Field::new(format_state_name(args.name, "count"), UInt64, false),
114 Field::new(format_state_name(args.name, "max"), Float64, false),
115 Field::new(format_state_name(args.name, "min"), Float64, false),
116 Field::new_list(
117 format_state_name(args.name, "centroids"),
118 Field::new_list_field(Float64, true),
119 false,
120 ),
121 ]
122 .into_iter()
123 .map(Arc::new)
124 .collect())
125 }
126 }
127
128 fn name(&self) -> &str {
129 "approx_median"
130 }
131
132 fn signature(&self) -> &Signature {
133 &self.signature
134 }
135
136 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137 Ok(arg_types[0].clone())
138 }
139
140 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141 if acc_args.is_distinct {
142 return not_impl_err!(
143 "APPROX_MEDIAN(DISTINCT) aggregations are not available"
144 );
145 }
146
147 if acc_args.expr_fields[0].data_type().is_null() {
148 Ok(Box::new(NoopAccumulator::default()))
149 } else {
150 Ok(Box::new(ApproxPercentileAccumulator::new(
151 0.5_f64,
152 acc_args.expr_fields[0].data_type().clone(),
153 )))
154 }
155 }
156
157 fn documentation(&self) -> Option<&Documentation> {
158 self.doc()
159 }
160}