datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
30    TypeSignature, Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use field::WindowUDFFieldArgs;
36use std::any::Any;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43define_udwf_and_expr!(
44    First,
45    first_value,
46    [arg],
47    "Returns the first value in the window frame",
48    NthValue::first
49);
50define_udwf_and_expr!(
51    Last,
52    last_value,
53    [arg],
54    "Returns the last value in the window frame",
55    NthValue::last
56);
57get_or_init_udwf!(
58    NthValue,
59    nth_value,
60    "Returns the nth value in the window frame",
61    NthValue::nth
62);
63
64/// Create an expression to represent the `nth_value` window function
65pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
66    nth_value_udwf().call(vec![arg, n.lit()])
67}
68
69/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
70#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
71pub enum NthValueKind {
72    First,
73    Last,
74    Nth,
75}
76
77impl NthValueKind {
78    fn name(&self) -> &'static str {
79        match self {
80            NthValueKind::First => "first_value",
81            NthValueKind::Last => "last_value",
82            NthValueKind::Nth => "nth_value",
83        }
84    }
85}
86
87#[derive(Debug, PartialEq, Eq, Hash)]
88pub struct NthValue {
89    signature: Signature,
90    kind: NthValueKind,
91}
92
93impl NthValue {
94    /// Create a new `nth_value` function
95    pub fn new(kind: NthValueKind) -> Self {
96        Self {
97            signature: Signature::one_of(
98                vec![
99                    TypeSignature::Any(0),
100                    TypeSignature::Any(1),
101                    TypeSignature::Any(2),
102                ],
103                Volatility::Immutable,
104            ),
105            kind,
106        }
107    }
108
109    pub fn first() -> Self {
110        Self::new(NthValueKind::First)
111    }
112
113    pub fn last() -> Self {
114        Self::new(NthValueKind::Last)
115    }
116    pub fn nth() -> Self {
117        Self::new(NthValueKind::Nth)
118    }
119
120    pub fn kind(&self) -> &NthValueKind {
121        &self.kind
122    }
123}
124
125static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
126    Documentation::builder(
127        DOC_SECTION_ANALYTICAL,
128        "Returns value evaluated at the row that is the first row of the window \
129            frame.",
130        "first_value(expression)",
131    )
132    .with_argument("expression", "Expression to operate on")
133    .with_sql_example(
134        r#"
135```sql
136-- Example usage of the first_value window function:
137SELECT department,
138  employee_id,
139  salary,
140  first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
141FROM employees;
142
143+-------------+-------------+--------+------------+
144| department  | employee_id | salary | top_salary |
145+-------------+-------------+--------+------------+
146| Sales       | 1           | 70000  | 70000      |
147| Sales       | 2           | 50000  | 70000      |
148| Sales       | 3           | 30000  | 70000      |
149| Engineering | 4           | 90000  | 90000      |
150| Engineering | 5           | 80000  | 90000      |
151+-------------+-------------+--------+------------+
152```
153"#,
154    )
155    .build()
156});
157
158fn get_first_value_doc() -> &'static Documentation {
159    &FIRST_VALUE_DOCUMENTATION
160}
161
162static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
163    Documentation::builder(
164        DOC_SECTION_ANALYTICAL,
165        "Returns value evaluated at the row that is the last row of the window \
166            frame.",
167        "last_value(expression)",
168    )
169    .with_argument("expression", "Expression to operate on")
170        .with_sql_example(r#"```sql
171-- SQL example of last_value:
172SELECT department,
173       employee_id,
174       salary,
175       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
176FROM employees;
177
178+-------------+-------------+--------+---------------------+
179| department  | employee_id | salary | running_last_salary |
180+-------------+-------------+--------+---------------------+
181| Sales       | 1           | 30000  | 30000               |
182| Sales       | 2           | 50000  | 50000               |
183| Sales       | 3           | 70000  | 70000               |
184| Engineering | 4           | 40000  | 40000               |
185| Engineering | 5           | 60000  | 60000               |
186+-------------+-------------+--------+---------------------+
187```
188"#)
189    .build()
190});
191
192fn get_last_value_doc() -> &'static Documentation {
193    &LAST_VALUE_DOCUMENTATION
194}
195
196static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
197    Documentation::builder(
198        DOC_SECTION_ANALYTICAL,
199        "Returns the value evaluated at the nth row of the window frame \
200         (counting from 1). Returns NULL if no such row exists.",
201        "nth_value(expression, n)",
202    )
203    .with_argument(
204        "expression",
205        "The column from which to retrieve the nth value.",
206    )
207    .with_argument(
208        "n",
209        "Integer. Specifies the row number (starting from 1) in the window frame.",
210    )
211    .with_sql_example(
212        r#"
213```sql
214-- Sample employees table:
215CREATE TABLE employees (id INT, salary INT);
216INSERT INTO employees (id, salary) VALUES
217(1, 30000),
218(2, 40000),
219(3, 50000),
220(4, 60000),
221(5, 70000);
222
223-- Example usage of nth_value:
224SELECT nth_value(salary, 2) OVER (
225  ORDER BY salary
226  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
227) AS nth_value
228FROM employees;
229
230+-----------+
231| nth_value |
232+-----------+
233| 40000     |
234| 40000     |
235| 40000     |
236| 40000     |
237| 40000     |
238+-----------+
239```
240"#,
241    )
242    .build()
243});
244
245fn get_nth_value_doc() -> &'static Documentation {
246    &NTH_VALUE_DOCUMENTATION
247}
248
249impl WindowUDFImpl for NthValue {
250    fn as_any(&self) -> &dyn Any {
251        self
252    }
253
254    fn name(&self) -> &str {
255        self.kind.name()
256    }
257
258    fn signature(&self) -> &Signature {
259        &self.signature
260    }
261
262    fn partition_evaluator(
263        &self,
264        partition_evaluator_args: PartitionEvaluatorArgs,
265    ) -> Result<Box<dyn PartitionEvaluator>> {
266        let state = NthValueState {
267            finalized_result: None,
268            kind: self.kind,
269        };
270
271        if !matches!(self.kind, NthValueKind::Nth) {
272            return Ok(Box::new(NthValueEvaluator {
273                state,
274                ignore_nulls: partition_evaluator_args.ignore_nulls(),
275                n: 0,
276            }));
277        }
278
279        let n = match get_scalar_value_from_args(
280            partition_evaluator_args.input_exprs(),
281            1,
282        )
283        .map_err(|_e| {
284            exec_datafusion_err!(
285                "Expected a signed integer literal for the second argument of nth_value"
286            )
287        })?
288        .map(|v| get_signed_integer(&v))
289        {
290            Some(Ok(n)) => {
291                if partition_evaluator_args.is_reversed() {
292                    -n
293                } else {
294                    n
295                }
296            }
297            _ => {
298                return exec_err!(
299                    "Expected a signed integer literal for the second argument of nth_value"
300                );
301            }
302        };
303
304        Ok(Box::new(NthValueEvaluator {
305            state,
306            ignore_nulls: partition_evaluator_args.ignore_nulls(),
307            n,
308        }))
309    }
310
311    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
312        let return_type = field_args
313            .input_fields()
314            .first()
315            .map(|f| f.data_type())
316            .cloned()
317            .unwrap_or(DataType::Null);
318
319        Ok(Field::new(field_args.name(), return_type, true).into())
320    }
321
322    fn reverse_expr(&self) -> ReversedUDWF {
323        match self.kind {
324            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
325            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
326            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
327        }
328    }
329
330    fn documentation(&self) -> Option<&Documentation> {
331        match self.kind {
332            NthValueKind::First => Some(get_first_value_doc()),
333            NthValueKind::Last => Some(get_last_value_doc()),
334            NthValueKind::Nth => Some(get_nth_value_doc()),
335        }
336    }
337
338    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
339        LimitEffect::None // NthValue is causal
340    }
341}
342
343#[derive(Debug, Clone)]
344pub struct NthValueState {
345    // In certain cases, we can finalize the result early. Consider this usage:
346    // ```
347    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
348    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
349    // ```
350    // The result will always be the first entry in the table. We can store such
351    // early-finalizing results and then just reuse them as necessary. This opens
352    // opportunities to prune our datasets.
353    pub finalized_result: Option<ScalarValue>,
354    pub kind: NthValueKind,
355}
356
357#[derive(Debug)]
358pub(crate) struct NthValueEvaluator {
359    state: NthValueState,
360    ignore_nulls: bool,
361    n: i64,
362}
363
364impl PartitionEvaluator for NthValueEvaluator {
365    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
366    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
367    /// can memoize the result.  Once result is calculated, it will always stay
368    /// same. Hence, we do not need to keep past data as we process the entire
369    /// dataset.
370    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
371        let out = &state.out_col;
372        let size = out.len();
373        let mut buffer_size = 1;
374        // Decide if we arrived at a final result yet:
375        let (is_prunable, is_reverse_direction) = match self.state.kind {
376            NthValueKind::First => {
377                let n_range =
378                    state.window_frame_range.end - state.window_frame_range.start;
379                (n_range > 0 && size > 0, false)
380            }
381            NthValueKind::Last => (true, true),
382            NthValueKind::Nth => {
383                let n_range =
384                    state.window_frame_range.end - state.window_frame_range.start;
385                match self.n.cmp(&0) {
386                    Ordering::Greater => (
387                        n_range >= (self.n as usize) && size > (self.n as usize),
388                        false,
389                    ),
390                    Ordering::Less => {
391                        let reverse_index = (-self.n) as usize;
392                        buffer_size = reverse_index;
393                        // Negative index represents reverse direction.
394                        (n_range >= reverse_index, true)
395                    }
396                    Ordering::Equal => (false, false),
397                }
398            }
399        };
400        // Do not memoize results when nulls are ignored.
401        if is_prunable && !self.ignore_nulls {
402            if self.state.finalized_result.is_none() && !is_reverse_direction {
403                let result = ScalarValue::try_from_array(out, size - 1)?;
404                self.state.finalized_result = Some(result);
405            }
406            state.window_frame_range.start =
407                state.window_frame_range.end.saturating_sub(buffer_size);
408        }
409        Ok(())
410    }
411
412    fn evaluate(
413        &mut self,
414        values: &[ArrayRef],
415        range: &Range<usize>,
416    ) -> Result<ScalarValue> {
417        if let Some(ref result) = self.state.finalized_result {
418            Ok(result.clone())
419        } else {
420            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
421            let arr = &values[0];
422            let n_range = range.end - range.start;
423            if n_range == 0 {
424                // We produce None if the window is empty.
425                return ScalarValue::try_from(arr.data_type());
426            }
427
428            // If null values exist and need to be ignored, extract the valid indices.
429            let valid_indices = if self.ignore_nulls {
430                // Calculate valid indices, inside the window frame boundaries.
431                let slice = arr.slice(range.start, n_range);
432                match slice.nulls() {
433                    Some(nulls) => {
434                        let valid_indices = nulls
435                            .valid_indices()
436                            .map(|idx| {
437                                // Add offset `range.start` to valid indices, to point correct index in the original arr.
438                                idx + range.start
439                            })
440                            .collect::<Vec<_>>();
441                        if valid_indices.is_empty() {
442                            // If all values are null, return directly.
443                            return ScalarValue::try_from(arr.data_type());
444                        }
445                        Some(valid_indices)
446                    }
447                    None => None,
448                }
449            } else {
450                None
451            };
452            match self.state.kind {
453                NthValueKind::First => {
454                    if let Some(valid_indices) = &valid_indices {
455                        ScalarValue::try_from_array(arr, valid_indices[0])
456                    } else {
457                        ScalarValue::try_from_array(arr, range.start)
458                    }
459                }
460                NthValueKind::Last => {
461                    if let Some(valid_indices) = &valid_indices {
462                        ScalarValue::try_from_array(
463                            arr,
464                            valid_indices[valid_indices.len() - 1],
465                        )
466                    } else {
467                        ScalarValue::try_from_array(arr, range.end - 1)
468                    }
469                }
470                NthValueKind::Nth => {
471                    match self.n.cmp(&0) {
472                        Ordering::Greater => {
473                            // SQL indices are not 0-based.
474                            let index = (self.n as usize) - 1;
475                            if index >= n_range {
476                                // Outside the range, return NULL:
477                                ScalarValue::try_from(arr.data_type())
478                            } else if let Some(valid_indices) = valid_indices {
479                                if index >= valid_indices.len() {
480                                    return ScalarValue::try_from(arr.data_type());
481                                }
482                                ScalarValue::try_from_array(&arr, valid_indices[index])
483                            } else {
484                                ScalarValue::try_from_array(arr, range.start + index)
485                            }
486                        }
487                        Ordering::Less => {
488                            let reverse_index = (-self.n) as usize;
489                            if n_range < reverse_index {
490                                // Outside the range, return NULL:
491                                ScalarValue::try_from(arr.data_type())
492                            } else if let Some(valid_indices) = valid_indices {
493                                if reverse_index > valid_indices.len() {
494                                    return ScalarValue::try_from(arr.data_type());
495                                }
496                                let new_index =
497                                    valid_indices[valid_indices.len() - reverse_index];
498                                ScalarValue::try_from_array(&arr, new_index)
499                            } else {
500                                ScalarValue::try_from_array(
501                                    arr,
502                                    range.start + n_range - reverse_index,
503                                )
504                            }
505                        }
506                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
507                    }
508                }
509            }
510        }
511    }
512
513    fn supports_bounded_execution(&self) -> bool {
514        true
515    }
516
517    fn uses_window_frame(&self) -> bool {
518        true
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use arrow::array::*;
526    use datafusion_common::cast::as_int32_array;
527    use datafusion_physical_expr::expressions::{Column, Literal};
528    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
529    use std::sync::Arc;
530
531    fn test_i32_result(
532        expr: NthValue,
533        partition_evaluator_args: PartitionEvaluatorArgs,
534        expected: Int32Array,
535    ) -> Result<()> {
536        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
537        let values = vec![arr];
538        let mut ranges: Vec<Range<usize>> = vec![];
539        for i in 0..8 {
540            ranges.push(Range {
541                start: 0,
542                end: i + 1,
543            })
544        }
545        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
546        let result = ranges
547            .iter()
548            .map(|range| evaluator.evaluate(&values, range))
549            .collect::<Result<Vec<ScalarValue>>>()?;
550        let result = ScalarValue::iter_to_array(result.into_iter())?;
551        let result = as_int32_array(&result)?;
552        assert_eq!(expected, *result);
553        Ok(())
554    }
555
556    #[test]
557    fn first_value() -> Result<()> {
558        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
559        test_i32_result(
560            NthValue::first(),
561            PartitionEvaluatorArgs::new(
562                &[expr],
563                &[Field::new("f", DataType::Int32, true).into()],
564                false,
565                false,
566            ),
567            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
568        )
569    }
570
571    #[test]
572    fn last_value() -> Result<()> {
573        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
574        test_i32_result(
575            NthValue::last(),
576            PartitionEvaluatorArgs::new(
577                &[expr],
578                &[Field::new("f", DataType::Int32, true).into()],
579                false,
580                false,
581            ),
582            Int32Array::from(vec![
583                Some(1),
584                Some(-2),
585                Some(3),
586                Some(-4),
587                Some(5),
588                Some(-6),
589                Some(7),
590                Some(8),
591            ]),
592        )
593    }
594
595    #[test]
596    fn nth_value_1() -> Result<()> {
597        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
598        let n_value =
599            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
600
601        test_i32_result(
602            NthValue::nth(),
603            PartitionEvaluatorArgs::new(
604                &[expr, n_value],
605                &[Field::new("f", DataType::Int32, true).into()],
606                false,
607                false,
608            ),
609            Int32Array::from(vec![1; 8]),
610        )?;
611        Ok(())
612    }
613
614    #[test]
615    fn nth_value_2() -> Result<()> {
616        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
617        let n_value =
618            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
619
620        test_i32_result(
621            NthValue::nth(),
622            PartitionEvaluatorArgs::new(
623                &[expr, n_value],
624                &[Field::new("f", DataType::Int32, true).into()],
625                false,
626                false,
627            ),
628            Int32Array::from(vec![
629                None,
630                Some(-2),
631                Some(-2),
632                Some(-2),
633                Some(-2),
634                Some(-2),
635                Some(-2),
636                Some(-2),
637            ]),
638        )?;
639        Ok(())
640    }
641}