datafusion_functions_window/
lead_lag.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `lead` and `lag` window function implementations
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use datafusion_common::arrow::array::ArrayRef;
22use datafusion_common::arrow::datatypes::DataType;
23use datafusion_common::arrow::datatypes::Field;
24use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
25use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
26use datafusion_expr::{
27    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
28    Volatility, WindowUDFImpl,
29};
30use datafusion_functions_window_common::expr::ExpressionArgs;
31use datafusion_functions_window_common::field::WindowUDFFieldArgs;
32use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
33use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
34use std::any::Any;
35use std::cmp::min;
36use std::collections::VecDeque;
37use std::ops::{Neg, Range};
38use std::sync::{Arc, LazyLock};
39
40get_or_init_udwf!(
41    Lag,
42    lag,
43    "Returns the row value that precedes the current row by a specified \
44    offset within partition. If no such row exists, then returns the \
45    default value.",
46    WindowShift::lag
47);
48get_or_init_udwf!(
49    Lead,
50    lead,
51    "Returns the value from a row that follows the current row by a \
52    specified offset within the partition. If no such row exists, then \
53    returns the default value.",
54    WindowShift::lead
55);
56
57/// Create an expression to represent the `lag` window function
58///
59/// returns value evaluated at the row that is offset rows before the current row within the partition;
60/// if there is no such row, instead return default (which must be of the same type as value).
61/// Both offset and default are evaluated with respect to the current row.
62/// If omitted, offset defaults to 1 and default to null
63pub fn lag(
64    arg: datafusion_expr::Expr,
65    shift_offset: Option<i64>,
66    default_value: Option<ScalarValue>,
67) -> datafusion_expr::Expr {
68    let shift_offset_lit = shift_offset
69        .map(|v| v.lit())
70        .unwrap_or(ScalarValue::Null.lit());
71    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
72
73    lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
74}
75
76/// Create an expression to represent the `lead` window function
77///
78/// returns value evaluated at the row that is offset rows after the current row within the partition;
79/// if there is no such row, instead return default (which must be of the same type as value).
80/// Both offset and default are evaluated with respect to the current row.
81/// If omitted, offset defaults to 1 and default to null
82pub fn lead(
83    arg: datafusion_expr::Expr,
84    shift_offset: Option<i64>,
85    default_value: Option<ScalarValue>,
86) -> datafusion_expr::Expr {
87    let shift_offset_lit = shift_offset
88        .map(|v| v.lit())
89        .unwrap_or(ScalarValue::Null.lit());
90    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
91
92    lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
93}
94
95#[derive(Debug)]
96enum WindowShiftKind {
97    Lag,
98    Lead,
99}
100
101impl WindowShiftKind {
102    fn name(&self) -> &'static str {
103        match self {
104            WindowShiftKind::Lag => "lag",
105            WindowShiftKind::Lead => "lead",
106        }
107    }
108
109    /// In [`WindowShiftEvaluator`] a positive offset is used to signal
110    /// computation of `lag()`. So here we negate the input offset
111    /// value when computing `lead()`.
112    fn shift_offset(&self, value: Option<i64>) -> i64 {
113        match self {
114            WindowShiftKind::Lag => value.unwrap_or(1),
115            WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
116        }
117    }
118}
119
120/// window shift expression
121#[derive(Debug)]
122pub struct WindowShift {
123    signature: Signature,
124    kind: WindowShiftKind,
125}
126
127impl WindowShift {
128    fn new(kind: WindowShiftKind) -> Self {
129        Self {
130            signature: Signature::one_of(
131                vec![
132                    TypeSignature::Any(1),
133                    TypeSignature::Any(2),
134                    TypeSignature::Any(3),
135                ],
136                Volatility::Immutable,
137            ),
138            kind,
139        }
140    }
141
142    pub fn lag() -> Self {
143        Self::new(WindowShiftKind::Lag)
144    }
145
146    pub fn lead() -> Self {
147        Self::new(WindowShiftKind::Lead)
148    }
149}
150
151static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
152    Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
153            current row within the partition; if there is no such row, instead return default \
154            (which must be of the same type as value).", "lag(expression, offset, default)")
155        .with_argument("expression", "Expression to operate on")
156        .with_argument("offset", "Integer. Specifies how many rows back \
157        the value of expression should be retrieved. Defaults to 1.")
158        .with_argument("default", "The default value if the offset is \
159        not within the partition. Must be of the same type as expression.")
160        .build()
161});
162
163fn get_lag_doc() -> &'static Documentation {
164    &LAG_DOCUMENTATION
165}
166
167static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
168    Documentation::builder(DOC_SECTION_ANALYTICAL,
169            "Returns value evaluated at the row that is offset rows after the \
170            current row within the partition; if there is no such row, instead return default \
171            (which must be of the same type as value).",
172        "lead(expression, offset, default)")
173        .with_argument("expression", "Expression to operate on")
174        .with_argument("offset", "Integer. Specifies how many rows \
175        forward the value of expression should be retrieved. Defaults to 1.")
176        .with_argument("default", "The default value if the offset is \
177        not within the partition. Must be of the same type as expression.")
178        .build()
179});
180
181fn get_lead_doc() -> &'static Documentation {
182    &LEAD_DOCUMENTATION
183}
184
185impl WindowUDFImpl for WindowShift {
186    fn as_any(&self) -> &dyn Any {
187        self
188    }
189
190    fn name(&self) -> &str {
191        self.kind.name()
192    }
193
194    fn signature(&self) -> &Signature {
195        &self.signature
196    }
197
198    /// Handles the case where `NULL` expression is passed as an
199    /// argument to `lead`/`lag`. The type is refined depending
200    /// on the default value argument.
201    ///
202    /// For more details see: <https://github.com/apache/datafusion/issues/12717>
203    fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
204        parse_expr(expr_args.input_exprs(), expr_args.input_types())
205            .into_iter()
206            .collect::<Vec<_>>()
207    }
208
209    fn partition_evaluator(
210        &self,
211        partition_evaluator_args: PartitionEvaluatorArgs,
212    ) -> Result<Box<dyn PartitionEvaluator>> {
213        let shift_offset =
214            get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
215                .map(get_signed_integer)
216                .map_or(Ok(None), |v| v.map(Some))
217                .map(|n| self.kind.shift_offset(n))
218                .map(|offset| {
219                    if partition_evaluator_args.is_reversed() {
220                        -offset
221                    } else {
222                        offset
223                    }
224                })?;
225        let default_value = parse_default_value(
226            partition_evaluator_args.input_exprs(),
227            partition_evaluator_args.input_types(),
228        )?;
229
230        Ok(Box::new(WindowShiftEvaluator {
231            shift_offset,
232            default_value,
233            ignore_nulls: partition_evaluator_args.ignore_nulls(),
234            non_null_offsets: VecDeque::new(),
235        }))
236    }
237
238    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
239        let return_type = parse_expr_type(field_args.input_types())?;
240
241        Ok(Field::new(field_args.name(), return_type, true))
242    }
243
244    fn reverse_expr(&self) -> ReversedUDWF {
245        match self.kind {
246            WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
247            WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
248        }
249    }
250
251    fn documentation(&self) -> Option<&Documentation> {
252        match self.kind {
253            WindowShiftKind::Lag => Some(get_lag_doc()),
254            WindowShiftKind::Lead => Some(get_lead_doc()),
255        }
256    }
257}
258
259/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
260/// refine it by matching it with the type of the default value.
261///
262/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
263/// is refined into `ScalarValue::Boolean(None)`. Only the type is
264/// refined, the expression value remains `NULL`.
265///
266/// When the window function is evaluated with `NULL` expression
267/// this guarantees that the type matches with that of the default
268/// value.
269///
270/// For more details see: <https://github.com/apache/datafusion/issues/12717>
271fn parse_expr(
272    input_exprs: &[Arc<dyn PhysicalExpr>],
273    input_types: &[DataType],
274) -> Result<Arc<dyn PhysicalExpr>> {
275    assert!(!input_exprs.is_empty());
276    assert!(!input_types.is_empty());
277
278    let expr = Arc::clone(input_exprs.first().unwrap());
279    let expr_type = input_types.first().unwrap();
280
281    // Handles the most common case where NULL is unexpected
282    if !expr_type.is_null() {
283        return Ok(expr);
284    }
285
286    let default_value = get_scalar_value_from_args(input_exprs, 2)?;
287    default_value.map_or(Ok(expr), |value| {
288        ScalarValue::try_from(&value.data_type()).map(|v| {
289            Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
290                as Arc<dyn PhysicalExpr>
291        })
292    })
293}
294
295/// Returns the data type of the default value(if provided) when the
296/// expression is `NULL`.
297///
298/// Otherwise, returns the expression type unchanged.
299fn parse_expr_type(input_types: &[DataType]) -> Result<DataType> {
300    assert!(!input_types.is_empty());
301    let expr_type = input_types.first().unwrap_or(&DataType::Null);
302
303    // Handles the most common case where NULL is unexpected
304    if !expr_type.is_null() {
305        return Ok(expr_type.clone());
306    }
307
308    let default_value_type = input_types.get(2).unwrap_or(&DataType::Null);
309    Ok(default_value_type.clone())
310}
311
312/// Handles type coercion and null value refinement for default value
313/// argument depending on the data type of the input expression.
314fn parse_default_value(
315    input_exprs: &[Arc<dyn PhysicalExpr>],
316    input_types: &[DataType],
317) -> Result<ScalarValue> {
318    let expr_type = parse_expr_type(input_types)?;
319    let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
320
321    unparsed
322        .filter(|v| !v.data_type().is_null())
323        .map(|v| v.cast_to(&expr_type))
324        .unwrap_or(ScalarValue::try_from(expr_type))
325}
326
327#[derive(Debug)]
328struct WindowShiftEvaluator {
329    shift_offset: i64,
330    default_value: ScalarValue,
331    ignore_nulls: bool,
332    // VecDeque contains offset values that between non-null entries
333    non_null_offsets: VecDeque<usize>,
334}
335
336impl WindowShiftEvaluator {
337    fn is_lag(&self) -> bool {
338        // Mode is LAG, when shift_offset is positive
339        self.shift_offset > 0
340    }
341}
342
343// implement ignore null for evaluate_all
344fn evaluate_all_with_ignore_null(
345    array: &ArrayRef,
346    offset: i64,
347    default_value: &ScalarValue,
348    is_lag: bool,
349) -> Result<ArrayRef, DataFusionError> {
350    let valid_indices: Vec<usize> =
351        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
352    let direction = !is_lag;
353    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
354        .map(|id| {
355            let result_index = match valid_indices.binary_search(&id) {
356                Ok(pos) => if direction {
357                    pos.checked_add(offset as usize)
358                } else {
359                    pos.checked_sub(offset.unsigned_abs() as usize)
360                }
361                .and_then(|new_pos| {
362                    if new_pos < valid_indices.len() {
363                        Some(valid_indices[new_pos])
364                    } else {
365                        None
366                    }
367                }),
368                Err(pos) => if direction {
369                    pos.checked_add(offset as usize)
370                } else if pos > 0 {
371                    pos.checked_sub(offset.unsigned_abs() as usize)
372                } else {
373                    None
374                }
375                .and_then(|new_pos| {
376                    if new_pos < valid_indices.len() {
377                        Some(valid_indices[new_pos])
378                    } else {
379                        None
380                    }
381                }),
382            };
383
384            match result_index {
385                Some(index) => ScalarValue::try_from_array(array, index),
386                None => Ok(default_value.clone()),
387            }
388        })
389        .collect();
390
391    let new_array = new_array_results?;
392    ScalarValue::iter_to_array(new_array)
393}
394// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
395fn shift_with_default_value(
396    array: &ArrayRef,
397    offset: i64,
398    default_value: &ScalarValue,
399) -> Result<ArrayRef> {
400    use datafusion_common::arrow::compute::concat;
401
402    let value_len = array.len() as i64;
403    if offset == 0 {
404        Ok(Arc::clone(array))
405    } else if offset == i64::MIN || offset.abs() >= value_len {
406        default_value.to_array_of_size(value_len as usize)
407    } else {
408        let slice_offset = (-offset).clamp(0, value_len) as usize;
409        let length = array.len() - offset.unsigned_abs() as usize;
410        let slice = array.slice(slice_offset, length);
411
412        // Generate array with remaining `null` items
413        let nulls = offset.unsigned_abs() as usize;
414        let default_values = default_value.to_array_of_size(nulls)?;
415
416        // Concatenate both arrays, add nulls after if shift > 0 else before
417        if offset > 0 {
418            concat(&[default_values.as_ref(), slice.as_ref()])
419                .map_err(|e| arrow_datafusion_err!(e))
420        } else {
421            concat(&[slice.as_ref(), default_values.as_ref()])
422                .map_err(|e| arrow_datafusion_err!(e))
423        }
424    }
425}
426
427impl PartitionEvaluator for WindowShiftEvaluator {
428    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
429        if self.is_lag() {
430            let start = if self.non_null_offsets.len() == self.shift_offset as usize {
431                // How many rows needed previous than the current row to get necessary lag result
432                let offset: usize = self.non_null_offsets.iter().sum();
433                idx.saturating_sub(offset)
434            } else if !self.ignore_nulls {
435                let offset = self.shift_offset as usize;
436                idx.saturating_sub(offset)
437            } else {
438                0
439            };
440            let end = idx + 1;
441            Ok(Range { start, end })
442        } else {
443            let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
444                // How many rows needed further than the current row to get necessary lead result
445                let offset: usize = self.non_null_offsets.iter().sum();
446                min(idx + offset + 1, n_rows)
447            } else if !self.ignore_nulls {
448                let offset = (-self.shift_offset) as usize;
449                min(idx + offset, n_rows)
450            } else {
451                n_rows
452            };
453            Ok(Range { start: idx, end })
454        }
455    }
456
457    fn is_causal(&self) -> bool {
458        // Lagging windows are causal by definition:
459        self.is_lag()
460    }
461
462    fn evaluate(
463        &mut self,
464        values: &[ArrayRef],
465        range: &Range<usize>,
466    ) -> Result<ScalarValue> {
467        let array = &values[0];
468        let len = array.len();
469
470        // LAG mode
471        let i = if self.is_lag() {
472            (range.end as i64 - self.shift_offset - 1) as usize
473        } else {
474            // LEAD mode
475            (range.start as i64 - self.shift_offset) as usize
476        };
477
478        let mut idx: Option<usize> = if i < len { Some(i) } else { None };
479
480        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
481        // If current row index points to NULL value the row is NOT counted
482        if self.ignore_nulls && self.is_lag() {
483            // LAG when NULLS are ignored.
484            // Find the nonNULL row index that shifted by offset comparing to current row index
485            idx = if self.non_null_offsets.len() == self.shift_offset as usize {
486                let total_offset: usize = self.non_null_offsets.iter().sum();
487                Some(range.end - 1 - total_offset)
488            } else {
489                None
490            };
491
492            // Keep track of offset values between non-null entries
493            if array.is_valid(range.end - 1) {
494                // Non-null add new offset
495                self.non_null_offsets.push_back(1);
496                if self.non_null_offsets.len() > self.shift_offset as usize {
497                    // WE do not need to keep track of more than `lag number of offset` values.
498                    self.non_null_offsets.pop_front();
499                }
500            } else if !self.non_null_offsets.is_empty() {
501                // Entry is null, increment offset value of the last entry.
502                let end_idx = self.non_null_offsets.len() - 1;
503                self.non_null_offsets[end_idx] += 1;
504            }
505        } else if self.ignore_nulls && !self.is_lag() {
506            // LEAD when NULLS are ignored.
507            // Stores the necessary non-null entry number further than the current row.
508            let non_null_row_count = (-self.shift_offset) as usize;
509
510            if self.non_null_offsets.is_empty() {
511                // When empty, fill non_null offsets with the data further than the current row.
512                let mut offset_val = 1;
513                for idx in range.start + 1..range.end {
514                    if array.is_valid(idx) {
515                        self.non_null_offsets.push_back(offset_val);
516                        offset_val = 1;
517                    } else {
518                        offset_val += 1;
519                    }
520                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
521                    // further data is unnecessary for the result.
522                    if self.non_null_offsets.len() == non_null_row_count + 1 {
523                        break;
524                    }
525                }
526            } else if range.end < len && array.is_valid(range.end) {
527                // Update `non_null_offsets` with the new end data.
528                if array.is_valid(range.end) {
529                    // When non-null, append a new offset.
530                    self.non_null_offsets.push_back(1);
531                } else {
532                    // When null, increment offset count of the last entry
533                    let last_idx = self.non_null_offsets.len() - 1;
534                    self.non_null_offsets[last_idx] += 1;
535                }
536            }
537
538            // Find the nonNULL row index that shifted by offset comparing to current row index
539            idx = if self.non_null_offsets.len() >= non_null_row_count {
540                let total_offset: usize =
541                    self.non_null_offsets.iter().take(non_null_row_count).sum();
542                Some(range.start + total_offset)
543            } else {
544                None
545            };
546            // Prune `self.non_null_offsets` from the start. so that at next iteration
547            // start of the `self.non_null_offsets` matches with current row.
548            if !self.non_null_offsets.is_empty() {
549                self.non_null_offsets[0] -= 1;
550                if self.non_null_offsets[0] == 0 {
551                    // When offset is 0. Remove it.
552                    self.non_null_offsets.pop_front();
553                }
554            }
555        }
556
557        // Set the default value if
558        // - index is out of window bounds
559        // OR
560        // - ignore nulls mode and current value is null and is within window bounds
561        // .unwrap() is safe here as there is a none check in front
562        #[allow(clippy::unnecessary_unwrap)]
563        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
564            ScalarValue::try_from_array(array, idx.unwrap())
565        } else {
566            Ok(self.default_value.clone())
567        }
568    }
569
570    fn evaluate_all(
571        &mut self,
572        values: &[ArrayRef],
573        _num_rows: usize,
574    ) -> Result<ArrayRef> {
575        // LEAD, LAG window functions take single column, values will have size 1
576        let value = &values[0];
577        if !self.ignore_nulls {
578            shift_with_default_value(value, self.shift_offset, &self.default_value)
579        } else {
580            evaluate_all_with_ignore_null(
581                value,
582                self.shift_offset,
583                &self.default_value,
584                self.is_lag(),
585            )
586        }
587    }
588
589    fn supports_bounded_execution(&self) -> bool {
590        true
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use arrow::array::*;
598    use datafusion_common::cast::as_int32_array;
599    use datafusion_physical_expr::expressions::{Column, Literal};
600    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
601
602    fn test_i32_result(
603        expr: WindowShift,
604        partition_evaluator_args: PartitionEvaluatorArgs,
605        expected: Int32Array,
606    ) -> Result<()> {
607        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
608        let values = vec![arr];
609        let num_rows = values.len();
610        let result = expr
611            .partition_evaluator(partition_evaluator_args)?
612            .evaluate_all(&values, num_rows)?;
613        let result = as_int32_array(&result)?;
614        assert_eq!(expected, *result);
615        Ok(())
616    }
617
618    #[test]
619    fn lead_lag_get_range() -> Result<()> {
620        // LAG(2)
621        let lag_fn = WindowShiftEvaluator {
622            shift_offset: 2,
623            default_value: ScalarValue::Null,
624            ignore_nulls: false,
625            non_null_offsets: Default::default(),
626        };
627        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
628        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
629
630        // LAG(2 ignore nulls)
631        let lag_fn = WindowShiftEvaluator {
632            shift_offset: 2,
633            default_value: ScalarValue::Null,
634            ignore_nulls: true,
635            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
636            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
637        };
638        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
639
640        // LEAD(2)
641        let lead_fn = WindowShiftEvaluator {
642            shift_offset: -2,
643            default_value: ScalarValue::Null,
644            ignore_nulls: false,
645            non_null_offsets: Default::default(),
646        };
647        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
648        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
649
650        // LEAD(2 ignore nulls)
651        let lead_fn = WindowShiftEvaluator {
652            shift_offset: -2,
653            default_value: ScalarValue::Null,
654            ignore_nulls: true,
655            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
656            non_null_offsets: vec![2, 2].into(),
657        };
658        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
659
660        Ok(())
661    }
662
663    #[test]
664    fn test_lead_window_shift() -> Result<()> {
665        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
666
667        test_i32_result(
668            WindowShift::lead(),
669            PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
670            [
671                Some(-2),
672                Some(3),
673                Some(-4),
674                Some(5),
675                Some(-6),
676                Some(7),
677                Some(8),
678                None,
679            ]
680            .iter()
681            .collect::<Int32Array>(),
682        )
683    }
684
685    #[test]
686    fn test_lag_window_shift() -> Result<()> {
687        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
688
689        test_i32_result(
690            WindowShift::lag(),
691            PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
692            [
693                None,
694                Some(1),
695                Some(-2),
696                Some(3),
697                Some(-4),
698                Some(5),
699                Some(-6),
700                Some(7),
701            ]
702            .iter()
703            .collect::<Int32Array>(),
704        )
705    }
706
707    #[test]
708    fn test_lag_with_default() -> Result<()> {
709        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
710        let shift_offset =
711            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
712        let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
713            as Arc<dyn PhysicalExpr>;
714
715        let input_exprs = &[expr, shift_offset, default_value];
716        let input_types: &[DataType] =
717            &[DataType::Int32, DataType::Int32, DataType::Int32];
718
719        test_i32_result(
720            WindowShift::lag(),
721            PartitionEvaluatorArgs::new(input_exprs, input_types, false, false),
722            [
723                Some(100),
724                Some(1),
725                Some(-2),
726                Some(3),
727                Some(-4),
728                Some(5),
729                Some(-6),
730                Some(7),
731            ]
732            .iter()
733            .collect::<Int32Array>(),
734        )
735    }
736}