datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use std::any::Any;
23use std::cmp::Ordering;
24use std::fmt::Debug;
25use std::ops::Range;
26use std::sync::LazyLock;
27
28use datafusion_common::arrow::array::ArrayRef;
29use datafusion_common::arrow::datatypes::{DataType, Field};
30use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
31use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
32use datafusion_expr::window_state::WindowAggState;
33use datafusion_expr::{
34    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
35    Volatility, WindowUDFImpl,
36};
37use datafusion_functions_window_common::field;
38use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
39use field::WindowUDFFieldArgs;
40
41get_or_init_udwf!(
42    First,
43    first_value,
44    "returns the first value in the window frame",
45    NthValue::first
46);
47get_or_init_udwf!(
48    Last,
49    last_value,
50    "returns the last value in the window frame",
51    NthValue::last
52);
53get_or_init_udwf!(
54    NthValue,
55    nth_value,
56    "returns the nth value in the window frame",
57    NthValue::nth
58);
59
60/// Create an expression to represent the `first_value` window function
61///
62pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
63    first_value_udwf().call(vec![arg])
64}
65
66/// Create an expression to represent the `last_value` window function
67///
68pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
69    last_value_udwf().call(vec![arg])
70}
71
72/// Create an expression to represent the `nth_value` window function
73///
74pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
75    nth_value_udwf().call(vec![arg, n.lit()])
76}
77
78/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
79#[derive(Debug, Copy, Clone)]
80pub enum NthValueKind {
81    First,
82    Last,
83    Nth,
84}
85
86impl NthValueKind {
87    fn name(&self) -> &'static str {
88        match self {
89            NthValueKind::First => "first_value",
90            NthValueKind::Last => "last_value",
91            NthValueKind::Nth => "nth_value",
92        }
93    }
94}
95
96#[derive(Debug)]
97pub struct NthValue {
98    signature: Signature,
99    kind: NthValueKind,
100}
101
102impl NthValue {
103    /// Create a new `nth_value` function
104    pub fn new(kind: NthValueKind) -> Self {
105        Self {
106            signature: Signature::one_of(
107                vec![
108                    TypeSignature::Any(0),
109                    TypeSignature::Any(1),
110                    TypeSignature::Any(2),
111                ],
112                Volatility::Immutable,
113            ),
114            kind,
115        }
116    }
117
118    pub fn first() -> Self {
119        Self::new(NthValueKind::First)
120    }
121
122    pub fn last() -> Self {
123        Self::new(NthValueKind::Last)
124    }
125    pub fn nth() -> Self {
126        Self::new(NthValueKind::Nth)
127    }
128}
129
130static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
131    Documentation::builder(
132        DOC_SECTION_ANALYTICAL,
133        "Returns value evaluated at the row that is the first row of the window \
134            frame.",
135        "first_value(expression)",
136    )
137    .with_argument("expression", "Expression to operate on")
138    .build()
139});
140
141fn get_first_value_doc() -> &'static Documentation {
142    &FIRST_VALUE_DOCUMENTATION
143}
144
145static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
146    Documentation::builder(
147        DOC_SECTION_ANALYTICAL,
148        "Returns value evaluated at the row that is the last row of the window \
149            frame.",
150        "last_value(expression)",
151    )
152    .with_argument("expression", "Expression to operate on")
153    .build()
154});
155
156fn get_last_value_doc() -> &'static Documentation {
157    &LAST_VALUE_DOCUMENTATION
158}
159
160static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
161    Documentation::builder(
162        DOC_SECTION_ANALYTICAL,
163        "Returns value evaluated at the row that is the nth row of the window \
164            frame (counting from 1); null if no such row.",
165        "nth_value(expression, n)",
166    )
167    .with_argument(
168        "expression",
169        "The name the column of which nth \
170        value to retrieve",
171    )
172    .with_argument("n", "Integer. Specifies the n in nth")
173    .build()
174});
175
176fn get_nth_value_doc() -> &'static Documentation {
177    &NTH_VALUE_DOCUMENTATION
178}
179
180impl WindowUDFImpl for NthValue {
181    fn as_any(&self) -> &dyn Any {
182        self
183    }
184
185    fn name(&self) -> &str {
186        self.kind.name()
187    }
188
189    fn signature(&self) -> &Signature {
190        &self.signature
191    }
192
193    fn partition_evaluator(
194        &self,
195        partition_evaluator_args: PartitionEvaluatorArgs,
196    ) -> Result<Box<dyn PartitionEvaluator>> {
197        let state = NthValueState {
198            finalized_result: None,
199            kind: self.kind,
200        };
201
202        if !matches!(self.kind, NthValueKind::Nth) {
203            return Ok(Box::new(NthValueEvaluator {
204                state,
205                ignore_nulls: partition_evaluator_args.ignore_nulls(),
206                n: 0,
207            }));
208        }
209
210        let n =
211            match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
212                .map_err(|_e| {
213                    exec_datafusion_err!(
214                "Expected a signed integer literal for the second argument of nth_value")
215                })?
216                .map(get_signed_integer)
217            {
218                Some(Ok(n)) => {
219                    if partition_evaluator_args.is_reversed() {
220                        -n
221                    } else {
222                        n
223                    }
224                }
225                _ => {
226                    return exec_err!(
227                "Expected a signed integer literal for the second argument of nth_value"
228            )
229                }
230            };
231
232        Ok(Box::new(NthValueEvaluator {
233            state,
234            ignore_nulls: partition_evaluator_args.ignore_nulls(),
235            n,
236        }))
237    }
238
239    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
240        let nullable = true;
241        let return_type = field_args.input_types().first().unwrap_or(&DataType::Null);
242
243        Ok(Field::new(field_args.name(), return_type.clone(), nullable))
244    }
245
246    fn reverse_expr(&self) -> ReversedUDWF {
247        match self.kind {
248            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
249            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
250            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
251        }
252    }
253
254    fn documentation(&self) -> Option<&Documentation> {
255        match self.kind {
256            NthValueKind::First => Some(get_first_value_doc()),
257            NthValueKind::Last => Some(get_last_value_doc()),
258            NthValueKind::Nth => Some(get_nth_value_doc()),
259        }
260    }
261}
262
263#[derive(Debug, Clone)]
264pub struct NthValueState {
265    // In certain cases, we can finalize the result early. Consider this usage:
266    // ```
267    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
268    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
269    // ```
270    // The result will always be the first entry in the table. We can store such
271    // early-finalizing results and then just reuse them as necessary. This opens
272    // opportunities to prune our datasets.
273    pub finalized_result: Option<ScalarValue>,
274    pub kind: NthValueKind,
275}
276
277#[derive(Debug)]
278pub(crate) struct NthValueEvaluator {
279    state: NthValueState,
280    ignore_nulls: bool,
281    n: i64,
282}
283
284impl PartitionEvaluator for NthValueEvaluator {
285    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
286    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
287    /// can memoize the result.  Once result is calculated, it will always stay
288    /// same. Hence, we do not need to keep past data as we process the entire
289    /// dataset.
290    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
291        let out = &state.out_col;
292        let size = out.len();
293        let mut buffer_size = 1;
294        // Decide if we arrived at a final result yet:
295        let (is_prunable, is_reverse_direction) = match self.state.kind {
296            NthValueKind::First => {
297                let n_range =
298                    state.window_frame_range.end - state.window_frame_range.start;
299                (n_range > 0 && size > 0, false)
300            }
301            NthValueKind::Last => (true, true),
302            NthValueKind::Nth => {
303                let n_range =
304                    state.window_frame_range.end - state.window_frame_range.start;
305                match self.n.cmp(&0) {
306                    Ordering::Greater => (
307                        n_range >= (self.n as usize) && size > (self.n as usize),
308                        false,
309                    ),
310                    Ordering::Less => {
311                        let reverse_index = (-self.n) as usize;
312                        buffer_size = reverse_index;
313                        // Negative index represents reverse direction.
314                        (n_range >= reverse_index, true)
315                    }
316                    Ordering::Equal => (false, false),
317                }
318            }
319        };
320        // Do not memoize results when nulls are ignored.
321        if is_prunable && !self.ignore_nulls {
322            if self.state.finalized_result.is_none() && !is_reverse_direction {
323                let result = ScalarValue::try_from_array(out, size - 1)?;
324                self.state.finalized_result = Some(result);
325            }
326            state.window_frame_range.start =
327                state.window_frame_range.end.saturating_sub(buffer_size);
328        }
329        Ok(())
330    }
331
332    fn evaluate(
333        &mut self,
334        values: &[ArrayRef],
335        range: &Range<usize>,
336    ) -> Result<ScalarValue> {
337        if let Some(ref result) = self.state.finalized_result {
338            Ok(result.clone())
339        } else {
340            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
341            let arr = &values[0];
342            let n_range = range.end - range.start;
343            if n_range == 0 {
344                // We produce None if the window is empty.
345                return ScalarValue::try_from(arr.data_type());
346            }
347
348            // If null values exist and need to be ignored, extract the valid indices.
349            let valid_indices = if self.ignore_nulls {
350                // Calculate valid indices, inside the window frame boundaries.
351                let slice = arr.slice(range.start, n_range);
352                match slice.nulls() {
353                    Some(nulls) => {
354                        let valid_indices = nulls
355                            .valid_indices()
356                            .map(|idx| {
357                                // Add offset `range.start` to valid indices, to point correct index in the original arr.
358                                idx + range.start
359                            })
360                            .collect::<Vec<_>>();
361                        if valid_indices.is_empty() {
362                            // If all values are null, return directly.
363                            return ScalarValue::try_from(arr.data_type());
364                        }
365                        Some(valid_indices)
366                    }
367                    None => None,
368                }
369            } else {
370                None
371            };
372            match self.state.kind {
373                NthValueKind::First => {
374                    if let Some(valid_indices) = &valid_indices {
375                        ScalarValue::try_from_array(arr, valid_indices[0])
376                    } else {
377                        ScalarValue::try_from_array(arr, range.start)
378                    }
379                }
380                NthValueKind::Last => {
381                    if let Some(valid_indices) = &valid_indices {
382                        ScalarValue::try_from_array(
383                            arr,
384                            valid_indices[valid_indices.len() - 1],
385                        )
386                    } else {
387                        ScalarValue::try_from_array(arr, range.end - 1)
388                    }
389                }
390                NthValueKind::Nth => {
391                    match self.n.cmp(&0) {
392                        Ordering::Greater => {
393                            // SQL indices are not 0-based.
394                            let index = (self.n as usize) - 1;
395                            if index >= n_range {
396                                // Outside the range, return NULL:
397                                ScalarValue::try_from(arr.data_type())
398                            } else if let Some(valid_indices) = valid_indices {
399                                if index >= valid_indices.len() {
400                                    return ScalarValue::try_from(arr.data_type());
401                                }
402                                ScalarValue::try_from_array(&arr, valid_indices[index])
403                            } else {
404                                ScalarValue::try_from_array(arr, range.start + index)
405                            }
406                        }
407                        Ordering::Less => {
408                            let reverse_index = (-self.n) as usize;
409                            if n_range < reverse_index {
410                                // Outside the range, return NULL:
411                                ScalarValue::try_from(arr.data_type())
412                            } else if let Some(valid_indices) = valid_indices {
413                                if reverse_index > valid_indices.len() {
414                                    return ScalarValue::try_from(arr.data_type());
415                                }
416                                let new_index =
417                                    valid_indices[valid_indices.len() - reverse_index];
418                                ScalarValue::try_from_array(&arr, new_index)
419                            } else {
420                                ScalarValue::try_from_array(
421                                    arr,
422                                    range.start + n_range - reverse_index,
423                                )
424                            }
425                        }
426                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
427                    }
428                }
429            }
430        }
431    }
432
433    fn supports_bounded_execution(&self) -> bool {
434        true
435    }
436
437    fn uses_window_frame(&self) -> bool {
438        true
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use arrow::array::*;
446    use datafusion_common::cast::as_int32_array;
447    use datafusion_physical_expr::expressions::{Column, Literal};
448    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
449    use std::sync::Arc;
450
451    fn test_i32_result(
452        expr: NthValue,
453        partition_evaluator_args: PartitionEvaluatorArgs,
454        expected: Int32Array,
455    ) -> Result<()> {
456        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
457        let values = vec![arr];
458        let mut ranges: Vec<Range<usize>> = vec![];
459        for i in 0..8 {
460            ranges.push(Range {
461                start: 0,
462                end: i + 1,
463            })
464        }
465        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
466        let result = ranges
467            .iter()
468            .map(|range| evaluator.evaluate(&values, range))
469            .collect::<Result<Vec<ScalarValue>>>()?;
470        let result = ScalarValue::iter_to_array(result.into_iter())?;
471        let result = as_int32_array(&result)?;
472        assert_eq!(expected, *result);
473        Ok(())
474    }
475
476    #[test]
477    fn first_value() -> Result<()> {
478        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
479        test_i32_result(
480            NthValue::first(),
481            PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
482            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
483        )
484    }
485
486    #[test]
487    fn last_value() -> Result<()> {
488        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
489        test_i32_result(
490            NthValue::last(),
491            PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
492            Int32Array::from(vec![
493                Some(1),
494                Some(-2),
495                Some(3),
496                Some(-4),
497                Some(5),
498                Some(-6),
499                Some(7),
500                Some(8),
501            ]),
502        )
503    }
504
505    #[test]
506    fn nth_value_1() -> Result<()> {
507        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
508        let n_value =
509            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
510
511        test_i32_result(
512            NthValue::nth(),
513            PartitionEvaluatorArgs::new(
514                &[expr, n_value],
515                &[DataType::Int32],
516                false,
517                false,
518            ),
519            Int32Array::from(vec![1; 8]),
520        )?;
521        Ok(())
522    }
523
524    #[test]
525    fn nth_value_2() -> Result<()> {
526        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
527        let n_value =
528            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
529
530        test_i32_result(
531            NthValue::nth(),
532            PartitionEvaluatorArgs::new(
533                &[expr, n_value],
534                &[DataType::Int32],
535                false,
536                false,
537            ),
538            Int32Array::from(vec![
539                None,
540                Some(-2),
541                Some(-2),
542                Some(-2),
543                Some(-2),
544                Some(-2),
545                Some(-2),
546                Some(-2),
547            ]),
548        )?;
549        Ok(())
550    }
551}