datafusion_functions_aggregate/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines NTH_VALUE aggregate expression which may specify ordering requirement
19//! that can evaluated at runtime during query execution
20
21use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{ArrayRef, AsArray, StructArray, new_empty_array};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{SingleRowListArrayBuilder, get_row_at_idx};
30use datafusion_common::{
31    Result, ScalarValue, assert_or_internal_err, exec_err, not_impl_err,
32};
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36    Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
37    Signature, SortExpr, Volatility, lit,
38};
39use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
40use datafusion_functions_aggregate_common::utils::ordering_fields;
41use datafusion_macros::user_doc;
42use datafusion_physical_expr::expressions::Literal;
43use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
44
45create_func!(NthValueAgg, nth_value_udaf);
46
47/// Returns the nth value in a group of values.
48pub fn nth_value(
49    expr: datafusion_expr::Expr,
50    n: i64,
51    order_by: Vec<SortExpr>,
52) -> datafusion_expr::Expr {
53    let args = vec![expr, lit(n)];
54    if !order_by.is_empty() {
55        nth_value_udaf()
56            .call(args)
57            .order_by(order_by)
58            .build()
59            .unwrap()
60    } else {
61        nth_value_udaf().call(args)
62    }
63}
64
65#[user_doc(
66    doc_section(label = "Statistical Functions"),
67    description = "Returns the nth value in a group of values.",
68    syntax_example = "nth_value(expression, n ORDER BY expression)",
69    sql_example = r#"```sql
70> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
71  FROM employee;
72+---------+--------+-------------------------+
73| dept_id | salary | second_salary_by_dept   |
74+---------+--------+-------------------------+
75| 1       | 30000  | NULL                    |
76| 1       | 40000  | 40000                   |
77| 1       | 50000  | 40000                   |
78| 2       | 35000  | NULL                    |
79| 2       | 45000  | 45000                   |
80+---------+--------+-------------------------+
81```"#,
82    argument(
83        name = "expression",
84        description = "The column or expression to retrieve the nth value from."
85    ),
86    argument(
87        name = "n",
88        description = "The position (nth) of the value to retrieve, based on the ordering."
89    )
90)]
91/// Expression for a `NTH_VALUE(..., ... ORDER BY ...)` aggregation. In a multi
92/// partition setting, partial aggregations are computed for every partition,
93/// and then their results are merged.
94#[derive(Debug, PartialEq, Eq, Hash)]
95pub struct NthValueAgg {
96    signature: Signature,
97}
98
99impl NthValueAgg {
100    /// Create a new `NthValueAgg` aggregate function
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::any(2, Volatility::Immutable),
104        }
105    }
106}
107
108impl Default for NthValueAgg {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl AggregateUDFImpl for NthValueAgg {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "nth_value"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128        Ok(arg_types[0].clone())
129    }
130
131    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
132        let n = match acc_args.exprs[1]
133            .as_any()
134            .downcast_ref::<Literal>()
135            .map(|lit| lit.value())
136        {
137            Some(ScalarValue::Int64(Some(value))) => {
138                if acc_args.is_reversed {
139                    -*value
140                } else {
141                    *value
142                }
143            }
144            _ => {
145                return not_impl_err!(
146                    "{} not supported for n: {}",
147                    self.name(),
148                    &acc_args.exprs[1]
149                );
150            }
151        };
152
153        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
154            return TrivialNthValueAccumulator::try_new(
155                n,
156                acc_args.return_field.data_type(),
157            )
158            .map(|acc| Box::new(acc) as _);
159        };
160        let ordering_dtypes = ordering
161            .iter()
162            .map(|e| e.expr.data_type(acc_args.schema))
163            .collect::<Result<Vec<_>>>()?;
164
165        let data_type = acc_args.expr_fields[0].data_type();
166        NthValueAccumulator::try_new(n, data_type, &ordering_dtypes, ordering)
167            .map(|acc| Box::new(acc) as _)
168    }
169
170    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
171        let mut fields = vec![Field::new_list(
172            format_state_name(self.name(), "nth_value"),
173            // See COMMENTS.md to understand why nullable is set to true
174            Field::new_list_field(args.input_fields[0].data_type().clone(), true),
175            false,
176        )];
177        let orderings = args.ordering_fields.to_vec();
178        if !orderings.is_empty() {
179            fields.push(Field::new_list(
180                format_state_name(self.name(), "nth_value_orderings"),
181                Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
182                false,
183            ));
184        }
185        Ok(fields.into_iter().map(Arc::new).collect())
186    }
187
188    fn reverse_expr(&self) -> ReversedUDAF {
189        ReversedUDAF::Reversed(nth_value_udaf())
190    }
191
192    fn documentation(&self) -> Option<&Documentation> {
193        self.doc()
194    }
195}
196
197#[derive(Debug)]
198pub struct TrivialNthValueAccumulator {
199    /// The `N` value.
200    n: i64,
201    /// Stores entries in the `NTH_VALUE` result.
202    values: VecDeque<ScalarValue>,
203    /// Data types of the value.
204    datatype: DataType,
205}
206
207impl TrivialNthValueAccumulator {
208    /// Create a new order-insensitive NTH_VALUE accumulator based on the given
209    /// item data type.
210    pub fn try_new(n: i64, datatype: &DataType) -> Result<Self> {
211        // n cannot be 0
212        assert_or_internal_err!(
213            n != 0,
214            "Nth value indices are 1 based. 0 is invalid index"
215        );
216        Ok(Self {
217            n,
218            values: VecDeque::new(),
219            datatype: datatype.clone(),
220        })
221    }
222
223    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
224    /// None represents all of the new `values` need to be added to the state.
225    fn append_new_data(
226        &mut self,
227        values: &[ArrayRef],
228        fetch: Option<usize>,
229    ) -> Result<()> {
230        let n_row = values[0].len();
231        let n_to_add = if let Some(fetch) = fetch {
232            std::cmp::min(fetch, n_row)
233        } else {
234            n_row
235        };
236        for index in 0..n_to_add {
237            let mut row = get_row_at_idx(values, index)?;
238            self.values.push_back(row.swap_remove(0));
239            // At index 1, we have n index argument, which is constant.
240        }
241        Ok(())
242    }
243}
244
245impl Accumulator for TrivialNthValueAccumulator {
246    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
247    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
248    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
249        if !values.is_empty() {
250            let n_required = self.n.unsigned_abs() as usize;
251            let from_start = self.n > 0;
252            if from_start {
253                // direction is from start
254                let n_remaining = n_required.saturating_sub(self.values.len());
255                self.append_new_data(values, Some(n_remaining))?;
256            } else {
257                // direction is from end
258                self.append_new_data(values, None)?;
259                let start_offset = self.values.len().saturating_sub(n_required);
260                if start_offset > 0 {
261                    self.values.drain(0..start_offset);
262                }
263            }
264        }
265        Ok(())
266    }
267
268    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
269        if !states.is_empty() {
270            // First entry in the state is the aggregation result.
271            let n_required = self.n.unsigned_abs() as usize;
272            let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
273            for v in array_agg_res.into_iter().flatten() {
274                self.values.extend(v);
275                if self.values.len() > n_required {
276                    // There is enough data collected, can stop merging:
277                    break;
278                }
279            }
280        }
281        Ok(())
282    }
283
284    fn state(&mut self) -> Result<Vec<ScalarValue>> {
285        let mut values_cloned = self.values.clone();
286        let values_slice = values_cloned.make_contiguous();
287        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
288            values_slice,
289            &self.datatype,
290        ))])
291    }
292
293    fn evaluate(&mut self) -> Result<ScalarValue> {
294        let n_required = self.n.unsigned_abs() as usize;
295        let from_start = self.n > 0;
296        let nth_value_idx = if from_start {
297            // index is from start
298            let forward_idx = n_required - 1;
299            (forward_idx < self.values.len()).then_some(forward_idx)
300        } else {
301            // index is from end
302            self.values.len().checked_sub(n_required)
303        };
304        if let Some(idx) = nth_value_idx {
305            Ok(self.values[idx].clone())
306        } else {
307            ScalarValue::try_from(self.datatype.clone())
308        }
309    }
310
311    fn size(&self) -> usize {
312        size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
313            - size_of_val(&self.values)
314            + size_of::<DataType>()
315    }
316}
317
318#[derive(Debug)]
319pub struct NthValueAccumulator {
320    /// The `N` value.
321    n: i64,
322    /// Stores entries in the `NTH_VALUE` result.
323    values: VecDeque<ScalarValue>,
324    /// Stores values of ordering requirement expressions corresponding to each
325    /// entry in `values`. This information is used when merging results from
326    /// different partitions. For detailed information how merging is done, see
327    /// [`merge_ordered_arrays`].
328    ordering_values: VecDeque<Vec<ScalarValue>>,
329    /// Stores datatypes of expressions inside values and ordering requirement
330    /// expressions.
331    datatypes: Vec<DataType>,
332    /// Stores the ordering requirement of the `Accumulator`.
333    ordering_req: LexOrdering,
334}
335
336impl NthValueAccumulator {
337    /// Create a new order-sensitive NTH_VALUE accumulator based on the given
338    /// item data type.
339    pub fn try_new(
340        n: i64,
341        datatype: &DataType,
342        ordering_dtypes: &[DataType],
343        ordering_req: LexOrdering,
344    ) -> Result<Self> {
345        // n cannot be 0
346        assert_or_internal_err!(
347            n != 0,
348            "Nth value indices are 1 based. 0 is invalid index"
349        );
350        let mut datatypes = vec![datatype.clone()];
351        datatypes.extend(ordering_dtypes.iter().cloned());
352        Ok(Self {
353            n,
354            values: VecDeque::new(),
355            ordering_values: VecDeque::new(),
356            datatypes,
357            ordering_req,
358        })
359    }
360
361    fn evaluate_orderings(&self) -> Result<ScalarValue> {
362        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
363
364        let mut column_wise_ordering_values = vec![];
365        let num_columns = fields.len();
366        for i in 0..num_columns {
367            let column_values = self
368                .ordering_values
369                .iter()
370                .map(|x| x[i].clone())
371                .collect::<Vec<_>>();
372            let array = if column_values.is_empty() {
373                new_empty_array(fields[i].data_type())
374            } else {
375                ScalarValue::iter_to_array(column_values.into_iter())?
376            };
377            column_wise_ordering_values.push(array);
378        }
379
380        let struct_field = Fields::from(fields);
381        let ordering_array =
382            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
383
384        Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
385    }
386
387    fn evaluate_values(&self) -> ScalarValue {
388        let mut values_cloned = self.values.clone();
389        let values_slice = values_cloned.make_contiguous();
390        ScalarValue::List(ScalarValue::new_list_nullable(
391            values_slice,
392            &self.datatypes[0],
393        ))
394    }
395
396    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
397    /// None represents all of the new `values` need to be added to the state.
398    fn append_new_data(
399        &mut self,
400        values: &[ArrayRef],
401        fetch: Option<usize>,
402    ) -> Result<()> {
403        let n_row = values[0].len();
404        let n_to_add = if let Some(fetch) = fetch {
405            std::cmp::min(fetch, n_row)
406        } else {
407            n_row
408        };
409        for index in 0..n_to_add {
410            let row = get_row_at_idx(values, index)?;
411            self.values.push_back(row[0].clone());
412            // At index 1, we have n index argument.
413            // Ordering values cover starting from 2nd index to end
414            self.ordering_values.push_back(row[2..].to_vec());
415        }
416        Ok(())
417    }
418}
419
420impl Accumulator for NthValueAccumulator {
421    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
422    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
423    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
424        if values.is_empty() {
425            return Ok(());
426        }
427
428        let n_required = self.n.unsigned_abs() as usize;
429        let from_start = self.n > 0;
430        if from_start {
431            // direction is from start
432            let n_remaining = n_required.saturating_sub(self.values.len());
433            self.append_new_data(values, Some(n_remaining))?;
434        } else {
435            // direction is from end
436            self.append_new_data(values, None)?;
437            let start_offset = self.values.len().saturating_sub(n_required);
438            if start_offset > 0 {
439                self.values.drain(0..start_offset);
440                self.ordering_values.drain(0..start_offset);
441            }
442        }
443
444        Ok(())
445    }
446
447    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
448        if states.is_empty() {
449            return Ok(());
450        }
451        // Second entry stores values received for ordering requirement columns
452        // for each aggregation value inside NTH_VALUE list. For each `StructArray`
453        // inside this list, we will receive an `Array` that stores values received
454        // from its ordering requirement expression. This information is necessary
455        // during merging.
456        let Some(agg_orderings) = states[1].as_list_opt::<i32>() else {
457            return exec_err!("Expects to receive a list array");
458        };
459
460        // Stores NTH_VALUE results coming from each partition
461        let mut partition_values = vec![self.values.clone()];
462        // First entry in the state is the aggregation result.
463        let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
464        for v in array_agg_res.into_iter().flatten() {
465            partition_values.push(v.into());
466        }
467        // Stores ordering requirement expression results coming from each partition:
468        let mut partition_ordering_values = vec![self.ordering_values.clone()];
469        let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
470        // Extract value from struct to ordering_rows for each group/partition:
471        for partition_ordering_rows in orderings.into_iter().flatten() {
472            let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
473                let ScalarValue::Struct(s_array) = ordering_row else {
474                    return exec_err!(
475                        "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
476                        ordering_row.data_type()
477                    );
478                };
479                s_array
480                    .columns()
481                    .iter()
482                    .map(|column| ScalarValue::try_from_array(column, 0))
483                    .collect()
484            }).collect::<Result<VecDeque<_>>>()?;
485            partition_ordering_values.push(ordering_values);
486        }
487
488        let sort_options = self
489            .ordering_req
490            .iter()
491            .map(|sort_expr| sort_expr.options)
492            .collect::<Vec<_>>();
493        let (new_values, new_orderings) = merge_ordered_arrays(
494            &mut partition_values,
495            &mut partition_ordering_values,
496            &sort_options,
497        )?;
498        self.values = new_values.into();
499        self.ordering_values = new_orderings.into();
500        Ok(())
501    }
502
503    fn state(&mut self) -> Result<Vec<ScalarValue>> {
504        Ok(vec![self.evaluate_values(), self.evaluate_orderings()?])
505    }
506
507    fn evaluate(&mut self) -> Result<ScalarValue> {
508        let n_required = self.n.unsigned_abs() as usize;
509        let from_start = self.n > 0;
510        let nth_value_idx = if from_start {
511            // index is from start
512            let forward_idx = n_required - 1;
513            (forward_idx < self.values.len()).then_some(forward_idx)
514        } else {
515            // index is from end
516            self.values.len().checked_sub(n_required)
517        };
518        if let Some(idx) = nth_value_idx {
519            Ok(self.values[idx].clone())
520        } else {
521            ScalarValue::try_from(self.datatypes[0].clone())
522        }
523    }
524
525    fn size(&self) -> usize {
526        let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
527            - size_of_val(&self.values);
528
529        // Add size of the `self.ordering_values`
530        total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
531        for row in &self.ordering_values {
532            total += ScalarValue::size_of_vec(row) - size_of_val(row);
533        }
534
535        // Add size of the `self.datatypes`
536        total += size_of::<DataType>() * self.datatypes.capacity();
537        for dtype in &self.datatypes {
538            total += dtype.size() - size_of_val(dtype);
539        }
540
541        // Add size of the `self.ordering_req`
542        total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
543        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
544        total
545    }
546}