datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
30    Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::hash::{DefaultHasher, Hash, Hasher};
39use std::ops::Range;
40use std::sync::LazyLock;
41
42get_or_init_udwf!(
43    First,
44    first_value,
45    "returns the first value in the window frame",
46    NthValue::first
47);
48get_or_init_udwf!(
49    Last,
50    last_value,
51    "returns the last value in the window frame",
52    NthValue::last
53);
54get_or_init_udwf!(
55    NthValue,
56    nth_value,
57    "returns the nth value in the window frame",
58    NthValue::nth
59);
60
61/// Create an expression to represent the `first_value` window function
62///
63pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
64    first_value_udwf().call(vec![arg])
65}
66
67/// Create an expression to represent the `last_value` window function
68///
69pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
70    last_value_udwf().call(vec![arg])
71}
72
73/// Create an expression to represent the `nth_value` window function
74///
75pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
76    nth_value_udwf().call(vec![arg, n.lit()])
77}
78
79/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
80#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
81pub enum NthValueKind {
82    First,
83    Last,
84    Nth,
85}
86
87impl NthValueKind {
88    fn name(&self) -> &'static str {
89        match self {
90            NthValueKind::First => "first_value",
91            NthValueKind::Last => "last_value",
92            NthValueKind::Nth => "nth_value",
93        }
94    }
95}
96
97#[derive(Debug)]
98pub struct NthValue {
99    signature: Signature,
100    kind: NthValueKind,
101}
102
103impl NthValue {
104    /// Create a new `nth_value` function
105    pub fn new(kind: NthValueKind) -> Self {
106        Self {
107            signature: Signature::one_of(
108                vec![
109                    TypeSignature::Any(0),
110                    TypeSignature::Any(1),
111                    TypeSignature::Any(2),
112                ],
113                Volatility::Immutable,
114            ),
115            kind,
116        }
117    }
118
119    pub fn first() -> Self {
120        Self::new(NthValueKind::First)
121    }
122
123    pub fn last() -> Self {
124        Self::new(NthValueKind::Last)
125    }
126    pub fn nth() -> Self {
127        Self::new(NthValueKind::Nth)
128    }
129}
130
131static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
132    Documentation::builder(
133        DOC_SECTION_ANALYTICAL,
134        "Returns value evaluated at the row that is the first row of the window \
135            frame.",
136        "first_value(expression)",
137    )
138    .with_argument("expression", "Expression to operate on")
139        .with_sql_example(r#"```sql
140    --Example usage of the first_value window function:
141    SELECT department,
142           employee_id,
143           salary,
144           first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
145    FROM employees;
146```
147
148```sql
149+-------------+-------------+--------+------------+
150| department  | employee_id | salary | top_salary |
151+-------------+-------------+--------+------------+
152| Sales       | 1           | 70000  | 70000      |
153| Sales       | 2           | 50000  | 70000      |
154| Sales       | 3           | 30000  | 70000      |
155| Engineering | 4           | 90000  | 90000      |
156| Engineering | 5           | 80000  | 90000      |
157+-------------+-------------+--------+------------+
158```"#)
159    .build()
160});
161
162fn get_first_value_doc() -> &'static Documentation {
163    &FIRST_VALUE_DOCUMENTATION
164}
165
166static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
167    Documentation::builder(
168        DOC_SECTION_ANALYTICAL,
169        "Returns value evaluated at the row that is the last row of the window \
170            frame.",
171        "last_value(expression)",
172    )
173    .with_argument("expression", "Expression to operate on")
174        .with_sql_example(r#"```sql
175-- SQL example of last_value:
176SELECT department,
177       employee_id,
178       salary,
179       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
180FROM employees;
181```
182
183```sql
184+-------------+-------------+--------+---------------------+
185| department  | employee_id | salary | running_last_salary |
186+-------------+-------------+--------+---------------------+
187| Sales       | 1           | 30000  | 30000               |
188| Sales       | 2           | 50000  | 50000               |
189| Sales       | 3           | 70000  | 70000               |
190| Engineering | 4           | 40000  | 40000               |
191| Engineering | 5           | 60000  | 60000               |
192+-------------+-------------+--------+---------------------+
193```"#)
194    .build()
195});
196
197fn get_last_value_doc() -> &'static Documentation {
198    &LAST_VALUE_DOCUMENTATION
199}
200
201static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
202    Documentation::builder(
203        DOC_SECTION_ANALYTICAL,
204        "Returns the value evaluated at the nth row of the window frame \
205         (counting from 1). Returns NULL if no such row exists.",
206        "nth_value(expression, n)",
207    )
208    .with_argument(
209        "expression",
210        "The column from which to retrieve the nth value.",
211    )
212    .with_argument(
213        "n",
214        "Integer. Specifies the row number (starting from 1) in the window frame.",
215    )
216    .with_sql_example(
217        r#"```sql
218-- Sample employees table:
219CREATE TABLE employees (id INT, salary INT);
220INSERT INTO employees (id, salary) VALUES
221(1, 30000),
222(2, 40000),
223(3, 50000),
224(4, 60000),
225(5, 70000);
226
227-- Example usage of nth_value:
228SELECT nth_value(salary, 2) OVER (
229  ORDER BY salary
230  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
231) AS nth_value
232FROM employees;
233```
234
235```text
236+-----------+
237| nth_value |
238+-----------+
239| 40000     |
240| 40000     |
241| 40000     |
242| 40000     |
243| 40000     |
244+-----------+
245```"#,
246    )
247    .build()
248});
249
250fn get_nth_value_doc() -> &'static Documentation {
251    &NTH_VALUE_DOCUMENTATION
252}
253
254impl WindowUDFImpl for NthValue {
255    fn as_any(&self) -> &dyn Any {
256        self
257    }
258
259    fn name(&self) -> &str {
260        self.kind.name()
261    }
262
263    fn signature(&self) -> &Signature {
264        &self.signature
265    }
266
267    fn partition_evaluator(
268        &self,
269        partition_evaluator_args: PartitionEvaluatorArgs,
270    ) -> Result<Box<dyn PartitionEvaluator>> {
271        let state = NthValueState {
272            finalized_result: None,
273            kind: self.kind,
274        };
275
276        if !matches!(self.kind, NthValueKind::Nth) {
277            return Ok(Box::new(NthValueEvaluator {
278                state,
279                ignore_nulls: partition_evaluator_args.ignore_nulls(),
280                n: 0,
281            }));
282        }
283
284        let n =
285            match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
286                .map_err(|_e| {
287                    exec_datafusion_err!(
288                "Expected a signed integer literal for the second argument of nth_value")
289                })?
290                .map(get_signed_integer)
291            {
292                Some(Ok(n)) => {
293                    if partition_evaluator_args.is_reversed() {
294                        -n
295                    } else {
296                        n
297                    }
298                }
299                _ => {
300                    return exec_err!(
301                "Expected a signed integer literal for the second argument of nth_value"
302            )
303                }
304            };
305
306        Ok(Box::new(NthValueEvaluator {
307            state,
308            ignore_nulls: partition_evaluator_args.ignore_nulls(),
309            n,
310        }))
311    }
312
313    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
314        let return_type = field_args
315            .input_fields()
316            .first()
317            .map(|f| f.data_type())
318            .cloned()
319            .unwrap_or(DataType::Null);
320
321        Ok(Field::new(field_args.name(), return_type, true).into())
322    }
323
324    fn reverse_expr(&self) -> ReversedUDWF {
325        match self.kind {
326            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
327            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
328            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
329        }
330    }
331
332    fn documentation(&self) -> Option<&Documentation> {
333        match self.kind {
334            NthValueKind::First => Some(get_first_value_doc()),
335            NthValueKind::Last => Some(get_last_value_doc()),
336            NthValueKind::Nth => Some(get_nth_value_doc()),
337        }
338    }
339
340    fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
341        let Some(other) = other.as_any().downcast_ref::<Self>() else {
342            return false;
343        };
344        let Self { signature, kind } = self;
345        signature == &other.signature && kind == &other.kind
346    }
347
348    fn hash_value(&self) -> u64 {
349        let Self { signature, kind } = self;
350        let mut hasher = DefaultHasher::new();
351        std::any::type_name::<Self>().hash(&mut hasher);
352        signature.hash(&mut hasher);
353        kind.hash(&mut hasher);
354        hasher.finish()
355    }
356}
357
358#[derive(Debug, Clone)]
359pub struct NthValueState {
360    // In certain cases, we can finalize the result early. Consider this usage:
361    // ```
362    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
363    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
364    // ```
365    // The result will always be the first entry in the table. We can store such
366    // early-finalizing results and then just reuse them as necessary. This opens
367    // opportunities to prune our datasets.
368    pub finalized_result: Option<ScalarValue>,
369    pub kind: NthValueKind,
370}
371
372#[derive(Debug)]
373pub(crate) struct NthValueEvaluator {
374    state: NthValueState,
375    ignore_nulls: bool,
376    n: i64,
377}
378
379impl PartitionEvaluator for NthValueEvaluator {
380    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
381    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
382    /// can memoize the result.  Once result is calculated, it will always stay
383    /// same. Hence, we do not need to keep past data as we process the entire
384    /// dataset.
385    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
386        let out = &state.out_col;
387        let size = out.len();
388        let mut buffer_size = 1;
389        // Decide if we arrived at a final result yet:
390        let (is_prunable, is_reverse_direction) = match self.state.kind {
391            NthValueKind::First => {
392                let n_range =
393                    state.window_frame_range.end - state.window_frame_range.start;
394                (n_range > 0 && size > 0, false)
395            }
396            NthValueKind::Last => (true, true),
397            NthValueKind::Nth => {
398                let n_range =
399                    state.window_frame_range.end - state.window_frame_range.start;
400                match self.n.cmp(&0) {
401                    Ordering::Greater => (
402                        n_range >= (self.n as usize) && size > (self.n as usize),
403                        false,
404                    ),
405                    Ordering::Less => {
406                        let reverse_index = (-self.n) as usize;
407                        buffer_size = reverse_index;
408                        // Negative index represents reverse direction.
409                        (n_range >= reverse_index, true)
410                    }
411                    Ordering::Equal => (false, false),
412                }
413            }
414        };
415        // Do not memoize results when nulls are ignored.
416        if is_prunable && !self.ignore_nulls {
417            if self.state.finalized_result.is_none() && !is_reverse_direction {
418                let result = ScalarValue::try_from_array(out, size - 1)?;
419                self.state.finalized_result = Some(result);
420            }
421            state.window_frame_range.start =
422                state.window_frame_range.end.saturating_sub(buffer_size);
423        }
424        Ok(())
425    }
426
427    fn evaluate(
428        &mut self,
429        values: &[ArrayRef],
430        range: &Range<usize>,
431    ) -> Result<ScalarValue> {
432        if let Some(ref result) = self.state.finalized_result {
433            Ok(result.clone())
434        } else {
435            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
436            let arr = &values[0];
437            let n_range = range.end - range.start;
438            if n_range == 0 {
439                // We produce None if the window is empty.
440                return ScalarValue::try_from(arr.data_type());
441            }
442
443            // If null values exist and need to be ignored, extract the valid indices.
444            let valid_indices = if self.ignore_nulls {
445                // Calculate valid indices, inside the window frame boundaries.
446                let slice = arr.slice(range.start, n_range);
447                match slice.nulls() {
448                    Some(nulls) => {
449                        let valid_indices = nulls
450                            .valid_indices()
451                            .map(|idx| {
452                                // Add offset `range.start` to valid indices, to point correct index in the original arr.
453                                idx + range.start
454                            })
455                            .collect::<Vec<_>>();
456                        if valid_indices.is_empty() {
457                            // If all values are null, return directly.
458                            return ScalarValue::try_from(arr.data_type());
459                        }
460                        Some(valid_indices)
461                    }
462                    None => None,
463                }
464            } else {
465                None
466            };
467            match self.state.kind {
468                NthValueKind::First => {
469                    if let Some(valid_indices) = &valid_indices {
470                        ScalarValue::try_from_array(arr, valid_indices[0])
471                    } else {
472                        ScalarValue::try_from_array(arr, range.start)
473                    }
474                }
475                NthValueKind::Last => {
476                    if let Some(valid_indices) = &valid_indices {
477                        ScalarValue::try_from_array(
478                            arr,
479                            valid_indices[valid_indices.len() - 1],
480                        )
481                    } else {
482                        ScalarValue::try_from_array(arr, range.end - 1)
483                    }
484                }
485                NthValueKind::Nth => {
486                    match self.n.cmp(&0) {
487                        Ordering::Greater => {
488                            // SQL indices are not 0-based.
489                            let index = (self.n as usize) - 1;
490                            if index >= n_range {
491                                // Outside the range, return NULL:
492                                ScalarValue::try_from(arr.data_type())
493                            } else if let Some(valid_indices) = valid_indices {
494                                if index >= valid_indices.len() {
495                                    return ScalarValue::try_from(arr.data_type());
496                                }
497                                ScalarValue::try_from_array(&arr, valid_indices[index])
498                            } else {
499                                ScalarValue::try_from_array(arr, range.start + index)
500                            }
501                        }
502                        Ordering::Less => {
503                            let reverse_index = (-self.n) as usize;
504                            if n_range < reverse_index {
505                                // Outside the range, return NULL:
506                                ScalarValue::try_from(arr.data_type())
507                            } else if let Some(valid_indices) = valid_indices {
508                                if reverse_index > valid_indices.len() {
509                                    return ScalarValue::try_from(arr.data_type());
510                                }
511                                let new_index =
512                                    valid_indices[valid_indices.len() - reverse_index];
513                                ScalarValue::try_from_array(&arr, new_index)
514                            } else {
515                                ScalarValue::try_from_array(
516                                    arr,
517                                    range.start + n_range - reverse_index,
518                                )
519                            }
520                        }
521                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
522                    }
523                }
524            }
525        }
526    }
527
528    fn supports_bounded_execution(&self) -> bool {
529        true
530    }
531
532    fn uses_window_frame(&self) -> bool {
533        true
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use arrow::array::*;
541    use datafusion_common::cast::as_int32_array;
542    use datafusion_physical_expr::expressions::{Column, Literal};
543    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
544    use std::sync::Arc;
545
546    fn test_i32_result(
547        expr: NthValue,
548        partition_evaluator_args: PartitionEvaluatorArgs,
549        expected: Int32Array,
550    ) -> Result<()> {
551        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
552        let values = vec![arr];
553        let mut ranges: Vec<Range<usize>> = vec![];
554        for i in 0..8 {
555            ranges.push(Range {
556                start: 0,
557                end: i + 1,
558            })
559        }
560        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
561        let result = ranges
562            .iter()
563            .map(|range| evaluator.evaluate(&values, range))
564            .collect::<Result<Vec<ScalarValue>>>()?;
565        let result = ScalarValue::iter_to_array(result.into_iter())?;
566        let result = as_int32_array(&result)?;
567        assert_eq!(expected, *result);
568        Ok(())
569    }
570
571    #[test]
572    fn first_value() -> Result<()> {
573        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
574        test_i32_result(
575            NthValue::first(),
576            PartitionEvaluatorArgs::new(
577                &[expr],
578                &[Field::new("f", DataType::Int32, true).into()],
579                false,
580                false,
581            ),
582            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
583        )
584    }
585
586    #[test]
587    fn last_value() -> Result<()> {
588        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
589        test_i32_result(
590            NthValue::last(),
591            PartitionEvaluatorArgs::new(
592                &[expr],
593                &[Field::new("f", DataType::Int32, true).into()],
594                false,
595                false,
596            ),
597            Int32Array::from(vec![
598                Some(1),
599                Some(-2),
600                Some(3),
601                Some(-4),
602                Some(5),
603                Some(-6),
604                Some(7),
605                Some(8),
606            ]),
607        )
608    }
609
610    #[test]
611    fn nth_value_1() -> Result<()> {
612        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
613        let n_value =
614            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
615
616        test_i32_result(
617            NthValue::nth(),
618            PartitionEvaluatorArgs::new(
619                &[expr, n_value],
620                &[Field::new("f", DataType::Int32, true).into()],
621                false,
622                false,
623            ),
624            Int32Array::from(vec![1; 8]),
625        )?;
626        Ok(())
627    }
628
629    #[test]
630    fn nth_value_2() -> Result<()> {
631        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
632        let n_value =
633            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
634
635        test_i32_result(
636            NthValue::nth(),
637            PartitionEvaluatorArgs::new(
638                &[expr, n_value],
639                &[Field::new("f", DataType::Int32, true).into()],
640                false,
641                false,
642            ),
643            Int32Array::from(vec![
644                None,
645                Some(-2),
646                Some(-2),
647                Some(-2),
648                Some(-2),
649                Some(-2),
650                Some(-2),
651                Some(-2),
652            ]),
653        )?;
654        Ok(())
655    }
656}