datafusion_functions_extra/
max_min_by.rs1use datafusion::logical_expr::AggregateUDFImpl;
2use datafusion::{arrow, common, error, functions_aggregate, logical_expr};
3use std::ops::Deref;
4use std::{any, fmt};
5
6make_udaf_expr_and_func!(
7 MaxByFunction,
8 max_by,
9 x y,
10 "Returns the value of the first column corresponding to the maximum value in the second column.",
11 max_by_udaf
12);
13
14pub struct MaxByFunction {
15 signature: logical_expr::Signature,
16}
17
18impl fmt::Debug for MaxByFunction {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 f.debug_struct("MaxBy")
21 .field("name", &self.name())
22 .field("signature", &self.signature)
23 .field("accumulator", &"<FUNC>")
24 .finish()
25 }
26}
27impl Default for MaxByFunction {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl MaxByFunction {
34 pub fn new() -> Self {
35 Self {
36 signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
37 }
38 }
39}
40
41fn get_min_max_by_result_type(
42 input_types: &[arrow::datatypes::DataType],
43) -> error::Result<Vec<arrow::datatypes::DataType>> {
44 match &input_types[0] {
45 arrow::datatypes::DataType::Dictionary(_, dict_value_type) => {
46 Ok(vec![dict_value_type.deref().clone()])
48 }
49 _ => Ok(input_types.to_vec()),
50 }
51}
52
53impl logical_expr::AggregateUDFImpl for MaxByFunction {
54 fn as_any(&self) -> &dyn any::Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "max_by"
60 }
61
62 fn signature(&self) -> &logical_expr::Signature {
63 &self.signature
64 }
65
66 fn return_type(
67 &self,
68 arg_types: &[arrow::datatypes::DataType],
69 ) -> error::Result<arrow::datatypes::DataType> {
70 Ok(arg_types[0].to_owned())
71 }
72
73 fn accumulator(
74 &self,
75 _acc_args: logical_expr::function::AccumulatorArgs,
76 ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
77 common::exec_err!("should not reach here")
78 }
79 fn coerce_types(
80 &self,
81 arg_types: &[arrow::datatypes::DataType],
82 ) -> error::Result<Vec<arrow::datatypes::DataType>> {
83 get_min_max_by_result_type(arg_types)
84 }
85
86 fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
87 let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
88 _: &dyn logical_expr::simplify::SimplifyInfo| {
89 let mut order_by = aggr_func.params.order_by;
90 let (second_arg, first_arg) = (
91 aggr_func.params.args.remove(1),
92 aggr_func.params.args.remove(0),
93 );
94 let sort = logical_expr::expr::Sort::new(second_arg, true, false);
95 order_by.push(sort);
96 let func = logical_expr::expr::Expr::AggregateFunction(
97 logical_expr::expr::AggregateFunction::new_udf(
98 functions_aggregate::first_last::last_value_udaf(),
99 vec![first_arg],
100 aggr_func.params.distinct,
101 aggr_func.params.filter,
102 order_by,
103 aggr_func.params.null_treatment,
104 ),
105 );
106 Ok(func)
107 };
108 Some(Box::new(simplify))
109 }
110}
111
112make_udaf_expr_and_func!(
113 MinByFunction,
114 min_by,
115 x y,
116 "Returns the value of the first column corresponding to the minimum value in the second column.",
117 min_by_udaf
118);
119
120pub struct MinByFunction {
121 signature: logical_expr::Signature,
122}
123
124impl fmt::Debug for MinByFunction {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 f.debug_struct("MinBy")
127 .field("name", &self.name())
128 .field("signature", &self.signature)
129 .field("accumulator", &"<FUNC>")
130 .finish()
131 }
132}
133
134impl Default for MinByFunction {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl MinByFunction {
141 pub fn new() -> Self {
142 Self {
143 signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
144 }
145 }
146}
147
148impl logical_expr::AggregateUDFImpl for MinByFunction {
149 fn as_any(&self) -> &dyn any::Any {
150 self
151 }
152
153 fn name(&self) -> &str {
154 "min_by"
155 }
156
157 fn signature(&self) -> &logical_expr::Signature {
158 &self.signature
159 }
160
161 fn return_type(
162 &self,
163 arg_types: &[arrow::datatypes::DataType],
164 ) -> error::Result<arrow::datatypes::DataType> {
165 Ok(arg_types[0].to_owned())
166 }
167
168 fn accumulator(
169 &self,
170 _acc_args: logical_expr::function::AccumulatorArgs,
171 ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
172 common::exec_err!("should not reach here")
173 }
174
175 fn coerce_types(
176 &self,
177 arg_types: &[arrow::datatypes::DataType],
178 ) -> error::Result<Vec<arrow::datatypes::DataType>> {
179 get_min_max_by_result_type(arg_types)
180 }
181
182 fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
183 let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
184 _: &dyn logical_expr::simplify::SimplifyInfo| {
185 let mut order_by = aggr_func.params.order_by;
186 let (second_arg, first_arg) = (
187 aggr_func.params.args.remove(1),
188 aggr_func.params.args.remove(0),
189 );
190
191 let sort = logical_expr::expr::Sort::new(second_arg, false, false);
192 order_by.push(sort); let func = logical_expr::expr::Expr::AggregateFunction(
194 logical_expr::expr::AggregateFunction::new_udf(
195 functions_aggregate::first_last::last_value_udaf(),
196 vec![first_arg],
197 aggr_func.params.distinct,
198 aggr_func.params.filter,
199 order_by,
200 aggr_func.params.null_treatment,
201 ),
202 );
203 Ok(func)
204 };
205 Some(Box::new(simplify))
206 }
207}