datafusion_functions_aggregate/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines NTH_VALUE aggregate expression which may specify ordering requirement
19//! that can evaluated at runtime during query execution
20
21use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
30use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34    lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
35    Signature, SortExpr, Volatility,
36};
37use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
38use datafusion_functions_aggregate_common::utils::ordering_fields;
39use datafusion_macros::user_doc;
40use datafusion_physical_expr::expressions::Literal;
41use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
42
43create_func!(NthValueAgg, nth_value_udaf);
44
45/// Returns the nth value in a group of values.
46pub fn nth_value(
47    expr: datafusion_expr::Expr,
48    n: i64,
49    order_by: Vec<SortExpr>,
50) -> datafusion_expr::Expr {
51    let args = vec![expr, lit(n)];
52    if !order_by.is_empty() {
53        nth_value_udaf()
54            .call(args)
55            .order_by(order_by)
56            .build()
57            .unwrap()
58    } else {
59        nth_value_udaf().call(args)
60    }
61}
62
63#[user_doc(
64    doc_section(label = "Statistical Functions"),
65    description = "Returns the nth value in a group of values.",
66    syntax_example = "nth_value(expression, n ORDER BY expression)",
67    sql_example = r#"```sql
68> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
69  FROM employee;
70+---------+--------+-------------------------+
71| dept_id | salary | second_salary_by_dept   |
72+---------+--------+-------------------------+
73| 1       | 30000  | NULL                    |
74| 1       | 40000  | 40000                   |
75| 1       | 50000  | 40000                   |
76| 2       | 35000  | NULL                    |
77| 2       | 45000  | 45000                   |
78+---------+--------+-------------------------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The column or expression to retrieve the nth value from."
83    ),
84    argument(
85        name = "n",
86        description = "The position (nth) of the value to retrieve, based on the ordering."
87    )
88)]
89/// Expression for a `NTH_VALUE(..., ... ORDER BY ...)` aggregation. In a multi
90/// partition setting, partial aggregations are computed for every partition,
91/// and then their results are merged.
92#[derive(Debug)]
93pub struct NthValueAgg {
94    signature: Signature,
95}
96
97impl NthValueAgg {
98    /// Create a new `NthValueAgg` aggregate function
99    pub fn new() -> Self {
100        Self {
101            signature: Signature::any(2, Volatility::Immutable),
102        }
103    }
104}
105
106impl Default for NthValueAgg {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl AggregateUDFImpl for NthValueAgg {
113    fn as_any(&self) -> &dyn Any {
114        self
115    }
116
117    fn name(&self) -> &str {
118        "nth_value"
119    }
120
121    fn signature(&self) -> &Signature {
122        &self.signature
123    }
124
125    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126        Ok(arg_types[0].clone())
127    }
128
129    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
130        let n = match acc_args.exprs[1]
131            .as_any()
132            .downcast_ref::<Literal>()
133            .map(|lit| lit.value())
134        {
135            Some(ScalarValue::Int64(Some(value))) => {
136                if acc_args.is_reversed {
137                    -*value
138                } else {
139                    *value
140                }
141            }
142            _ => {
143                return not_impl_err!(
144                    "{} not supported for n: {}",
145                    self.name(),
146                    &acc_args.exprs[1]
147                )
148            }
149        };
150
151        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
152            return TrivialNthValueAccumulator::try_new(
153                n,
154                acc_args.return_field.data_type(),
155            )
156            .map(|acc| Box::new(acc) as _);
157        };
158        let ordering_dtypes = ordering
159            .iter()
160            .map(|e| e.expr.data_type(acc_args.schema))
161            .collect::<Result<Vec<_>>>()?;
162
163        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
164        NthValueAccumulator::try_new(n, &data_type, &ordering_dtypes, ordering)
165            .map(|acc| Box::new(acc) as _)
166    }
167
168    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
169        let mut fields = vec![Field::new_list(
170            format_state_name(self.name(), "nth_value"),
171            // See COMMENTS.md to understand why nullable is set to true
172            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
173            false,
174        )];
175        let orderings = args.ordering_fields.to_vec();
176        if !orderings.is_empty() {
177            fields.push(Field::new_list(
178                format_state_name(self.name(), "nth_value_orderings"),
179                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
180                false,
181            ));
182        }
183        Ok(fields.into_iter().map(Arc::new).collect())
184    }
185
186    fn reverse_expr(&self) -> ReversedUDAF {
187        ReversedUDAF::Reversed(nth_value_udaf())
188    }
189
190    fn documentation(&self) -> Option<&Documentation> {
191        self.doc()
192    }
193}
194
195#[derive(Debug)]
196pub struct TrivialNthValueAccumulator {
197    /// The `N` value.
198    n: i64,
199    /// Stores entries in the `NTH_VALUE` result.
200    values: VecDeque<ScalarValue>,
201    /// Data types of the value.
202    datatype: DataType,
203}
204
205impl TrivialNthValueAccumulator {
206    /// Create a new order-insensitive NTH_VALUE accumulator based on the given
207    /// item data type.
208    pub fn try_new(n: i64, datatype: &DataType) -> Result<Self> {
209        if n == 0 {
210            // n cannot be 0
211            return internal_err!("Nth value indices are 1 based. 0 is invalid index");
212        }
213        Ok(Self {
214            n,
215            values: VecDeque::new(),
216            datatype: datatype.clone(),
217        })
218    }
219
220    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
221    /// None represents all of the new `values` need to be added to the state.
222    fn append_new_data(
223        &mut self,
224        values: &[ArrayRef],
225        fetch: Option<usize>,
226    ) -> Result<()> {
227        let n_row = values[0].len();
228        let n_to_add = if let Some(fetch) = fetch {
229            std::cmp::min(fetch, n_row)
230        } else {
231            n_row
232        };
233        for index in 0..n_to_add {
234            let mut row = get_row_at_idx(values, index)?;
235            self.values.push_back(row.swap_remove(0));
236            // At index 1, we have n index argument, which is constant.
237        }
238        Ok(())
239    }
240}
241
242impl Accumulator for TrivialNthValueAccumulator {
243    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
244    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
245    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
246        if !values.is_empty() {
247            let n_required = self.n.unsigned_abs() as usize;
248            let from_start = self.n > 0;
249            if from_start {
250                // direction is from start
251                let n_remaining = n_required.saturating_sub(self.values.len());
252                self.append_new_data(values, Some(n_remaining))?;
253            } else {
254                // direction is from end
255                self.append_new_data(values, None)?;
256                let start_offset = self.values.len().saturating_sub(n_required);
257                if start_offset > 0 {
258                    self.values.drain(0..start_offset);
259                }
260            }
261        }
262        Ok(())
263    }
264
265    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
266        if !states.is_empty() {
267            // First entry in the state is the aggregation result.
268            let n_required = self.n.unsigned_abs() as usize;
269            let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
270            for v in array_agg_res.into_iter() {
271                self.values.extend(v);
272                if self.values.len() > n_required {
273                    // There is enough data collected, can stop merging:
274                    break;
275                }
276            }
277        }
278        Ok(())
279    }
280
281    fn state(&mut self) -> Result<Vec<ScalarValue>> {
282        let mut values_cloned = self.values.clone();
283        let values_slice = values_cloned.make_contiguous();
284        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
285            values_slice,
286            &self.datatype,
287        ))])
288    }
289
290    fn evaluate(&mut self) -> Result<ScalarValue> {
291        let n_required = self.n.unsigned_abs() as usize;
292        let from_start = self.n > 0;
293        let nth_value_idx = if from_start {
294            // index is from start
295            let forward_idx = n_required - 1;
296            (forward_idx < self.values.len()).then_some(forward_idx)
297        } else {
298            // index is from end
299            self.values.len().checked_sub(n_required)
300        };
301        if let Some(idx) = nth_value_idx {
302            Ok(self.values[idx].clone())
303        } else {
304            ScalarValue::try_from(self.datatype.clone())
305        }
306    }
307
308    fn size(&self) -> usize {
309        size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
310            - size_of_val(&self.values)
311            + size_of::<DataType>()
312    }
313}
314
315#[derive(Debug)]
316pub struct NthValueAccumulator {
317    /// The `N` value.
318    n: i64,
319    /// Stores entries in the `NTH_VALUE` result.
320    values: VecDeque<ScalarValue>,
321    /// Stores values of ordering requirement expressions corresponding to each
322    /// entry in `values`. This information is used when merging results from
323    /// different partitions. For detailed information how merging is done, see
324    /// [`merge_ordered_arrays`].
325    ordering_values: VecDeque<Vec<ScalarValue>>,
326    /// Stores datatypes of expressions inside values and ordering requirement
327    /// expressions.
328    datatypes: Vec<DataType>,
329    /// Stores the ordering requirement of the `Accumulator`.
330    ordering_req: LexOrdering,
331}
332
333impl NthValueAccumulator {
334    /// Create a new order-sensitive NTH_VALUE accumulator based on the given
335    /// item data type.
336    pub fn try_new(
337        n: i64,
338        datatype: &DataType,
339        ordering_dtypes: &[DataType],
340        ordering_req: LexOrdering,
341    ) -> Result<Self> {
342        if n == 0 {
343            // n cannot be 0
344            return internal_err!("Nth value indices are 1 based. 0 is invalid index");
345        }
346        let mut datatypes = vec![datatype.clone()];
347        datatypes.extend(ordering_dtypes.iter().cloned());
348        Ok(Self {
349            n,
350            values: VecDeque::new(),
351            ordering_values: VecDeque::new(),
352            datatypes,
353            ordering_req,
354        })
355    }
356
357    fn evaluate_orderings(&self) -> Result<ScalarValue> {
358        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
359
360        let mut column_wise_ordering_values = vec![];
361        let num_columns = fields.len();
362        for i in 0..num_columns {
363            let column_values = self
364                .ordering_values
365                .iter()
366                .map(|x| x[i].clone())
367                .collect::<Vec<_>>();
368            let array = if column_values.is_empty() {
369                new_empty_array(fields[i].data_type())
370            } else {
371                ScalarValue::iter_to_array(column_values.into_iter())?
372            };
373            column_wise_ordering_values.push(array);
374        }
375
376        let struct_field = Fields::from(fields);
377        let ordering_array =
378            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
379
380        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
381    }
382
383    fn evaluate_values(&self) -> ScalarValue {
384        let mut values_cloned = self.values.clone();
385        let values_slice = values_cloned.make_contiguous();
386        ScalarValue::List(ScalarValue::new_list_nullable(
387            values_slice,
388            &self.datatypes[0],
389        ))
390    }
391
392    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
393    /// None represents all of the new `values` need to be added to the state.
394    fn append_new_data(
395        &mut self,
396        values: &[ArrayRef],
397        fetch: Option<usize>,
398    ) -> Result<()> {
399        let n_row = values[0].len();
400        let n_to_add = if let Some(fetch) = fetch {
401            std::cmp::min(fetch, n_row)
402        } else {
403            n_row
404        };
405        for index in 0..n_to_add {
406            let row = get_row_at_idx(values, index)?;
407            self.values.push_back(row[0].clone());
408            // At index 1, we have n index argument.
409            // Ordering values cover starting from 2nd index to end
410            self.ordering_values.push_back(row[2..].to_vec());
411        }
412        Ok(())
413    }
414}
415
416impl Accumulator for NthValueAccumulator {
417    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
418    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
419    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
420        if values.is_empty() {
421            return Ok(());
422        }
423
424        let n_required = self.n.unsigned_abs() as usize;
425        let from_start = self.n > 0;
426        if from_start {
427            // direction is from start
428            let n_remaining = n_required.saturating_sub(self.values.len());
429            self.append_new_data(values, Some(n_remaining))?;
430        } else {
431            // direction is from end
432            self.append_new_data(values, None)?;
433            let start_offset = self.values.len().saturating_sub(n_required);
434            if start_offset > 0 {
435                self.values.drain(0..start_offset);
436                self.ordering_values.drain(0..start_offset);
437            }
438        }
439
440        Ok(())
441    }
442
443    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
444        if states.is_empty() {
445            return Ok(());
446        }
447        // Second entry stores values received for ordering requirement columns
448        // for each aggregation value inside NTH_VALUE list. For each `StructArray`
449        // inside this list, we will receive an `Array` that stores values received
450        // from its ordering requirement expression. This information is necessary
451        // during merging.
452        let Some(agg_orderings) = states[1].as_list_opt::<i32>() else {
453            return exec_err!("Expects to receive a list array");
454        };
455
456        // Stores NTH_VALUE results coming from each partition
457        let mut partition_values = vec![self.values.clone()];
458        // First entry in the state is the aggregation result.
459        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
460        for v in array_agg_res.into_iter() {
461            partition_values.push(v.into());
462        }
463        // Stores ordering requirement expression results coming from each partition:
464        let mut partition_ordering_values = vec![self.ordering_values.clone()];
465        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
466        // Extract value from struct to ordering_rows for each group/partition:
467        for partition_ordering_rows in orderings.into_iter() {
468            let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
469                let ScalarValue::Struct(s_array) = ordering_row else {
470                    return exec_err!(
471                        "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
472                        ordering_row.data_type()
473                    );
474                };
475                s_array
476                    .columns()
477                    .iter()
478                    .map(|column| ScalarValue::try_from_array(column, 0))
479                    .collect()
480            }).collect::<Result<VecDeque<_>>>()?;
481            partition_ordering_values.push(ordering_values);
482        }
483
484        let sort_options = self
485            .ordering_req
486            .iter()
487            .map(|sort_expr| sort_expr.options)
488            .collect::<Vec<_>>();
489        let (new_values, new_orderings) = merge_ordered_arrays(
490            &mut partition_values,
491            &mut partition_ordering_values,
492            &sort_options,
493        )?;
494        self.values = new_values.into();
495        self.ordering_values = new_orderings.into();
496        Ok(())
497    }
498
499    fn state(&mut self) -> Result<Vec<ScalarValue>> {
500        Ok(vec![self.evaluate_values(), self.evaluate_orderings()?])
501    }
502
503    fn evaluate(&mut self) -> Result<ScalarValue> {
504        let n_required = self.n.unsigned_abs() as usize;
505        let from_start = self.n > 0;
506        let nth_value_idx = if from_start {
507            // index is from start
508            let forward_idx = n_required - 1;
509            (forward_idx < self.values.len()).then_some(forward_idx)
510        } else {
511            // index is from end
512            self.values.len().checked_sub(n_required)
513        };
514        if let Some(idx) = nth_value_idx {
515            Ok(self.values[idx].clone())
516        } else {
517            ScalarValue::try_from(self.datatypes[0].clone())
518        }
519    }
520
521    fn size(&self) -> usize {
522        let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
523            - size_of_val(&self.values);
524
525        // Add size of the `self.ordering_values`
526        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
527        for row in &self.ordering_values {
528            total += ScalarValue::size_of_vec(row) - size_of_val(row);
529        }
530
531        // Add size of the `self.datatypes`
532        total += size_of::<DataType>() * self.datatypes.capacity();
533        for dtype in &self.datatypes {
534            total += dtype.size() - size_of_val(dtype);
535        }
536
537        // Add size of the `self.ordering_req`
538        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
539        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
540        total
541    }
542}