Skip to main content

datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::buffer::NullBuffer;
23use arrow::datatypes::FieldRef;
24use datafusion_common::arrow::array::ArrayRef;
25use datafusion_common::arrow::datatypes::{DataType, Field};
26use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err};
27use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
28use datafusion_expr::window_state::WindowAggState;
29use datafusion_expr::{
30    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
31    TypeSignature, Volatility, WindowUDFImpl,
32};
33use datafusion_functions_window_common::field;
34use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use field::WindowUDFFieldArgs;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43define_udwf_and_expr!(
44    First,
45    first_value,
46    [arg],
47    first_value_udwf,
48    "Returns the first value in the window frame",
49    NthValue::first
50);
51define_udwf_and_expr!(
52    Last,
53    last_value,
54    [arg],
55    last_value_udwf,
56    "Returns the last value in the window frame",
57    NthValue::last
58);
59get_or_init_udwf!(
60    NthValue,
61    nth_value,
62    nth_value_udwf,
63    "Returns the nth value in the window frame",
64    NthValue::nth
65);
66
67/// Create an expression to represent the `nth_value` window function
68pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
69    nth_value_udwf().call(vec![arg, n.lit()])
70}
71
72/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
74pub enum NthValueKind {
75    First,
76    Last,
77    Nth,
78}
79
80impl NthValueKind {
81    fn name(&self) -> &'static str {
82        match self {
83            NthValueKind::First => "first_value",
84            NthValueKind::Last => "last_value",
85            NthValueKind::Nth => "nth_value",
86        }
87    }
88}
89
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct NthValue {
92    signature: Signature,
93    kind: NthValueKind,
94}
95
96impl NthValue {
97    /// Create a new `nth_value` function
98    pub fn new(kind: NthValueKind) -> Self {
99        Self {
100            signature: Signature::one_of(
101                vec![
102                    TypeSignature::Nullary,
103                    TypeSignature::Any(1),
104                    TypeSignature::Any(2),
105                ],
106                Volatility::Immutable,
107            ),
108            kind,
109        }
110    }
111
112    pub fn first() -> Self {
113        Self::new(NthValueKind::First)
114    }
115
116    pub fn last() -> Self {
117        Self::new(NthValueKind::Last)
118    }
119    pub fn nth() -> Self {
120        Self::new(NthValueKind::Nth)
121    }
122
123    pub fn kind(&self) -> &NthValueKind {
124        &self.kind
125    }
126}
127
128static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
129    Documentation::builder(
130        DOC_SECTION_ANALYTICAL,
131        "Returns value evaluated at the row that is the first row of the window \
132            frame.",
133        "first_value(expression)",
134    )
135    .with_argument("expression", "Expression to operate on")
136    .with_sql_example(
137        r#"
138```sql
139-- Example usage of the first_value window function:
140SELECT department,
141  employee_id,
142  salary,
143  first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
144FROM employees;
145
146+-------------+-------------+--------+------------+
147| department  | employee_id | salary | top_salary |
148+-------------+-------------+--------+------------+
149| Sales       | 1           | 70000  | 70000      |
150| Sales       | 2           | 50000  | 70000      |
151| Sales       | 3           | 30000  | 70000      |
152| Engineering | 4           | 90000  | 90000      |
153| Engineering | 5           | 80000  | 90000      |
154+-------------+-------------+--------+------------+
155```
156"#,
157    )
158    .build()
159});
160
161fn get_first_value_doc() -> &'static Documentation {
162    &FIRST_VALUE_DOCUMENTATION
163}
164
165static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
166    Documentation::builder(
167        DOC_SECTION_ANALYTICAL,
168        "Returns value evaluated at the row that is the last row of the window \
169            frame.",
170        "last_value(expression)",
171    )
172    .with_argument("expression", "Expression to operate on")
173        .with_sql_example(r#"```sql
174-- SQL example of last_value:
175SELECT department,
176       employee_id,
177       salary,
178       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
179FROM employees;
180
181+-------------+-------------+--------+---------------------+
182| department  | employee_id | salary | running_last_salary |
183+-------------+-------------+--------+---------------------+
184| Sales       | 1           | 30000  | 30000               |
185| Sales       | 2           | 50000  | 50000               |
186| Sales       | 3           | 70000  | 70000               |
187| Engineering | 4           | 40000  | 40000               |
188| Engineering | 5           | 60000  | 60000               |
189+-------------+-------------+--------+---------------------+
190```
191"#)
192    .build()
193});
194
195fn get_last_value_doc() -> &'static Documentation {
196    &LAST_VALUE_DOCUMENTATION
197}
198
199static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
200    Documentation::builder(
201        DOC_SECTION_ANALYTICAL,
202        "Returns the value evaluated at the nth row of the window frame \
203         (counting from 1). Returns NULL if no such row exists.",
204        "nth_value(expression, n)",
205    )
206    .with_argument(
207        "expression",
208        "The column from which to retrieve the nth value.",
209    )
210    .with_argument(
211        "n",
212        "Integer. Specifies the row number (starting from 1) in the window frame.",
213    )
214    .with_sql_example(
215        r#"
216```sql
217-- Sample employees table:
218CREATE TABLE employees (id INT, salary INT);
219INSERT INTO employees (id, salary) VALUES
220(1, 30000),
221(2, 40000),
222(3, 50000),
223(4, 60000),
224(5, 70000);
225
226-- Example usage of nth_value:
227SELECT nth_value(salary, 2) OVER (
228  ORDER BY salary
229  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
230) AS nth_value
231FROM employees;
232
233+-----------+
234| nth_value |
235+-----------+
236| 40000     |
237| 40000     |
238| 40000     |
239| 40000     |
240| 40000     |
241+-----------+
242```
243"#,
244    )
245    .build()
246});
247
248fn get_nth_value_doc() -> &'static Documentation {
249    &NTH_VALUE_DOCUMENTATION
250}
251
252impl WindowUDFImpl for NthValue {
253    fn name(&self) -> &str {
254        self.kind.name()
255    }
256
257    fn signature(&self) -> &Signature {
258        &self.signature
259    }
260
261    fn partition_evaluator(
262        &self,
263        partition_evaluator_args: PartitionEvaluatorArgs,
264    ) -> Result<Box<dyn PartitionEvaluator>> {
265        let state = NthValueState {
266            finalized_result: None,
267            kind: self.kind,
268        };
269
270        if self.kind != NthValueKind::Nth {
271            return Ok(Box::new(NthValueEvaluator {
272                state,
273                ignore_nulls: partition_evaluator_args.ignore_nulls(),
274                n: 0,
275            }));
276        }
277
278        let n = match get_scalar_value_from_args(
279            partition_evaluator_args.input_exprs(),
280            1,
281        )
282        .map_err(|_e| {
283            exec_datafusion_err!(
284                "Expected a signed integer literal for the second argument of nth_value"
285            )
286        })?
287        .map(|v| get_signed_integer(&v))
288        {
289            Some(Ok(n)) => {
290                if partition_evaluator_args.is_reversed() {
291                    -n
292                } else {
293                    n
294                }
295            }
296            _ => {
297                return exec_err!(
298                    "Expected a signed integer literal for the second argument of nth_value"
299                );
300            }
301        };
302
303        Ok(Box::new(NthValueEvaluator {
304            state,
305            ignore_nulls: partition_evaluator_args.ignore_nulls(),
306            n,
307        }))
308    }
309
310    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
311        let input_field =
312            field_args
313                .input_fields()
314                .first()
315                .cloned()
316                .unwrap_or_else(|| {
317                    Arc::new(Field::new(field_args.name(), DataType::Null, true))
318                });
319
320        // Clone the input field to preserve metadata, update name and nullability
321        Ok(input_field
322            .as_ref()
323            .clone()
324            .with_name(field_args.name())
325            .with_nullable(true)
326            .into())
327    }
328
329    fn reverse_expr(&self) -> ReversedUDWF {
330        match self.kind {
331            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
332            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
333            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
334        }
335    }
336
337    fn documentation(&self) -> Option<&Documentation> {
338        match self.kind {
339            NthValueKind::First => Some(get_first_value_doc()),
340            NthValueKind::Last => Some(get_last_value_doc()),
341            NthValueKind::Nth => Some(get_nth_value_doc()),
342        }
343    }
344
345    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
346        LimitEffect::None // NthValue is causal
347    }
348}
349
350#[derive(Debug, Clone)]
351pub struct NthValueState {
352    // In certain cases, we can finalize the result early. Consider this usage:
353    // ```
354    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
355    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
356    // ```
357    // The result will always be the first entry in the table. We can store such
358    // early-finalizing results and then just reuse them as necessary. This opens
359    // opportunities to prune our datasets.
360    pub finalized_result: Option<ScalarValue>,
361    pub kind: NthValueKind,
362}
363
364#[derive(Debug)]
365pub(crate) struct NthValueEvaluator {
366    state: NthValueState,
367    ignore_nulls: bool,
368    n: i64,
369}
370
371impl PartitionEvaluator for NthValueEvaluator {
372    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
373    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
374    /// can memoize the result.  Once result is calculated, it will always stay
375    /// same. Hence, we do not need to keep past data as we process the entire
376    /// dataset.
377    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
378        let out = &state.out_col;
379        let size = out.len();
380        if self.ignore_nulls {
381            match self.state.kind {
382                // Prune on first non-null output in case of FIRST_VALUE
383                NthValueKind::First => {
384                    if let Some(nulls) = out.nulls() {
385                        if self.state.finalized_result.is_none() {
386                            if let Some(valid_index) = nulls.valid_indices().next() {
387                                let result =
388                                    ScalarValue::try_from_array(out, valid_index)?;
389                                self.state.finalized_result = Some(result);
390                            } else {
391                                // The output is empty or all nulls, ignore
392                            }
393                        }
394                        if state.window_frame_range.start < state.window_frame_range.end {
395                            state.window_frame_range.start =
396                                state.window_frame_range.end - 1;
397                        }
398                        return Ok(());
399                    } else {
400                        // Fall through to the main case because there are no nulls
401                    }
402                }
403                // Do not memoize for other kinds when nulls are ignored
404                NthValueKind::Last | NthValueKind::Nth => return Ok(()),
405            }
406        }
407        let mut buffer_size = 1;
408        // Decide if we arrived at a final result yet:
409        let (is_prunable, is_reverse_direction) = match self.state.kind {
410            NthValueKind::First => {
411                let n_range =
412                    state.window_frame_range.end - state.window_frame_range.start;
413                (n_range > 0 && size > 0, false)
414            }
415            NthValueKind::Last => (true, true),
416            NthValueKind::Nth => {
417                let n_range =
418                    state.window_frame_range.end - state.window_frame_range.start;
419                match self.n.cmp(&0) {
420                    Ordering::Greater => (
421                        n_range >= (self.n as usize) && size > (self.n as usize),
422                        false,
423                    ),
424                    Ordering::Less => {
425                        let reverse_index = (-self.n) as usize;
426                        buffer_size = reverse_index;
427                        // Negative index represents reverse direction.
428                        (n_range >= reverse_index, true)
429                    }
430                    Ordering::Equal => (false, false),
431                }
432            }
433        };
434        if is_prunable {
435            if self.state.finalized_result.is_none() && !is_reverse_direction {
436                let result = ScalarValue::try_from_array(out, size - 1)?;
437                self.state.finalized_result = Some(result);
438            }
439            state.window_frame_range.start =
440                state.window_frame_range.end.saturating_sub(buffer_size);
441        }
442        Ok(())
443    }
444
445    fn evaluate(
446        &mut self,
447        values: &[ArrayRef],
448        range: &Range<usize>,
449    ) -> Result<ScalarValue> {
450        if let Some(ref result) = self.state.finalized_result {
451            Ok(result.clone())
452        } else {
453            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
454            let arr = &values[0];
455            let n_range = range.end - range.start;
456            if n_range == 0 {
457                // We produce None if the window is empty.
458                return ScalarValue::try_from(arr.data_type());
459            }
460            match self.valid_index(arr, range) {
461                Some(index) => ScalarValue::try_from_array(arr, index),
462                None => ScalarValue::try_from(arr.data_type()),
463            }
464        }
465    }
466
467    fn supports_bounded_execution(&self) -> bool {
468        true
469    }
470
471    fn uses_window_frame(&self) -> bool {
472        true
473    }
474}
475
476impl NthValueEvaluator {
477    fn valid_index(&self, array: &ArrayRef, range: &Range<usize>) -> Option<usize> {
478        let n_range = range.end - range.start;
479        if self.ignore_nulls {
480            // Calculate valid indices, inside the window frame boundaries.
481            let slice = array.slice(range.start, n_range);
482            if let Some(nulls) = slice.nulls()
483                && nulls.null_count() > 0
484            {
485                return self.valid_index_with_nulls(nulls, range.start);
486            }
487        }
488        // Either no nulls, or nulls are regarded as valid rows
489        match self.state.kind {
490            NthValueKind::First => Some(range.start),
491            NthValueKind::Last => Some(range.end - 1),
492            NthValueKind::Nth => match self.n.cmp(&0) {
493                Ordering::Greater => {
494                    // SQL indices are not 0-based.
495                    let index = (self.n as usize) - 1;
496                    if index >= n_range {
497                        // Outside the range, return NULL:
498                        None
499                    } else {
500                        Some(range.start + index)
501                    }
502                }
503                Ordering::Less => {
504                    let reverse_index = (-self.n) as usize;
505                    if n_range < reverse_index {
506                        // Outside the range, return NULL:
507                        None
508                    } else {
509                        Some(range.end - reverse_index)
510                    }
511                }
512                Ordering::Equal => None,
513            },
514        }
515    }
516
517    fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option<usize> {
518        match self.state.kind {
519            NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset),
520            NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset),
521            NthValueKind::Nth => {
522                match self.n.cmp(&0) {
523                    Ordering::Greater => {
524                        // SQL indices are not 0-based.
525                        let index = (self.n as usize) - 1;
526                        nulls.valid_indices().nth(index).map(|idx| idx + offset)
527                    }
528                    Ordering::Less => {
529                        let reverse_index = (-self.n) as usize;
530                        let valid_indices_len = nulls.len() - nulls.null_count();
531                        if reverse_index > valid_indices_len {
532                            return None;
533                        }
534                        nulls
535                            .valid_indices()
536                            .nth(valid_indices_len - reverse_index)
537                            .map(|idx| idx + offset)
538                    }
539                    Ordering::Equal => None,
540                }
541            }
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use arrow::array::*;
550    use datafusion_common::cast::as_int32_array;
551    use datafusion_physical_expr::expressions::{Column, Literal};
552
553    fn test_i32_result(
554        expr: NthValue,
555        partition_evaluator_args: PartitionEvaluatorArgs,
556        expected: Int32Array,
557    ) -> Result<()> {
558        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
559        let values = vec![arr];
560        let mut ranges: Vec<Range<usize>> = vec![];
561        for i in 0..8 {
562            ranges.push(Range {
563                start: 0,
564                end: i + 1,
565            })
566        }
567        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
568        let result = ranges
569            .iter()
570            .map(|range| evaluator.evaluate(&values, range))
571            .collect::<Result<Vec<ScalarValue>>>()?;
572        let result = ScalarValue::iter_to_array(result)?;
573        let result = as_int32_array(&result)?;
574        assert_eq!(expected, *result);
575        Ok(())
576    }
577
578    #[test]
579    fn first_value() -> Result<()> {
580        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
581        test_i32_result(
582            NthValue::first(),
583            PartitionEvaluatorArgs::new(
584                &[expr],
585                &[Field::new("f", DataType::Int32, true).into()],
586                false,
587                false,
588            ),
589            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
590        )
591    }
592
593    #[test]
594    fn last_value() -> Result<()> {
595        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
596        test_i32_result(
597            NthValue::last(),
598            PartitionEvaluatorArgs::new(
599                &[expr],
600                &[Field::new("f", DataType::Int32, true).into()],
601                false,
602                false,
603            ),
604            Int32Array::from(vec![
605                Some(1),
606                Some(-2),
607                Some(3),
608                Some(-4),
609                Some(5),
610                Some(-6),
611                Some(7),
612                Some(8),
613            ]),
614        )
615    }
616
617    #[test]
618    fn nth_value_1() -> Result<()> {
619        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
620        let n_value =
621            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
622
623        test_i32_result(
624            NthValue::nth(),
625            PartitionEvaluatorArgs::new(
626                &[expr, n_value],
627                &[Field::new("f", DataType::Int32, true).into()],
628                false,
629                false,
630            ),
631            Int32Array::from(vec![1; 8]),
632        )?;
633        Ok(())
634    }
635
636    #[test]
637    fn nth_value_2() -> Result<()> {
638        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
639        let n_value =
640            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
641
642        test_i32_result(
643            NthValue::nth(),
644            PartitionEvaluatorArgs::new(
645                &[expr, n_value],
646                &[Field::new("f", DataType::Int32, true).into()],
647                false,
648                false,
649            ),
650            Int32Array::from(vec![
651                None,
652                Some(-2),
653                Some(-2),
654                Some(-2),
655                Some(-2),
656                Some(-2),
657                Some(-2),
658                Some(-2),
659            ]),
660        )?;
661        Ok(())
662    }
663}