datafusion_functions_extra/
max_min_by.rs

1use datafusion::logical_expr::AggregateUDFImpl;
2use datafusion::{arrow, common, error, functions_aggregate, logical_expr};
3use std::ops::Deref;
4use std::{any, fmt};
5
6make_udaf_expr_and_func!(
7    MaxByFunction,
8    max_by,
9    x y,
10    "Returns the value of the first column corresponding to the maximum value in the second column.",
11    max_by_udaf
12);
13
14pub struct MaxByFunction {
15    signature: logical_expr::Signature,
16}
17
18impl fmt::Debug for MaxByFunction {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        f.debug_struct("MaxBy")
21            .field("name", &self.name())
22            .field("signature", &self.signature)
23            .field("accumulator", &"<FUNC>")
24            .finish()
25    }
26}
27impl Default for MaxByFunction {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl MaxByFunction {
34    pub fn new() -> Self {
35        Self {
36            signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
37        }
38    }
39}
40
41fn get_min_max_by_result_type(
42    input_types: &[arrow::datatypes::DataType],
43) -> error::Result<Vec<arrow::datatypes::DataType>> {
44    match &input_types[0] {
45        arrow::datatypes::DataType::Dictionary(_, dict_value_type) => {
46            // TODO add checker, if the value type is complex data type
47            Ok(vec![dict_value_type.deref().clone()])
48        }
49        _ => Ok(input_types.to_vec()),
50    }
51}
52
53impl logical_expr::AggregateUDFImpl for MaxByFunction {
54    fn as_any(&self) -> &dyn any::Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "max_by"
60    }
61
62    fn signature(&self) -> &logical_expr::Signature {
63        &self.signature
64    }
65
66    fn return_type(
67        &self,
68        arg_types: &[arrow::datatypes::DataType],
69    ) -> error::Result<arrow::datatypes::DataType> {
70        Ok(arg_types[0].to_owned())
71    }
72
73    fn accumulator(
74        &self,
75        _acc_args: logical_expr::function::AccumulatorArgs,
76    ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
77        common::exec_err!("should not reach here")
78    }
79    fn coerce_types(
80        &self,
81        arg_types: &[arrow::datatypes::DataType],
82    ) -> error::Result<Vec<arrow::datatypes::DataType>> {
83        get_min_max_by_result_type(arg_types)
84    }
85
86    fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
87        let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
88                        _: &dyn logical_expr::simplify::SimplifyInfo| {
89            let mut order_by = aggr_func.params.order_by;
90            let (second_arg, first_arg) = (
91                aggr_func.params.args.remove(1),
92                aggr_func.params.args.remove(0),
93            );
94            let sort = logical_expr::expr::Sort::new(second_arg, true, false);
95            order_by.push(sort);
96            let func = logical_expr::expr::Expr::AggregateFunction(
97                logical_expr::expr::AggregateFunction::new_udf(
98                    functions_aggregate::first_last::last_value_udaf(),
99                    vec![first_arg],
100                    aggr_func.params.distinct,
101                    aggr_func.params.filter,
102                    order_by,
103                    aggr_func.params.null_treatment,
104                ),
105            );
106            Ok(func)
107        };
108        Some(Box::new(simplify))
109    }
110}
111
112make_udaf_expr_and_func!(
113    MinByFunction,
114    min_by,
115    x y,
116    "Returns the value of the first column corresponding to the minimum value in the second column.",
117    min_by_udaf
118);
119
120pub struct MinByFunction {
121    signature: logical_expr::Signature,
122}
123
124impl fmt::Debug for MinByFunction {
125    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126        f.debug_struct("MinBy")
127            .field("name", &self.name())
128            .field("signature", &self.signature)
129            .field("accumulator", &"<FUNC>")
130            .finish()
131    }
132}
133
134impl Default for MinByFunction {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl MinByFunction {
141    pub fn new() -> Self {
142        Self {
143            signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
144        }
145    }
146}
147
148impl logical_expr::AggregateUDFImpl for MinByFunction {
149    fn as_any(&self) -> &dyn any::Any {
150        self
151    }
152
153    fn name(&self) -> &str {
154        "min_by"
155    }
156
157    fn signature(&self) -> &logical_expr::Signature {
158        &self.signature
159    }
160
161    fn return_type(
162        &self,
163        arg_types: &[arrow::datatypes::DataType],
164    ) -> error::Result<arrow::datatypes::DataType> {
165        Ok(arg_types[0].to_owned())
166    }
167
168    fn accumulator(
169        &self,
170        _acc_args: logical_expr::function::AccumulatorArgs,
171    ) -> error::Result<Box<dyn logical_expr::Accumulator>> {
172        common::exec_err!("should not reach here")
173    }
174
175    fn coerce_types(
176        &self,
177        arg_types: &[arrow::datatypes::DataType],
178    ) -> error::Result<Vec<arrow::datatypes::DataType>> {
179        get_min_max_by_result_type(arg_types)
180    }
181
182    fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
183        let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
184                        _: &dyn logical_expr::simplify::SimplifyInfo| {
185            let mut order_by = aggr_func.params.order_by;
186            let (second_arg, first_arg) = (
187                aggr_func.params.args.remove(1),
188                aggr_func.params.args.remove(0),
189            );
190
191            let sort = logical_expr::expr::Sort::new(second_arg, false, false);
192            order_by.push(sort); // false for ascending sort
193            let func = logical_expr::expr::Expr::AggregateFunction(
194                logical_expr::expr::AggregateFunction::new_udf(
195                    functions_aggregate::first_last::last_value_udaf(),
196                    vec![first_arg],
197                    aggr_func.params.distinct,
198                    aggr_func.params.filter,
199                    order_by,
200                    aggr_func.params.null_treatment,
201                ),
202            );
203            Ok(func)
204        };
205        Some(Box::new(simplify))
206    }
207}