Skip to main content

datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::buffer::NullBuffer;
23use arrow::datatypes::FieldRef;
24use datafusion_common::arrow::array::ArrayRef;
25use datafusion_common::arrow::datatypes::{DataType, Field};
26use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err};
27use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
28use datafusion_expr::window_state::WindowAggState;
29use datafusion_expr::{
30    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
31    TypeSignature, Volatility, WindowUDFImpl,
32};
33use datafusion_functions_window_common::field;
34use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use field::WindowUDFFieldArgs;
37use std::any::Any;
38use std::cmp::Ordering;
39use std::fmt::Debug;
40use std::hash::Hash;
41use std::ops::Range;
42use std::sync::{Arc, LazyLock};
43
44define_udwf_and_expr!(
45    First,
46    first_value,
47    [arg],
48    "Returns the first value in the window frame",
49    NthValue::first
50);
51define_udwf_and_expr!(
52    Last,
53    last_value,
54    [arg],
55    "Returns the last value in the window frame",
56    NthValue::last
57);
58get_or_init_udwf!(
59    NthValue,
60    nth_value,
61    "Returns the nth value in the window frame",
62    NthValue::nth
63);
64
65/// Create an expression to represent the `nth_value` window function
66pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
67    nth_value_udwf().call(vec![arg, n.lit()])
68}
69
70/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
71#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
72pub enum NthValueKind {
73    First,
74    Last,
75    Nth,
76}
77
78impl NthValueKind {
79    fn name(&self) -> &'static str {
80        match self {
81            NthValueKind::First => "first_value",
82            NthValueKind::Last => "last_value",
83            NthValueKind::Nth => "nth_value",
84        }
85    }
86}
87
88#[derive(Debug, PartialEq, Eq, Hash)]
89pub struct NthValue {
90    signature: Signature,
91    kind: NthValueKind,
92}
93
94impl NthValue {
95    /// Create a new `nth_value` function
96    pub fn new(kind: NthValueKind) -> Self {
97        Self {
98            signature: Signature::one_of(
99                vec![
100                    TypeSignature::Nullary,
101                    TypeSignature::Any(1),
102                    TypeSignature::Any(2),
103                ],
104                Volatility::Immutable,
105            ),
106            kind,
107        }
108    }
109
110    pub fn first() -> Self {
111        Self::new(NthValueKind::First)
112    }
113
114    pub fn last() -> Self {
115        Self::new(NthValueKind::Last)
116    }
117    pub fn nth() -> Self {
118        Self::new(NthValueKind::Nth)
119    }
120
121    pub fn kind(&self) -> &NthValueKind {
122        &self.kind
123    }
124}
125
126static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
127    Documentation::builder(
128        DOC_SECTION_ANALYTICAL,
129        "Returns value evaluated at the row that is the first row of the window \
130            frame.",
131        "first_value(expression)",
132    )
133    .with_argument("expression", "Expression to operate on")
134    .with_sql_example(
135        r#"
136```sql
137-- Example usage of the first_value window function:
138SELECT department,
139  employee_id,
140  salary,
141  first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
142FROM employees;
143
144+-------------+-------------+--------+------------+
145| department  | employee_id | salary | top_salary |
146+-------------+-------------+--------+------------+
147| Sales       | 1           | 70000  | 70000      |
148| Sales       | 2           | 50000  | 70000      |
149| Sales       | 3           | 30000  | 70000      |
150| Engineering | 4           | 90000  | 90000      |
151| Engineering | 5           | 80000  | 90000      |
152+-------------+-------------+--------+------------+
153```
154"#,
155    )
156    .build()
157});
158
159fn get_first_value_doc() -> &'static Documentation {
160    &FIRST_VALUE_DOCUMENTATION
161}
162
163static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
164    Documentation::builder(
165        DOC_SECTION_ANALYTICAL,
166        "Returns value evaluated at the row that is the last row of the window \
167            frame.",
168        "last_value(expression)",
169    )
170    .with_argument("expression", "Expression to operate on")
171        .with_sql_example(r#"```sql
172-- SQL example of last_value:
173SELECT department,
174       employee_id,
175       salary,
176       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
177FROM employees;
178
179+-------------+-------------+--------+---------------------+
180| department  | employee_id | salary | running_last_salary |
181+-------------+-------------+--------+---------------------+
182| Sales       | 1           | 30000  | 30000               |
183| Sales       | 2           | 50000  | 50000               |
184| Sales       | 3           | 70000  | 70000               |
185| Engineering | 4           | 40000  | 40000               |
186| Engineering | 5           | 60000  | 60000               |
187+-------------+-------------+--------+---------------------+
188```
189"#)
190    .build()
191});
192
193fn get_last_value_doc() -> &'static Documentation {
194    &LAST_VALUE_DOCUMENTATION
195}
196
197static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
198    Documentation::builder(
199        DOC_SECTION_ANALYTICAL,
200        "Returns the value evaluated at the nth row of the window frame \
201         (counting from 1). Returns NULL if no such row exists.",
202        "nth_value(expression, n)",
203    )
204    .with_argument(
205        "expression",
206        "The column from which to retrieve the nth value.",
207    )
208    .with_argument(
209        "n",
210        "Integer. Specifies the row number (starting from 1) in the window frame.",
211    )
212    .with_sql_example(
213        r#"
214```sql
215-- Sample employees table:
216CREATE TABLE employees (id INT, salary INT);
217INSERT INTO employees (id, salary) VALUES
218(1, 30000),
219(2, 40000),
220(3, 50000),
221(4, 60000),
222(5, 70000);
223
224-- Example usage of nth_value:
225SELECT nth_value(salary, 2) OVER (
226  ORDER BY salary
227  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
228) AS nth_value
229FROM employees;
230
231+-----------+
232| nth_value |
233+-----------+
234| 40000     |
235| 40000     |
236| 40000     |
237| 40000     |
238| 40000     |
239+-----------+
240```
241"#,
242    )
243    .build()
244});
245
246fn get_nth_value_doc() -> &'static Documentation {
247    &NTH_VALUE_DOCUMENTATION
248}
249
250impl WindowUDFImpl for NthValue {
251    fn as_any(&self) -> &dyn Any {
252        self
253    }
254
255    fn name(&self) -> &str {
256        self.kind.name()
257    }
258
259    fn signature(&self) -> &Signature {
260        &self.signature
261    }
262
263    fn partition_evaluator(
264        &self,
265        partition_evaluator_args: PartitionEvaluatorArgs,
266    ) -> Result<Box<dyn PartitionEvaluator>> {
267        let state = NthValueState {
268            finalized_result: None,
269            kind: self.kind,
270        };
271
272        if self.kind != NthValueKind::Nth {
273            return Ok(Box::new(NthValueEvaluator {
274                state,
275                ignore_nulls: partition_evaluator_args.ignore_nulls(),
276                n: 0,
277            }));
278        }
279
280        let n = match get_scalar_value_from_args(
281            partition_evaluator_args.input_exprs(),
282            1,
283        )
284        .map_err(|_e| {
285            exec_datafusion_err!(
286                "Expected a signed integer literal for the second argument of nth_value"
287            )
288        })?
289        .map(|v| get_signed_integer(&v))
290        {
291            Some(Ok(n)) => {
292                if partition_evaluator_args.is_reversed() {
293                    -n
294                } else {
295                    n
296                }
297            }
298            _ => {
299                return exec_err!(
300                    "Expected a signed integer literal for the second argument of nth_value"
301                );
302            }
303        };
304
305        Ok(Box::new(NthValueEvaluator {
306            state,
307            ignore_nulls: partition_evaluator_args.ignore_nulls(),
308            n,
309        }))
310    }
311
312    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
313        let return_type = field_args
314            .input_fields()
315            .first()
316            .map(|f| f.data_type())
317            .cloned()
318            .unwrap_or(DataType::Null);
319
320        Ok(Field::new(field_args.name(), return_type, true).into())
321    }
322
323    fn reverse_expr(&self) -> ReversedUDWF {
324        match self.kind {
325            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
326            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
327            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
328        }
329    }
330
331    fn documentation(&self) -> Option<&Documentation> {
332        match self.kind {
333            NthValueKind::First => Some(get_first_value_doc()),
334            NthValueKind::Last => Some(get_last_value_doc()),
335            NthValueKind::Nth => Some(get_nth_value_doc()),
336        }
337    }
338
339    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
340        LimitEffect::None // NthValue is causal
341    }
342}
343
344#[derive(Debug, Clone)]
345pub struct NthValueState {
346    // In certain cases, we can finalize the result early. Consider this usage:
347    // ```
348    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
349    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
350    // ```
351    // The result will always be the first entry in the table. We can store such
352    // early-finalizing results and then just reuse them as necessary. This opens
353    // opportunities to prune our datasets.
354    pub finalized_result: Option<ScalarValue>,
355    pub kind: NthValueKind,
356}
357
358#[derive(Debug)]
359pub(crate) struct NthValueEvaluator {
360    state: NthValueState,
361    ignore_nulls: bool,
362    n: i64,
363}
364
365impl PartitionEvaluator for NthValueEvaluator {
366    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
367    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
368    /// can memoize the result.  Once result is calculated, it will always stay
369    /// same. Hence, we do not need to keep past data as we process the entire
370    /// dataset.
371    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
372        let out = &state.out_col;
373        let size = out.len();
374        if self.ignore_nulls {
375            match self.state.kind {
376                // Prune on first non-null output in case of FIRST_VALUE
377                NthValueKind::First => {
378                    if let Some(nulls) = out.nulls() {
379                        if self.state.finalized_result.is_none() {
380                            if let Some(valid_index) = nulls.valid_indices().next() {
381                                let result =
382                                    ScalarValue::try_from_array(out, valid_index)?;
383                                self.state.finalized_result = Some(result);
384                            } else {
385                                // The output is empty or all nulls, ignore
386                            }
387                        }
388                        if state.window_frame_range.start < state.window_frame_range.end {
389                            state.window_frame_range.start =
390                                state.window_frame_range.end - 1;
391                        }
392                        return Ok(());
393                    } else {
394                        // Fall through to the main case because there are no nulls
395                    }
396                }
397                // Do not memoize for other kinds when nulls are ignored
398                NthValueKind::Last | NthValueKind::Nth => return Ok(()),
399            }
400        }
401        let mut buffer_size = 1;
402        // Decide if we arrived at a final result yet:
403        let (is_prunable, is_reverse_direction) = match self.state.kind {
404            NthValueKind::First => {
405                let n_range =
406                    state.window_frame_range.end - state.window_frame_range.start;
407                (n_range > 0 && size > 0, false)
408            }
409            NthValueKind::Last => (true, true),
410            NthValueKind::Nth => {
411                let n_range =
412                    state.window_frame_range.end - state.window_frame_range.start;
413                match self.n.cmp(&0) {
414                    Ordering::Greater => (
415                        n_range >= (self.n as usize) && size > (self.n as usize),
416                        false,
417                    ),
418                    Ordering::Less => {
419                        let reverse_index = (-self.n) as usize;
420                        buffer_size = reverse_index;
421                        // Negative index represents reverse direction.
422                        (n_range >= reverse_index, true)
423                    }
424                    Ordering::Equal => (false, false),
425                }
426            }
427        };
428        if is_prunable {
429            if self.state.finalized_result.is_none() && !is_reverse_direction {
430                let result = ScalarValue::try_from_array(out, size - 1)?;
431                self.state.finalized_result = Some(result);
432            }
433            state.window_frame_range.start =
434                state.window_frame_range.end.saturating_sub(buffer_size);
435        }
436        Ok(())
437    }
438
439    fn evaluate(
440        &mut self,
441        values: &[ArrayRef],
442        range: &Range<usize>,
443    ) -> Result<ScalarValue> {
444        if let Some(ref result) = self.state.finalized_result {
445            Ok(result.clone())
446        } else {
447            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
448            let arr = &values[0];
449            let n_range = range.end - range.start;
450            if n_range == 0 {
451                // We produce None if the window is empty.
452                return ScalarValue::try_from(arr.data_type());
453            }
454            match self.valid_index(arr, range) {
455                Some(index) => ScalarValue::try_from_array(arr, index),
456                None => ScalarValue::try_from(arr.data_type()),
457            }
458        }
459    }
460
461    fn supports_bounded_execution(&self) -> bool {
462        true
463    }
464
465    fn uses_window_frame(&self) -> bool {
466        true
467    }
468}
469
470impl NthValueEvaluator {
471    fn valid_index(&self, array: &ArrayRef, range: &Range<usize>) -> Option<usize> {
472        let n_range = range.end - range.start;
473        if self.ignore_nulls {
474            // Calculate valid indices, inside the window frame boundaries.
475            let slice = array.slice(range.start, n_range);
476            if let Some(nulls) = slice.nulls()
477                && nulls.null_count() > 0
478            {
479                return self.valid_index_with_nulls(nulls, range.start);
480            }
481        }
482        // Either no nulls, or nulls are regarded as valid rows
483        match self.state.kind {
484            NthValueKind::First => Some(range.start),
485            NthValueKind::Last => Some(range.end - 1),
486            NthValueKind::Nth => match self.n.cmp(&0) {
487                Ordering::Greater => {
488                    // SQL indices are not 0-based.
489                    let index = (self.n as usize) - 1;
490                    if index >= n_range {
491                        // Outside the range, return NULL:
492                        None
493                    } else {
494                        Some(range.start + index)
495                    }
496                }
497                Ordering::Less => {
498                    let reverse_index = (-self.n) as usize;
499                    if n_range < reverse_index {
500                        // Outside the range, return NULL:
501                        None
502                    } else {
503                        Some(range.end - reverse_index)
504                    }
505                }
506                Ordering::Equal => None,
507            },
508        }
509    }
510
511    fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option<usize> {
512        match self.state.kind {
513            NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset),
514            NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset),
515            NthValueKind::Nth => {
516                match self.n.cmp(&0) {
517                    Ordering::Greater => {
518                        // SQL indices are not 0-based.
519                        let index = (self.n as usize) - 1;
520                        nulls.valid_indices().nth(index).map(|idx| idx + offset)
521                    }
522                    Ordering::Less => {
523                        let reverse_index = (-self.n) as usize;
524                        let valid_indices_len = nulls.len() - nulls.null_count();
525                        if reverse_index > valid_indices_len {
526                            return None;
527                        }
528                        nulls
529                            .valid_indices()
530                            .nth(valid_indices_len - reverse_index)
531                            .map(|idx| idx + offset)
532                    }
533                    Ordering::Equal => None,
534                }
535            }
536        }
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use arrow::array::*;
544    use datafusion_common::cast::as_int32_array;
545    use datafusion_physical_expr::expressions::{Column, Literal};
546    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
547    use std::sync::Arc;
548
549    fn test_i32_result(
550        expr: NthValue,
551        partition_evaluator_args: PartitionEvaluatorArgs,
552        expected: Int32Array,
553    ) -> Result<()> {
554        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
555        let values = vec![arr];
556        let mut ranges: Vec<Range<usize>> = vec![];
557        for i in 0..8 {
558            ranges.push(Range {
559                start: 0,
560                end: i + 1,
561            })
562        }
563        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
564        let result = ranges
565            .iter()
566            .map(|range| evaluator.evaluate(&values, range))
567            .collect::<Result<Vec<ScalarValue>>>()?;
568        let result = ScalarValue::iter_to_array(result.into_iter())?;
569        let result = as_int32_array(&result)?;
570        assert_eq!(expected, *result);
571        Ok(())
572    }
573
574    #[test]
575    fn first_value() -> Result<()> {
576        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
577        test_i32_result(
578            NthValue::first(),
579            PartitionEvaluatorArgs::new(
580                &[expr],
581                &[Field::new("f", DataType::Int32, true).into()],
582                false,
583                false,
584            ),
585            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
586        )
587    }
588
589    #[test]
590    fn last_value() -> Result<()> {
591        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
592        test_i32_result(
593            NthValue::last(),
594            PartitionEvaluatorArgs::new(
595                &[expr],
596                &[Field::new("f", DataType::Int32, true).into()],
597                false,
598                false,
599            ),
600            Int32Array::from(vec![
601                Some(1),
602                Some(-2),
603                Some(3),
604                Some(-4),
605                Some(5),
606                Some(-6),
607                Some(7),
608                Some(8),
609            ]),
610        )
611    }
612
613    #[test]
614    fn nth_value_1() -> Result<()> {
615        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
616        let n_value =
617            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
618
619        test_i32_result(
620            NthValue::nth(),
621            PartitionEvaluatorArgs::new(
622                &[expr, n_value],
623                &[Field::new("f", DataType::Int32, true).into()],
624                false,
625                false,
626            ),
627            Int32Array::from(vec![1; 8]),
628        )?;
629        Ok(())
630    }
631
632    #[test]
633    fn nth_value_2() -> Result<()> {
634        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
635        let n_value =
636            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
637
638        test_i32_result(
639            NthValue::nth(),
640            PartitionEvaluatorArgs::new(
641                &[expr, n_value],
642                &[Field::new("f", DataType::Int32, true).into()],
643                false,
644                false,
645            ),
646            Int32Array::from(vec![
647                None,
648                Some(-2),
649                Some(-2),
650                Some(-2),
651                Some(-2),
652                Some(-2),
653                Some(-2),
654                Some(-2),
655            ]),
656        )?;
657        Ok(())
658    }
659}