datafusion_functions_window/
lead_lag.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `lead` and `lag` window function implementations
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
29    Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use std::any::Any;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::hash::{DefaultHasher, Hash, Hasher};
39use std::ops::{Neg, Range};
40use std::sync::{Arc, LazyLock};
41
42get_or_init_udwf!(
43    Lag,
44    lag,
45    "Returns the row value that precedes the current row by a specified \
46    offset within partition. If no such row exists, then returns the \
47    default value.",
48    WindowShift::lag
49);
50get_or_init_udwf!(
51    Lead,
52    lead,
53    "Returns the value from a row that follows the current row by a \
54    specified offset within the partition. If no such row exists, then \
55    returns the default value.",
56    WindowShift::lead
57);
58
59/// Create an expression to represent the `lag` window function
60///
61/// returns value evaluated at the row that is offset rows before the current row within the partition;
62/// if there is no such row, instead return default (which must be of the same type as value).
63/// Both offset and default are evaluated with respect to the current row.
64/// If omitted, offset defaults to 1 and default to null
65pub fn lag(
66    arg: datafusion_expr::Expr,
67    shift_offset: Option<i64>,
68    default_value: Option<ScalarValue>,
69) -> datafusion_expr::Expr {
70    let shift_offset_lit = shift_offset
71        .map(|v| v.lit())
72        .unwrap_or(ScalarValue::Null.lit());
73    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
74
75    lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
76}
77
78/// Create an expression to represent the `lead` window function
79///
80/// returns value evaluated at the row that is offset rows after the current row within the partition;
81/// if there is no such row, instead return default (which must be of the same type as value).
82/// Both offset and default are evaluated with respect to the current row.
83/// If omitted, offset defaults to 1 and default to null
84pub fn lead(
85    arg: datafusion_expr::Expr,
86    shift_offset: Option<i64>,
87    default_value: Option<ScalarValue>,
88) -> datafusion_expr::Expr {
89    let shift_offset_lit = shift_offset
90        .map(|v| v.lit())
91        .unwrap_or(ScalarValue::Null.lit());
92    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
93
94    lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
95}
96
97#[derive(Debug, PartialEq, Eq, Hash)]
98enum WindowShiftKind {
99    Lag,
100    Lead,
101}
102
103impl WindowShiftKind {
104    fn name(&self) -> &'static str {
105        match self {
106            WindowShiftKind::Lag => "lag",
107            WindowShiftKind::Lead => "lead",
108        }
109    }
110
111    /// In [`WindowShiftEvaluator`] a positive offset is used to signal
112    /// computation of `lag()`. So here we negate the input offset
113    /// value when computing `lead()`.
114    fn shift_offset(&self, value: Option<i64>) -> i64 {
115        match self {
116            WindowShiftKind::Lag => value.unwrap_or(1),
117            WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
118        }
119    }
120}
121
122/// window shift expression
123#[derive(Debug)]
124pub struct WindowShift {
125    signature: Signature,
126    kind: WindowShiftKind,
127}
128
129impl WindowShift {
130    fn new(kind: WindowShiftKind) -> Self {
131        Self {
132            signature: Signature::one_of(
133                vec![
134                    TypeSignature::Any(1),
135                    TypeSignature::Any(2),
136                    TypeSignature::Any(3),
137                ],
138                Volatility::Immutable,
139            ),
140            kind,
141        }
142    }
143
144    pub fn lag() -> Self {
145        Self::new(WindowShiftKind::Lag)
146    }
147
148    pub fn lead() -> Self {
149        Self::new(WindowShiftKind::Lead)
150    }
151}
152
153static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
154    Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
155            current row within the partition; if there is no such row, instead return default \
156            (which must be of the same type as value).", "lag(expression, offset, default)")
157        .with_argument("expression", "Expression to operate on")
158        .with_argument("offset", "Integer. Specifies how many rows back \
159        the value of expression should be retrieved. Defaults to 1.")
160        .with_argument("default", "The default value if the offset is \
161        not within the partition. Must be of the same type as expression.")
162        .with_sql_example(r#"```sql
163    --Example usage of the lag window function:
164    SELECT employee_id,
165           salary,
166           lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
167    FROM employees;
168```
169
170```sql
171+-------------+--------+-------------+
172| employee_id | salary | prev_salary |
173+-------------+--------+-------------+
174| 1           | 30000  | 0           |
175| 2           | 50000  | 30000       |
176| 3           | 70000  | 50000       |
177| 4           | 60000  | 70000       |
178+-------------+--------+-------------+
179```"#)
180        .build()
181});
182
183fn get_lag_doc() -> &'static Documentation {
184    &LAG_DOCUMENTATION
185}
186
187static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
188    Documentation::builder(DOC_SECTION_ANALYTICAL,
189            "Returns value evaluated at the row that is offset rows after the \
190            current row within the partition; if there is no such row, instead return default \
191            (which must be of the same type as value).",
192        "lead(expression, offset, default)")
193        .with_argument("expression", "Expression to operate on")
194        .with_argument("offset", "Integer. Specifies how many rows \
195        forward the value of expression should be retrieved. Defaults to 1.")
196        .with_argument("default", "The default value if the offset is \
197        not within the partition. Must be of the same type as expression.")
198        .with_sql_example(r#"```sql
199-- Example usage of lead() :
200SELECT
201    employee_id,
202    department,
203    salary,
204    lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
205FROM employees;
206```
207
208```sql
209+-------------+-------------+--------+--------------+
210| employee_id | department  | salary | next_salary  |
211+-------------+-------------+--------+--------------+
212| 1           | Sales       | 30000  | 50000        |
213| 2           | Sales       | 50000  | 70000        |
214| 3           | Sales       | 70000  | 0            |
215| 4           | Engineering | 40000  | 60000        |
216| 5           | Engineering | 60000  | 0            |
217+-------------+-------------+--------+--------------+
218```"#)
219        .build()
220});
221
222fn get_lead_doc() -> &'static Documentation {
223    &LEAD_DOCUMENTATION
224}
225
226impl WindowUDFImpl for WindowShift {
227    fn as_any(&self) -> &dyn Any {
228        self
229    }
230
231    fn name(&self) -> &str {
232        self.kind.name()
233    }
234
235    fn signature(&self) -> &Signature {
236        &self.signature
237    }
238
239    /// Handles the case where `NULL` expression is passed as an
240    /// argument to `lead`/`lag`. The type is refined depending
241    /// on the default value argument.
242    ///
243    /// For more details see: <https://github.com/apache/datafusion/issues/12717>
244    fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
245        parse_expr(expr_args.input_exprs(), expr_args.input_fields())
246            .into_iter()
247            .collect::<Vec<_>>()
248    }
249
250    fn partition_evaluator(
251        &self,
252        partition_evaluator_args: PartitionEvaluatorArgs,
253    ) -> Result<Box<dyn PartitionEvaluator>> {
254        let shift_offset =
255            get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
256                .map(get_signed_integer)
257                .map_or(Ok(None), |v| v.map(Some))
258                .map(|n| self.kind.shift_offset(n))
259                .map(|offset| {
260                    if partition_evaluator_args.is_reversed() {
261                        -offset
262                    } else {
263                        offset
264                    }
265                })?;
266        let default_value = parse_default_value(
267            partition_evaluator_args.input_exprs(),
268            partition_evaluator_args.input_fields(),
269        )?;
270
271        Ok(Box::new(WindowShiftEvaluator {
272            shift_offset,
273            default_value,
274            ignore_nulls: partition_evaluator_args.ignore_nulls(),
275            non_null_offsets: VecDeque::new(),
276        }))
277    }
278
279    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
280        let return_field = parse_expr_field(field_args.input_fields())?;
281
282        Ok(return_field
283            .as_ref()
284            .clone()
285            .with_name(field_args.name())
286            .into())
287    }
288
289    fn reverse_expr(&self) -> ReversedUDWF {
290        match self.kind {
291            WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
292            WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
293        }
294    }
295
296    fn documentation(&self) -> Option<&Documentation> {
297        match self.kind {
298            WindowShiftKind::Lag => Some(get_lag_doc()),
299            WindowShiftKind::Lead => Some(get_lead_doc()),
300        }
301    }
302
303    fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
304        let Some(other) = other.as_any().downcast_ref::<Self>() else {
305            return false;
306        };
307        let Self { signature, kind } = self;
308        signature == &other.signature && kind == &other.kind
309    }
310
311    fn hash_value(&self) -> u64 {
312        let Self { signature, kind } = self;
313        let mut hasher = DefaultHasher::new();
314        std::any::type_name::<Self>().hash(&mut hasher);
315        signature.hash(&mut hasher);
316        kind.hash(&mut hasher);
317        hasher.finish()
318    }
319}
320
321/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
322/// refine it by matching it with the type of the default value.
323///
324/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
325/// is refined into `ScalarValue::Boolean(None)`. Only the type is
326/// refined, the expression value remains `NULL`.
327///
328/// When the window function is evaluated with `NULL` expression
329/// this guarantees that the type matches with that of the default
330/// value.
331///
332/// For more details see: <https://github.com/apache/datafusion/issues/12717>
333fn parse_expr(
334    input_exprs: &[Arc<dyn PhysicalExpr>],
335    input_fields: &[FieldRef],
336) -> Result<Arc<dyn PhysicalExpr>> {
337    assert!(!input_exprs.is_empty());
338    assert!(!input_fields.is_empty());
339
340    let expr = Arc::clone(input_exprs.first().unwrap());
341    let expr_field = input_fields.first().unwrap();
342
343    // Handles the most common case where NULL is unexpected
344    if !expr_field.data_type().is_null() {
345        return Ok(expr);
346    }
347
348    let default_value = get_scalar_value_from_args(input_exprs, 2)?;
349    default_value.map_or(Ok(expr), |value| {
350        ScalarValue::try_from(&value.data_type()).map(|v| {
351            Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
352                as Arc<dyn PhysicalExpr>
353        })
354    })
355}
356
357static NULL_FIELD: LazyLock<FieldRef> =
358    LazyLock::new(|| Field::new("value", DataType::Null, true).into());
359
360/// Returns the field of the default value(if provided) when the
361/// expression is `NULL`.
362///
363/// Otherwise, returns the expression field unchanged.
364fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
365    assert!(!input_fields.is_empty());
366    let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
367
368    // Handles the most common case where NULL is unexpected
369    if !expr_field.data_type().is_null() {
370        return Ok(expr_field.as_ref().clone().with_nullable(true).into());
371    }
372
373    let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
374    Ok(default_value_field
375        .as_ref()
376        .clone()
377        .with_nullable(true)
378        .into())
379}
380
381/// Handles type coercion and null value refinement for default value
382/// argument depending on the data type of the input expression.
383fn parse_default_value(
384    input_exprs: &[Arc<dyn PhysicalExpr>],
385    input_types: &[FieldRef],
386) -> Result<ScalarValue> {
387    let expr_field = parse_expr_field(input_types)?;
388    let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
389
390    unparsed
391        .filter(|v| !v.data_type().is_null())
392        .map(|v| v.cast_to(expr_field.data_type()))
393        .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
394}
395
396#[derive(Debug)]
397struct WindowShiftEvaluator {
398    shift_offset: i64,
399    default_value: ScalarValue,
400    ignore_nulls: bool,
401    // VecDeque contains offset values that between non-null entries
402    non_null_offsets: VecDeque<usize>,
403}
404
405impl WindowShiftEvaluator {
406    fn is_lag(&self) -> bool {
407        // Mode is LAG, when shift_offset is positive
408        self.shift_offset > 0
409    }
410}
411
412// implement ignore null for evaluate_all
413fn evaluate_all_with_ignore_null(
414    array: &ArrayRef,
415    offset: i64,
416    default_value: &ScalarValue,
417    is_lag: bool,
418) -> Result<ArrayRef, DataFusionError> {
419    let valid_indices: Vec<usize> =
420        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
421    let direction = !is_lag;
422    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
423        .map(|id| {
424            let result_index = match valid_indices.binary_search(&id) {
425                Ok(pos) => if direction {
426                    pos.checked_add(offset as usize)
427                } else {
428                    pos.checked_sub(offset.unsigned_abs() as usize)
429                }
430                .and_then(|new_pos| {
431                    if new_pos < valid_indices.len() {
432                        Some(valid_indices[new_pos])
433                    } else {
434                        None
435                    }
436                }),
437                Err(pos) => if direction {
438                    pos.checked_add(offset as usize)
439                } else if pos > 0 {
440                    pos.checked_sub(offset.unsigned_abs() as usize)
441                } else {
442                    None
443                }
444                .and_then(|new_pos| {
445                    if new_pos < valid_indices.len() {
446                        Some(valid_indices[new_pos])
447                    } else {
448                        None
449                    }
450                }),
451            };
452
453            match result_index {
454                Some(index) => ScalarValue::try_from_array(array, index),
455                None => Ok(default_value.clone()),
456            }
457        })
458        .collect();
459
460    let new_array = new_array_results?;
461    ScalarValue::iter_to_array(new_array)
462}
463// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
464fn shift_with_default_value(
465    array: &ArrayRef,
466    offset: i64,
467    default_value: &ScalarValue,
468) -> Result<ArrayRef> {
469    use datafusion_common::arrow::compute::concat;
470
471    let value_len = array.len() as i64;
472    if offset == 0 {
473        Ok(Arc::clone(array))
474    } else if offset == i64::MIN || offset.abs() >= value_len {
475        default_value.to_array_of_size(value_len as usize)
476    } else {
477        let slice_offset = (-offset).clamp(0, value_len) as usize;
478        let length = array.len() - offset.unsigned_abs() as usize;
479        let slice = array.slice(slice_offset, length);
480
481        // Generate array with remaining `null` items
482        let nulls = offset.unsigned_abs() as usize;
483        let default_values = default_value.to_array_of_size(nulls)?;
484
485        // Concatenate both arrays, add nulls after if shift > 0 else before
486        if offset > 0 {
487            concat(&[default_values.as_ref(), slice.as_ref()])
488                .map_err(|e| arrow_datafusion_err!(e))
489        } else {
490            concat(&[slice.as_ref(), default_values.as_ref()])
491                .map_err(|e| arrow_datafusion_err!(e))
492        }
493    }
494}
495
496impl PartitionEvaluator for WindowShiftEvaluator {
497    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
498        if self.is_lag() {
499            let start = if self.non_null_offsets.len() == self.shift_offset as usize {
500                // How many rows needed previous than the current row to get necessary lag result
501                let offset: usize = self.non_null_offsets.iter().sum();
502                idx.saturating_sub(offset)
503            } else if !self.ignore_nulls {
504                let offset = self.shift_offset as usize;
505                idx.saturating_sub(offset)
506            } else {
507                0
508            };
509            let end = idx + 1;
510            Ok(Range { start, end })
511        } else {
512            let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
513                // How many rows needed further than the current row to get necessary lead result
514                let offset: usize = self.non_null_offsets.iter().sum();
515                min(idx + offset + 1, n_rows)
516            } else if !self.ignore_nulls {
517                let offset = (-self.shift_offset) as usize;
518                min(idx + offset, n_rows)
519            } else {
520                n_rows
521            };
522            Ok(Range { start: idx, end })
523        }
524    }
525
526    fn is_causal(&self) -> bool {
527        // Lagging windows are causal by definition:
528        self.is_lag()
529    }
530
531    fn evaluate(
532        &mut self,
533        values: &[ArrayRef],
534        range: &Range<usize>,
535    ) -> Result<ScalarValue> {
536        let array = &values[0];
537        let len = array.len();
538
539        // LAG mode
540        let i = if self.is_lag() {
541            (range.end as i64 - self.shift_offset - 1) as usize
542        } else {
543            // LEAD mode
544            (range.start as i64 - self.shift_offset) as usize
545        };
546
547        let mut idx: Option<usize> = if i < len { Some(i) } else { None };
548
549        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
550        // If current row index points to NULL value the row is NOT counted
551        if self.ignore_nulls && self.is_lag() {
552            // LAG when NULLS are ignored.
553            // Find the nonNULL row index that shifted by offset comparing to current row index
554            idx = if self.non_null_offsets.len() == self.shift_offset as usize {
555                let total_offset: usize = self.non_null_offsets.iter().sum();
556                Some(range.end - 1 - total_offset)
557            } else {
558                None
559            };
560
561            // Keep track of offset values between non-null entries
562            if array.is_valid(range.end - 1) {
563                // Non-null add new offset
564                self.non_null_offsets.push_back(1);
565                if self.non_null_offsets.len() > self.shift_offset as usize {
566                    // WE do not need to keep track of more than `lag number of offset` values.
567                    self.non_null_offsets.pop_front();
568                }
569            } else if !self.non_null_offsets.is_empty() {
570                // Entry is null, increment offset value of the last entry.
571                let end_idx = self.non_null_offsets.len() - 1;
572                self.non_null_offsets[end_idx] += 1;
573            }
574        } else if self.ignore_nulls && !self.is_lag() {
575            // LEAD when NULLS are ignored.
576            // Stores the necessary non-null entry number further than the current row.
577            let non_null_row_count = (-self.shift_offset) as usize;
578
579            if self.non_null_offsets.is_empty() {
580                // When empty, fill non_null offsets with the data further than the current row.
581                let mut offset_val = 1;
582                for idx in range.start + 1..range.end {
583                    if array.is_valid(idx) {
584                        self.non_null_offsets.push_back(offset_val);
585                        offset_val = 1;
586                    } else {
587                        offset_val += 1;
588                    }
589                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
590                    // further data is unnecessary for the result.
591                    if self.non_null_offsets.len() == non_null_row_count + 1 {
592                        break;
593                    }
594                }
595            } else if range.end < len && array.is_valid(range.end) {
596                // Update `non_null_offsets` with the new end data.
597                if array.is_valid(range.end) {
598                    // When non-null, append a new offset.
599                    self.non_null_offsets.push_back(1);
600                } else {
601                    // When null, increment offset count of the last entry
602                    let last_idx = self.non_null_offsets.len() - 1;
603                    self.non_null_offsets[last_idx] += 1;
604                }
605            }
606
607            // Find the nonNULL row index that shifted by offset comparing to current row index
608            idx = if self.non_null_offsets.len() >= non_null_row_count {
609                let total_offset: usize =
610                    self.non_null_offsets.iter().take(non_null_row_count).sum();
611                Some(range.start + total_offset)
612            } else {
613                None
614            };
615            // Prune `self.non_null_offsets` from the start. so that at next iteration
616            // start of the `self.non_null_offsets` matches with current row.
617            if !self.non_null_offsets.is_empty() {
618                self.non_null_offsets[0] -= 1;
619                if self.non_null_offsets[0] == 0 {
620                    // When offset is 0. Remove it.
621                    self.non_null_offsets.pop_front();
622                }
623            }
624        }
625
626        // Set the default value if
627        // - index is out of window bounds
628        // OR
629        // - ignore nulls mode and current value is null and is within window bounds
630        // .unwrap() is safe here as there is a none check in front
631        #[allow(clippy::unnecessary_unwrap)]
632        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
633            ScalarValue::try_from_array(array, idx.unwrap())
634        } else {
635            Ok(self.default_value.clone())
636        }
637    }
638
639    fn evaluate_all(
640        &mut self,
641        values: &[ArrayRef],
642        _num_rows: usize,
643    ) -> Result<ArrayRef> {
644        // LEAD, LAG window functions take single column, values will have size 1
645        let value = &values[0];
646        if !self.ignore_nulls {
647            shift_with_default_value(value, self.shift_offset, &self.default_value)
648        } else {
649            evaluate_all_with_ignore_null(
650                value,
651                self.shift_offset,
652                &self.default_value,
653                self.is_lag(),
654            )
655        }
656    }
657
658    fn supports_bounded_execution(&self) -> bool {
659        true
660    }
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666    use arrow::array::*;
667    use datafusion_common::cast::as_int32_array;
668    use datafusion_physical_expr::expressions::{Column, Literal};
669    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
670
671    fn test_i32_result(
672        expr: WindowShift,
673        partition_evaluator_args: PartitionEvaluatorArgs,
674        expected: Int32Array,
675    ) -> Result<()> {
676        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
677        let values = vec![arr];
678        let num_rows = values.len();
679        let result = expr
680            .partition_evaluator(partition_evaluator_args)?
681            .evaluate_all(&values, num_rows)?;
682        let result = as_int32_array(&result)?;
683        assert_eq!(expected, *result);
684        Ok(())
685    }
686
687    #[test]
688    fn lead_lag_get_range() -> Result<()> {
689        // LAG(2)
690        let lag_fn = WindowShiftEvaluator {
691            shift_offset: 2,
692            default_value: ScalarValue::Null,
693            ignore_nulls: false,
694            non_null_offsets: Default::default(),
695        };
696        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
697        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
698
699        // LAG(2 ignore nulls)
700        let lag_fn = WindowShiftEvaluator {
701            shift_offset: 2,
702            default_value: ScalarValue::Null,
703            ignore_nulls: true,
704            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
705            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
706        };
707        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
708
709        // LEAD(2)
710        let lead_fn = WindowShiftEvaluator {
711            shift_offset: -2,
712            default_value: ScalarValue::Null,
713            ignore_nulls: false,
714            non_null_offsets: Default::default(),
715        };
716        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
717        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
718
719        // LEAD(2 ignore nulls)
720        let lead_fn = WindowShiftEvaluator {
721            shift_offset: -2,
722            default_value: ScalarValue::Null,
723            ignore_nulls: true,
724            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
725            non_null_offsets: vec![2, 2].into(),
726        };
727        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
728
729        Ok(())
730    }
731
732    #[test]
733    fn test_lead_window_shift() -> Result<()> {
734        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
735
736        test_i32_result(
737            WindowShift::lead(),
738            PartitionEvaluatorArgs::new(
739                &[expr],
740                &[Field::new("f", DataType::Int32, true).into()],
741                false,
742                false,
743            ),
744            [
745                Some(-2),
746                Some(3),
747                Some(-4),
748                Some(5),
749                Some(-6),
750                Some(7),
751                Some(8),
752                None,
753            ]
754            .iter()
755            .collect::<Int32Array>(),
756        )
757    }
758
759    #[test]
760    fn test_lag_window_shift() -> Result<()> {
761        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
762
763        test_i32_result(
764            WindowShift::lag(),
765            PartitionEvaluatorArgs::new(
766                &[expr],
767                &[Field::new("f", DataType::Int32, true).into()],
768                false,
769                false,
770            ),
771            [
772                None,
773                Some(1),
774                Some(-2),
775                Some(3),
776                Some(-4),
777                Some(5),
778                Some(-6),
779                Some(7),
780            ]
781            .iter()
782            .collect::<Int32Array>(),
783        )
784    }
785
786    #[test]
787    fn test_lag_with_default() -> Result<()> {
788        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
789        let shift_offset =
790            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
791        let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
792            as Arc<dyn PhysicalExpr>;
793
794        let input_exprs = &[expr, shift_offset, default_value];
795        let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
796            .into_iter()
797            .map(|d| Field::new("f", d, true))
798            .map(Arc::new)
799            .collect::<Vec<_>>();
800
801        test_i32_result(
802            WindowShift::lag(),
803            PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
804            [
805                Some(100),
806                Some(1),
807                Some(-2),
808                Some(3),
809                Some(-4),
810                Some(5),
811                Some(-6),
812                Some(7),
813            ]
814            .iter()
815            .collect::<Int32Array>(),
816        )
817    }
818}