datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
30    TypeSignature, Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use field::WindowUDFFieldArgs;
36use std::any::Any;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43get_or_init_udwf!(
44    First,
45    first_value,
46    "returns the first value in the window frame",
47    NthValue::first
48);
49get_or_init_udwf!(
50    Last,
51    last_value,
52    "returns the last value in the window frame",
53    NthValue::last
54);
55get_or_init_udwf!(
56    NthValue,
57    nth_value,
58    "returns the nth value in the window frame",
59    NthValue::nth
60);
61
62/// Create an expression to represent the `first_value` window function
63///
64pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
65    first_value_udwf().call(vec![arg])
66}
67
68/// Create an expression to represent the `last_value` window function
69///
70pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
71    last_value_udwf().call(vec![arg])
72}
73
74/// Create an expression to represent the `nth_value` window function
75///
76pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
77    nth_value_udwf().call(vec![arg, n.lit()])
78}
79
80/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
81#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
82pub enum NthValueKind {
83    First,
84    Last,
85    Nth,
86}
87
88impl NthValueKind {
89    fn name(&self) -> &'static str {
90        match self {
91            NthValueKind::First => "first_value",
92            NthValueKind::Last => "last_value",
93            NthValueKind::Nth => "nth_value",
94        }
95    }
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub struct NthValue {
100    signature: Signature,
101    kind: NthValueKind,
102}
103
104impl NthValue {
105    /// Create a new `nth_value` function
106    pub fn new(kind: NthValueKind) -> Self {
107        Self {
108            signature: Signature::one_of(
109                vec![
110                    TypeSignature::Any(0),
111                    TypeSignature::Any(1),
112                    TypeSignature::Any(2),
113                ],
114                Volatility::Immutable,
115            ),
116            kind,
117        }
118    }
119
120    pub fn first() -> Self {
121        Self::new(NthValueKind::First)
122    }
123
124    pub fn last() -> Self {
125        Self::new(NthValueKind::Last)
126    }
127    pub fn nth() -> Self {
128        Self::new(NthValueKind::Nth)
129    }
130
131    pub fn kind(&self) -> &NthValueKind {
132        &self.kind
133    }
134}
135
136static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
137    Documentation::builder(
138        DOC_SECTION_ANALYTICAL,
139        "Returns value evaluated at the row that is the first row of the window \
140            frame.",
141        "first_value(expression)",
142    )
143    .with_argument("expression", "Expression to operate on")
144    .with_sql_example(
145        r#"
146```sql
147-- Example usage of the first_value window function:
148SELECT department,
149  employee_id,
150  salary,
151  first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
152FROM employees;
153
154+-------------+-------------+--------+------------+
155| department  | employee_id | salary | top_salary |
156+-------------+-------------+--------+------------+
157| Sales       | 1           | 70000  | 70000      |
158| Sales       | 2           | 50000  | 70000      |
159| Sales       | 3           | 30000  | 70000      |
160| Engineering | 4           | 90000  | 90000      |
161| Engineering | 5           | 80000  | 90000      |
162+-------------+-------------+--------+------------+
163```
164"#,
165    )
166    .build()
167});
168
169fn get_first_value_doc() -> &'static Documentation {
170    &FIRST_VALUE_DOCUMENTATION
171}
172
173static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
174    Documentation::builder(
175        DOC_SECTION_ANALYTICAL,
176        "Returns value evaluated at the row that is the last row of the window \
177            frame.",
178        "last_value(expression)",
179    )
180    .with_argument("expression", "Expression to operate on")
181        .with_sql_example(r#"```sql
182-- SQL example of last_value:
183SELECT department,
184       employee_id,
185       salary,
186       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
187FROM employees;
188
189+-------------+-------------+--------+---------------------+
190| department  | employee_id | salary | running_last_salary |
191+-------------+-------------+--------+---------------------+
192| Sales       | 1           | 30000  | 30000               |
193| Sales       | 2           | 50000  | 50000               |
194| Sales       | 3           | 70000  | 70000               |
195| Engineering | 4           | 40000  | 40000               |
196| Engineering | 5           | 60000  | 60000               |
197+-------------+-------------+--------+---------------------+
198```
199"#)
200    .build()
201});
202
203fn get_last_value_doc() -> &'static Documentation {
204    &LAST_VALUE_DOCUMENTATION
205}
206
207static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
208    Documentation::builder(
209        DOC_SECTION_ANALYTICAL,
210        "Returns the value evaluated at the nth row of the window frame \
211         (counting from 1). Returns NULL if no such row exists.",
212        "nth_value(expression, n)",
213    )
214    .with_argument(
215        "expression",
216        "The column from which to retrieve the nth value.",
217    )
218    .with_argument(
219        "n",
220        "Integer. Specifies the row number (starting from 1) in the window frame.",
221    )
222    .with_sql_example(
223        r#"
224```sql
225-- Sample employees table:
226CREATE TABLE employees (id INT, salary INT);
227INSERT INTO employees (id, salary) VALUES
228(1, 30000),
229(2, 40000),
230(3, 50000),
231(4, 60000),
232(5, 70000);
233
234-- Example usage of nth_value:
235SELECT nth_value(salary, 2) OVER (
236  ORDER BY salary
237  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
238) AS nth_value
239FROM employees;
240
241+-----------+
242| nth_value |
243+-----------+
244| 40000     |
245| 40000     |
246| 40000     |
247| 40000     |
248| 40000     |
249+-----------+
250```
251"#,
252    )
253    .build()
254});
255
256fn get_nth_value_doc() -> &'static Documentation {
257    &NTH_VALUE_DOCUMENTATION
258}
259
260impl WindowUDFImpl for NthValue {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264
265    fn name(&self) -> &str {
266        self.kind.name()
267    }
268
269    fn signature(&self) -> &Signature {
270        &self.signature
271    }
272
273    fn partition_evaluator(
274        &self,
275        partition_evaluator_args: PartitionEvaluatorArgs,
276    ) -> Result<Box<dyn PartitionEvaluator>> {
277        let state = NthValueState {
278            finalized_result: None,
279            kind: self.kind,
280        };
281
282        if !matches!(self.kind, NthValueKind::Nth) {
283            return Ok(Box::new(NthValueEvaluator {
284                state,
285                ignore_nulls: partition_evaluator_args.ignore_nulls(),
286                n: 0,
287            }));
288        }
289
290        let n =
291            match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
292                .map_err(|_e| {
293                    exec_datafusion_err!(
294                "Expected a signed integer literal for the second argument of nth_value")
295                })?
296                .map(get_signed_integer)
297            {
298                Some(Ok(n)) => {
299                    if partition_evaluator_args.is_reversed() {
300                        -n
301                    } else {
302                        n
303                    }
304                }
305                _ => {
306                    return exec_err!(
307                "Expected a signed integer literal for the second argument of nth_value"
308            )
309                }
310            };
311
312        Ok(Box::new(NthValueEvaluator {
313            state,
314            ignore_nulls: partition_evaluator_args.ignore_nulls(),
315            n,
316        }))
317    }
318
319    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
320        let return_type = field_args
321            .input_fields()
322            .first()
323            .map(|f| f.data_type())
324            .cloned()
325            .unwrap_or(DataType::Null);
326
327        Ok(Field::new(field_args.name(), return_type, true).into())
328    }
329
330    fn reverse_expr(&self) -> ReversedUDWF {
331        match self.kind {
332            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
333            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
334            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
335        }
336    }
337
338    fn documentation(&self) -> Option<&Documentation> {
339        match self.kind {
340            NthValueKind::First => Some(get_first_value_doc()),
341            NthValueKind::Last => Some(get_last_value_doc()),
342            NthValueKind::Nth => Some(get_nth_value_doc()),
343        }
344    }
345
346    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
347        LimitEffect::None // NthValue is causal
348    }
349}
350
351#[derive(Debug, Clone)]
352pub struct NthValueState {
353    // In certain cases, we can finalize the result early. Consider this usage:
354    // ```
355    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
356    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
357    // ```
358    // The result will always be the first entry in the table. We can store such
359    // early-finalizing results and then just reuse them as necessary. This opens
360    // opportunities to prune our datasets.
361    pub finalized_result: Option<ScalarValue>,
362    pub kind: NthValueKind,
363}
364
365#[derive(Debug)]
366pub(crate) struct NthValueEvaluator {
367    state: NthValueState,
368    ignore_nulls: bool,
369    n: i64,
370}
371
372impl PartitionEvaluator for NthValueEvaluator {
373    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
374    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
375    /// can memoize the result.  Once result is calculated, it will always stay
376    /// same. Hence, we do not need to keep past data as we process the entire
377    /// dataset.
378    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
379        let out = &state.out_col;
380        let size = out.len();
381        let mut buffer_size = 1;
382        // Decide if we arrived at a final result yet:
383        let (is_prunable, is_reverse_direction) = match self.state.kind {
384            NthValueKind::First => {
385                let n_range =
386                    state.window_frame_range.end - state.window_frame_range.start;
387                (n_range > 0 && size > 0, false)
388            }
389            NthValueKind::Last => (true, true),
390            NthValueKind::Nth => {
391                let n_range =
392                    state.window_frame_range.end - state.window_frame_range.start;
393                match self.n.cmp(&0) {
394                    Ordering::Greater => (
395                        n_range >= (self.n as usize) && size > (self.n as usize),
396                        false,
397                    ),
398                    Ordering::Less => {
399                        let reverse_index = (-self.n) as usize;
400                        buffer_size = reverse_index;
401                        // Negative index represents reverse direction.
402                        (n_range >= reverse_index, true)
403                    }
404                    Ordering::Equal => (false, false),
405                }
406            }
407        };
408        // Do not memoize results when nulls are ignored.
409        if is_prunable && !self.ignore_nulls {
410            if self.state.finalized_result.is_none() && !is_reverse_direction {
411                let result = ScalarValue::try_from_array(out, size - 1)?;
412                self.state.finalized_result = Some(result);
413            }
414            state.window_frame_range.start =
415                state.window_frame_range.end.saturating_sub(buffer_size);
416        }
417        Ok(())
418    }
419
420    fn evaluate(
421        &mut self,
422        values: &[ArrayRef],
423        range: &Range<usize>,
424    ) -> Result<ScalarValue> {
425        if let Some(ref result) = self.state.finalized_result {
426            Ok(result.clone())
427        } else {
428            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
429            let arr = &values[0];
430            let n_range = range.end - range.start;
431            if n_range == 0 {
432                // We produce None if the window is empty.
433                return ScalarValue::try_from(arr.data_type());
434            }
435
436            // If null values exist and need to be ignored, extract the valid indices.
437            let valid_indices = if self.ignore_nulls {
438                // Calculate valid indices, inside the window frame boundaries.
439                let slice = arr.slice(range.start, n_range);
440                match slice.nulls() {
441                    Some(nulls) => {
442                        let valid_indices = nulls
443                            .valid_indices()
444                            .map(|idx| {
445                                // Add offset `range.start` to valid indices, to point correct index in the original arr.
446                                idx + range.start
447                            })
448                            .collect::<Vec<_>>();
449                        if valid_indices.is_empty() {
450                            // If all values are null, return directly.
451                            return ScalarValue::try_from(arr.data_type());
452                        }
453                        Some(valid_indices)
454                    }
455                    None => None,
456                }
457            } else {
458                None
459            };
460            match self.state.kind {
461                NthValueKind::First => {
462                    if let Some(valid_indices) = &valid_indices {
463                        ScalarValue::try_from_array(arr, valid_indices[0])
464                    } else {
465                        ScalarValue::try_from_array(arr, range.start)
466                    }
467                }
468                NthValueKind::Last => {
469                    if let Some(valid_indices) = &valid_indices {
470                        ScalarValue::try_from_array(
471                            arr,
472                            valid_indices[valid_indices.len() - 1],
473                        )
474                    } else {
475                        ScalarValue::try_from_array(arr, range.end - 1)
476                    }
477                }
478                NthValueKind::Nth => {
479                    match self.n.cmp(&0) {
480                        Ordering::Greater => {
481                            // SQL indices are not 0-based.
482                            let index = (self.n as usize) - 1;
483                            if index >= n_range {
484                                // Outside the range, return NULL:
485                                ScalarValue::try_from(arr.data_type())
486                            } else if let Some(valid_indices) = valid_indices {
487                                if index >= valid_indices.len() {
488                                    return ScalarValue::try_from(arr.data_type());
489                                }
490                                ScalarValue::try_from_array(&arr, valid_indices[index])
491                            } else {
492                                ScalarValue::try_from_array(arr, range.start + index)
493                            }
494                        }
495                        Ordering::Less => {
496                            let reverse_index = (-self.n) as usize;
497                            if n_range < reverse_index {
498                                // Outside the range, return NULL:
499                                ScalarValue::try_from(arr.data_type())
500                            } else if let Some(valid_indices) = valid_indices {
501                                if reverse_index > valid_indices.len() {
502                                    return ScalarValue::try_from(arr.data_type());
503                                }
504                                let new_index =
505                                    valid_indices[valid_indices.len() - reverse_index];
506                                ScalarValue::try_from_array(&arr, new_index)
507                            } else {
508                                ScalarValue::try_from_array(
509                                    arr,
510                                    range.start + n_range - reverse_index,
511                                )
512                            }
513                        }
514                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
515                    }
516                }
517            }
518        }
519    }
520
521    fn supports_bounded_execution(&self) -> bool {
522        true
523    }
524
525    fn uses_window_frame(&self) -> bool {
526        true
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use arrow::array::*;
534    use datafusion_common::cast::as_int32_array;
535    use datafusion_physical_expr::expressions::{Column, Literal};
536    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
537    use std::sync::Arc;
538
539    fn test_i32_result(
540        expr: NthValue,
541        partition_evaluator_args: PartitionEvaluatorArgs,
542        expected: Int32Array,
543    ) -> Result<()> {
544        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
545        let values = vec![arr];
546        let mut ranges: Vec<Range<usize>> = vec![];
547        for i in 0..8 {
548            ranges.push(Range {
549                start: 0,
550                end: i + 1,
551            })
552        }
553        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
554        let result = ranges
555            .iter()
556            .map(|range| evaluator.evaluate(&values, range))
557            .collect::<Result<Vec<ScalarValue>>>()?;
558        let result = ScalarValue::iter_to_array(result.into_iter())?;
559        let result = as_int32_array(&result)?;
560        assert_eq!(expected, *result);
561        Ok(())
562    }
563
564    #[test]
565    fn first_value() -> Result<()> {
566        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
567        test_i32_result(
568            NthValue::first(),
569            PartitionEvaluatorArgs::new(
570                &[expr],
571                &[Field::new("f", DataType::Int32, true).into()],
572                false,
573                false,
574            ),
575            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
576        )
577    }
578
579    #[test]
580    fn last_value() -> Result<()> {
581        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
582        test_i32_result(
583            NthValue::last(),
584            PartitionEvaluatorArgs::new(
585                &[expr],
586                &[Field::new("f", DataType::Int32, true).into()],
587                false,
588                false,
589            ),
590            Int32Array::from(vec![
591                Some(1),
592                Some(-2),
593                Some(3),
594                Some(-4),
595                Some(5),
596                Some(-6),
597                Some(7),
598                Some(8),
599            ]),
600        )
601    }
602
603    #[test]
604    fn nth_value_1() -> Result<()> {
605        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
606        let n_value =
607            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
608
609        test_i32_result(
610            NthValue::nth(),
611            PartitionEvaluatorArgs::new(
612                &[expr, n_value],
613                &[Field::new("f", DataType::Int32, true).into()],
614                false,
615                false,
616            ),
617            Int32Array::from(vec![1; 8]),
618        )?;
619        Ok(())
620    }
621
622    #[test]
623    fn nth_value_2() -> Result<()> {
624        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
625        let n_value =
626            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
627
628        test_i32_result(
629            NthValue::nth(),
630            PartitionEvaluatorArgs::new(
631                &[expr, n_value],
632                &[Field::new("f", DataType::Int32, true).into()],
633                false,
634                false,
635            ),
636            Int32Array::from(vec![
637                None,
638                Some(-2),
639                Some(-2),
640                Some(-2),
641                Some(-2),
642                Some(-2),
643                Some(-2),
644                Some(-2),
645            ]),
646        )?;
647        Ok(())
648    }
649}