datafusion_functions_aggregate/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines NTH_VALUE aggregate expression which may specify ordering requirement
19//! that can evaluated at runtime during query execution
20
21use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
30use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34    lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
35    Signature, SortExpr, Volatility,
36};
37use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
38use datafusion_functions_aggregate_common::utils::ordering_fields;
39use datafusion_macros::user_doc;
40use datafusion_physical_expr::expressions::Literal;
41use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
42
43create_func!(NthValueAgg, nth_value_udaf);
44
45/// Returns the nth value in a group of values.
46pub fn nth_value(
47    expr: datafusion_expr::Expr,
48    n: i64,
49    order_by: Vec<SortExpr>,
50) -> datafusion_expr::Expr {
51    let args = vec![expr, lit(n)];
52    if !order_by.is_empty() {
53        nth_value_udaf()
54            .call(args)
55            .order_by(order_by)
56            .build()
57            .unwrap()
58    } else {
59        nth_value_udaf().call(args)
60    }
61}
62
63#[user_doc(
64    doc_section(label = "Statistical Functions"),
65    description = "Returns the nth value in a group of values.",
66    syntax_example = "nth_value(expression, n ORDER BY expression)",
67    sql_example = r#"```sql
68> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
69  FROM employee;
70+---------+--------+-------------------------+
71| dept_id | salary | second_salary_by_dept   |
72+---------+--------+-------------------------+
73| 1       | 30000  | NULL                    |
74| 1       | 40000  | 40000                   |
75| 1       | 50000  | 40000                   |
76| 2       | 35000  | NULL                    |
77| 2       | 45000  | 45000                   |
78+---------+--------+-------------------------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The column or expression to retrieve the nth value from."
83    ),
84    argument(
85        name = "n",
86        description = "The position (nth) of the value to retrieve, based on the ordering."
87    )
88)]
89/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
90/// partition setting, partial aggregations are computed for every partition,
91/// and then their results are merged.
92#[derive(Debug)]
93pub struct NthValueAgg {
94    signature: Signature,
95}
96
97impl NthValueAgg {
98    /// Create a new `NthValueAgg` aggregate function
99    pub fn new() -> Self {
100        Self {
101            signature: Signature::any(2, Volatility::Immutable),
102        }
103    }
104}
105
106impl Default for NthValueAgg {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl AggregateUDFImpl for NthValueAgg {
113    fn as_any(&self) -> &dyn Any {
114        self
115    }
116
117    fn name(&self) -> &str {
118        "nth_value"
119    }
120
121    fn signature(&self) -> &Signature {
122        &self.signature
123    }
124
125    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126        Ok(arg_types[0].clone())
127    }
128
129    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
130        let n = match acc_args.exprs[1]
131            .as_any()
132            .downcast_ref::<Literal>()
133            .map(|lit| lit.value())
134        {
135            Some(ScalarValue::Int64(Some(value))) => {
136                if acc_args.is_reversed {
137                    -*value
138                } else {
139                    *value
140                }
141            }
142            _ => {
143                return not_impl_err!(
144                    "{} not supported for n: {}",
145                    self.name(),
146                    &acc_args.exprs[1]
147                )
148            }
149        };
150
151        let ordering_dtypes = acc_args
152            .ordering_req
153            .iter()
154            .map(|e| e.expr.data_type(acc_args.schema))
155            .collect::<Result<Vec<_>>>()?;
156
157        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
158        NthValueAccumulator::try_new(
159            n,
160            &data_type,
161            &ordering_dtypes,
162            acc_args.ordering_req.clone(),
163        )
164        .map(|acc| Box::new(acc) as _)
165    }
166
167    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
168        let mut fields = vec![Field::new_list(
169            format_state_name(self.name(), "nth_value"),
170            // See COMMENTS.md to understand why nullable is set to true
171            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
172            false,
173        )];
174        let orderings = args.ordering_fields.to_vec();
175        if !orderings.is_empty() {
176            fields.push(Field::new_list(
177                format_state_name(self.name(), "nth_value_orderings"),
178                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
179                false,
180            ));
181        }
182        Ok(fields.into_iter().map(Arc::new).collect())
183    }
184
185    fn aliases(&self) -> &[String] {
186        &[]
187    }
188
189    fn reverse_expr(&self) -> ReversedUDAF {
190        ReversedUDAF::Reversed(nth_value_udaf())
191    }
192
193    fn documentation(&self) -> Option<&Documentation> {
194        self.doc()
195    }
196}
197
198#[derive(Debug)]
199pub struct NthValueAccumulator {
200    /// The `N` value.
201    n: i64,
202    /// Stores entries in the `NTH_VALUE` result.
203    values: VecDeque<ScalarValue>,
204    /// Stores values of ordering requirement expressions corresponding to each
205    /// entry in `values`. This information is used when merging results from
206    /// different partitions. For detailed information how merging is done, see
207    /// [`merge_ordered_arrays`].
208    ordering_values: VecDeque<Vec<ScalarValue>>,
209    /// Stores datatypes of expressions inside values and ordering requirement
210    /// expressions.
211    datatypes: Vec<DataType>,
212    /// Stores the ordering requirement of the `Accumulator`.
213    ordering_req: LexOrdering,
214}
215
216impl NthValueAccumulator {
217    /// Create a new order-sensitive NTH_VALUE accumulator based on the given
218    /// item data type.
219    pub fn try_new(
220        n: i64,
221        datatype: &DataType,
222        ordering_dtypes: &[DataType],
223        ordering_req: LexOrdering,
224    ) -> Result<Self> {
225        if n == 0 {
226            // n cannot be 0
227            return internal_err!("Nth value indices are 1 based. 0 is invalid index");
228        }
229        let mut datatypes = vec![datatype.clone()];
230        datatypes.extend(ordering_dtypes.iter().cloned());
231        Ok(Self {
232            n,
233            values: VecDeque::new(),
234            ordering_values: VecDeque::new(),
235            datatypes,
236            ordering_req,
237        })
238    }
239}
240
241impl Accumulator for NthValueAccumulator {
242    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
243    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
244    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
245        if values.is_empty() {
246            return Ok(());
247        }
248
249        let n_required = self.n.unsigned_abs() as usize;
250        let from_start = self.n > 0;
251        if from_start {
252            // direction is from start
253            let n_remaining = n_required.saturating_sub(self.values.len());
254            self.append_new_data(values, Some(n_remaining))?;
255        } else {
256            // direction is from end
257            self.append_new_data(values, None)?;
258            let start_offset = self.values.len().saturating_sub(n_required);
259            if start_offset > 0 {
260                self.values.drain(0..start_offset);
261                self.ordering_values.drain(0..start_offset);
262            }
263        }
264
265        Ok(())
266    }
267
268    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
269        if states.is_empty() {
270            return Ok(());
271        }
272        // First entry in the state is the aggregation result.
273        let array_agg_values = &states[0];
274        let n_required = self.n.unsigned_abs() as usize;
275        if self.ordering_req.is_empty() {
276            let array_agg_res =
277                ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
278            for v in array_agg_res.into_iter() {
279                self.values.extend(v);
280                if self.values.len() > n_required {
281                    // There is enough data collected can stop merging
282                    break;
283                }
284            }
285        } else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() {
286            // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list.
287            // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores
288            // values received from its ordering requirement expression. (This information is necessary for during merging).
289
290            // Stores NTH_VALUE results coming from each partition
291            let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![];
292            // Stores ordering requirement expression results coming from each partition
293            let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![];
294
295            // Existing values should be merged also.
296            partition_values.push(self.values.clone());
297
298            partition_ordering_values.push(self.ordering_values.clone());
299
300            let array_agg_res =
301                ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
302
303            for v in array_agg_res.into_iter() {
304                partition_values.push(v.into());
305            }
306
307            let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
308
309            let ordering_values = orderings.into_iter().map(|partition_ordering_rows| {
310                // Extract value from struct to ordering_rows for each group/partition
311                partition_ordering_rows.into_iter().map(|ordering_row| {
312                    if let ScalarValue::Struct(s) = ordering_row {
313                        let mut ordering_columns_per_row = vec![];
314
315                        for column in s.columns() {
316                            let sv = ScalarValue::try_from_array(column, 0)?;
317                            ordering_columns_per_row.push(sv);
318                        }
319
320                        Ok(ordering_columns_per_row)
321                    } else {
322                        exec_err!(
323                            "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
324                            ordering_row.data_type()
325                        )
326                    }
327                }).collect::<Result<Vec<_>>>()
328            }).collect::<Result<Vec<_>>>()?;
329            for ordering_values in ordering_values.into_iter() {
330                partition_ordering_values.push(ordering_values.into());
331            }
332
333            let sort_options = self
334                .ordering_req
335                .iter()
336                .map(|sort_expr| sort_expr.options)
337                .collect::<Vec<_>>();
338            let (new_values, new_orderings) = merge_ordered_arrays(
339                &mut partition_values,
340                &mut partition_ordering_values,
341                &sort_options,
342            )?;
343            self.values = new_values.into();
344            self.ordering_values = new_orderings.into();
345        } else {
346            return exec_err!("Expects to receive a list array");
347        }
348        Ok(())
349    }
350
351    fn state(&mut self) -> Result<Vec<ScalarValue>> {
352        let mut result = vec![self.evaluate_values()];
353        if !self.ordering_req.is_empty() {
354            result.push(self.evaluate_orderings()?);
355        }
356        Ok(result)
357    }
358
359    fn evaluate(&mut self) -> Result<ScalarValue> {
360        let n_required = self.n.unsigned_abs() as usize;
361        let from_start = self.n > 0;
362        let nth_value_idx = if from_start {
363            // index is from start
364            let forward_idx = n_required - 1;
365            (forward_idx < self.values.len()).then_some(forward_idx)
366        } else {
367            // index is from end
368            self.values.len().checked_sub(n_required)
369        };
370        if let Some(idx) = nth_value_idx {
371            Ok(self.values[idx].clone())
372        } else {
373            ScalarValue::try_from(self.datatypes[0].clone())
374        }
375    }
376
377    fn size(&self) -> usize {
378        let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
379            - size_of_val(&self.values);
380
381        // Add size of the `self.ordering_values`
382        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
383        for row in &self.ordering_values {
384            total += ScalarValue::size_of_vec(row) - size_of_val(row);
385        }
386
387        // Add size of the `self.datatypes`
388        total += size_of::<DataType>() * self.datatypes.capacity();
389        for dtype in &self.datatypes {
390            total += dtype.size() - size_of_val(dtype);
391        }
392
393        // Add size of the `self.ordering_req`
394        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
395        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
396        total
397    }
398}
399
400impl NthValueAccumulator {
401    fn evaluate_orderings(&self) -> Result<ScalarValue> {
402        let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
403
404        let mut column_wise_ordering_values = vec![];
405        let num_columns = fields.len();
406        for i in 0..num_columns {
407            let column_values = self
408                .ordering_values
409                .iter()
410                .map(|x| x[i].clone())
411                .collect::<Vec<_>>();
412            let array = if column_values.is_empty() {
413                new_empty_array(fields[i].data_type())
414            } else {
415                ScalarValue::iter_to_array(column_values.into_iter())?
416            };
417            column_wise_ordering_values.push(array);
418        }
419
420        let struct_field = Fields::from(fields);
421        let ordering_array =
422            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
423
424        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
425    }
426
427    fn evaluate_values(&self) -> ScalarValue {
428        let mut values_cloned = self.values.clone();
429        let values_slice = values_cloned.make_contiguous();
430        ScalarValue::List(ScalarValue::new_list_nullable(
431            values_slice,
432            &self.datatypes[0],
433        ))
434    }
435
436    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
437    /// None represents all of the new `values` need to be added to the state.
438    fn append_new_data(
439        &mut self,
440        values: &[ArrayRef],
441        fetch: Option<usize>,
442    ) -> Result<()> {
443        let n_row = values[0].len();
444        let n_to_add = if let Some(fetch) = fetch {
445            std::cmp::min(fetch, n_row)
446        } else {
447            n_row
448        };
449        for index in 0..n_to_add {
450            let row = get_row_at_idx(values, index)?;
451            self.values.push_back(row[0].clone());
452            // At index 1, we have n index argument.
453            // Ordering values cover starting from 2nd index to end
454            self.ordering_values.push_back(row[2..].to_vec());
455        }
456        Ok(())
457    }
458}