datafusion_functions_window/
lead_lag.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `lead` and `lag` window function implementations
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29    TypeSignature, Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr::expressions;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use std::any::Any;
37use std::cmp::min;
38use std::collections::VecDeque;
39use std::hash::Hash;
40use std::ops::{Neg, Range};
41use std::sync::{Arc, LazyLock};
42
43get_or_init_udwf!(
44    Lag,
45    lag,
46    "Returns the row value that precedes the current row by a specified \
47    offset within partition. If no such row exists, then returns the \
48    default value.",
49    WindowShift::lag
50);
51get_or_init_udwf!(
52    Lead,
53    lead,
54    "Returns the value from a row that follows the current row by a \
55    specified offset within the partition. If no such row exists, then \
56    returns the default value.",
57    WindowShift::lead
58);
59
60/// Create an expression to represent the `lag` window function
61///
62/// returns value evaluated at the row that is offset rows before the current row within the partition;
63/// if there is no such row, instead return default (which must be of the same type as value).
64/// Both offset and default are evaluated with respect to the current row.
65/// If omitted, offset defaults to 1 and default to null
66pub fn lag(
67    arg: datafusion_expr::Expr,
68    shift_offset: Option<i64>,
69    default_value: Option<ScalarValue>,
70) -> datafusion_expr::Expr {
71    let shift_offset_lit = shift_offset
72        .map(|v| v.lit())
73        .unwrap_or(ScalarValue::Null.lit());
74    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
75
76    lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
77}
78
79/// Create an expression to represent the `lead` window function
80///
81/// returns value evaluated at the row that is offset rows after the current row within the partition;
82/// if there is no such row, instead return default (which must be of the same type as value).
83/// Both offset and default are evaluated with respect to the current row.
84/// If omitted, offset defaults to 1 and default to null
85pub fn lead(
86    arg: datafusion_expr::Expr,
87    shift_offset: Option<i64>,
88    default_value: Option<ScalarValue>,
89) -> datafusion_expr::Expr {
90    let shift_offset_lit = shift_offset
91        .map(|v| v.lit())
92        .unwrap_or(ScalarValue::Null.lit());
93    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
94
95    lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub enum WindowShiftKind {
100    Lag,
101    Lead,
102}
103
104impl WindowShiftKind {
105    fn name(&self) -> &'static str {
106        match self {
107            WindowShiftKind::Lag => "lag",
108            WindowShiftKind::Lead => "lead",
109        }
110    }
111
112    /// In [`WindowShiftEvaluator`] a positive offset is used to signal
113    /// computation of `lag()`. So here we negate the input offset
114    /// value when computing `lead()`.
115    fn shift_offset(&self, value: Option<i64>) -> i64 {
116        match self {
117            WindowShiftKind::Lag => value.unwrap_or(1),
118            WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
119        }
120    }
121}
122
123/// window shift expression
124#[derive(Debug, PartialEq, Eq, Hash)]
125pub struct WindowShift {
126    signature: Signature,
127    kind: WindowShiftKind,
128}
129
130impl WindowShift {
131    fn new(kind: WindowShiftKind) -> Self {
132        Self {
133            signature: Signature::one_of(
134                vec![
135                    TypeSignature::Any(1),
136                    TypeSignature::Any(2),
137                    TypeSignature::Any(3),
138                ],
139                Volatility::Immutable,
140            )
141            .with_parameter_names(vec![
142                "expr".to_string(),
143                "offset".to_string(),
144                "default".to_string(),
145            ])
146            .expect("valid parameter names for lead/lag"),
147            kind,
148        }
149    }
150
151    pub fn lag() -> Self {
152        Self::new(WindowShiftKind::Lag)
153    }
154
155    pub fn lead() -> Self {
156        Self::new(WindowShiftKind::Lead)
157    }
158
159    pub fn kind(&self) -> &WindowShiftKind {
160        &self.kind
161    }
162}
163
164static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
165    Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
166            current row within the partition; if there is no such row, instead return default \
167            (which must be of the same type as value).", "lag(expression, offset, default)")
168        .with_argument("expression", "Expression to operate on")
169        .with_argument("offset", "Integer. Specifies how many rows back \
170        the value of expression should be retrieved. Defaults to 1.")
171        .with_argument("default", "The default value if the offset is \
172        not within the partition. Must be of the same type as expression.")
173        .with_sql_example(r#"
174```sql
175-- Example usage of the lag window function:
176SELECT employee_id,
177    salary,
178    lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
179FROM employees;
180
181+-------------+--------+-------------+
182| employee_id | salary | prev_salary |
183+-------------+--------+-------------+
184| 1           | 30000  | 0           |
185| 2           | 50000  | 30000       |
186| 3           | 70000  | 50000       |
187| 4           | 60000  | 70000       |
188+-------------+--------+-------------+
189```
190"#)
191        .build()
192});
193
194fn get_lag_doc() -> &'static Documentation {
195    &LAG_DOCUMENTATION
196}
197
198static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
199    Documentation::builder(DOC_SECTION_ANALYTICAL,
200            "Returns value evaluated at the row that is offset rows after the \
201            current row within the partition; if there is no such row, instead return default \
202            (which must be of the same type as value).",
203        "lead(expression, offset, default)")
204        .with_argument("expression", "Expression to operate on")
205        .with_argument("offset", "Integer. Specifies how many rows \
206        forward the value of expression should be retrieved. Defaults to 1.")
207        .with_argument("default", "The default value if the offset is \
208        not within the partition. Must be of the same type as expression.")
209        .with_sql_example(r#"
210```sql
211-- Example usage of lead window function:
212SELECT
213    employee_id,
214    department,
215    salary,
216    lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
217FROM employees;
218
219+-------------+-------------+--------+--------------+
220| employee_id | department  | salary | next_salary  |
221+-------------+-------------+--------+--------------+
222| 1           | Sales       | 30000  | 50000        |
223| 2           | Sales       | 50000  | 70000        |
224| 3           | Sales       | 70000  | 0            |
225| 4           | Engineering | 40000  | 60000        |
226| 5           | Engineering | 60000  | 0            |
227+-------------+-------------+--------+--------------+
228```
229"#)
230        .build()
231});
232
233fn get_lead_doc() -> &'static Documentation {
234    &LEAD_DOCUMENTATION
235}
236
237impl WindowUDFImpl for WindowShift {
238    fn as_any(&self) -> &dyn Any {
239        self
240    }
241
242    fn name(&self) -> &str {
243        self.kind.name()
244    }
245
246    fn signature(&self) -> &Signature {
247        &self.signature
248    }
249
250    /// Handles the case where `NULL` expression is passed as an
251    /// argument to `lead`/`lag`. The type is refined depending
252    /// on the default value argument.
253    ///
254    /// For more details see: <https://github.com/apache/datafusion/issues/12717>
255    fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
256        parse_expr(expr_args.input_exprs(), expr_args.input_fields())
257            .into_iter()
258            .collect::<Vec<_>>()
259    }
260
261    fn partition_evaluator(
262        &self,
263        partition_evaluator_args: PartitionEvaluatorArgs,
264    ) -> Result<Box<dyn PartitionEvaluator>> {
265        let shift_offset =
266            get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
267                .map(get_signed_integer)
268                .map_or(Ok(None), |v| v.map(Some))
269                .map(|n| self.kind.shift_offset(n))
270                .map(|offset| {
271                    if partition_evaluator_args.is_reversed() {
272                        -offset
273                    } else {
274                        offset
275                    }
276                })?;
277        let default_value = parse_default_value(
278            partition_evaluator_args.input_exprs(),
279            partition_evaluator_args.input_fields(),
280        )?;
281
282        Ok(Box::new(WindowShiftEvaluator {
283            shift_offset,
284            default_value,
285            ignore_nulls: partition_evaluator_args.ignore_nulls(),
286            non_null_offsets: VecDeque::new(),
287        }))
288    }
289
290    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
291        let return_field = parse_expr_field(field_args.input_fields())?;
292
293        Ok(return_field
294            .as_ref()
295            .clone()
296            .with_name(field_args.name())
297            .into())
298    }
299
300    fn reverse_expr(&self) -> ReversedUDWF {
301        match self.kind {
302            WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
303            WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
304        }
305    }
306
307    fn documentation(&self) -> Option<&Documentation> {
308        match self.kind {
309            WindowShiftKind::Lag => Some(get_lag_doc()),
310            WindowShiftKind::Lead => Some(get_lead_doc()),
311        }
312    }
313
314    fn limit_effect(&self, args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
315        if self.kind == WindowShiftKind::Lag {
316            return LimitEffect::None;
317        }
318        match args {
319            [_, expr, ..] => {
320                let Some(lit) = expr.as_any().downcast_ref::<expressions::Literal>()
321                else {
322                    return LimitEffect::Unknown;
323                };
324                let ScalarValue::Int64(Some(amount)) = lit.value() else {
325                    return LimitEffect::Unknown; // we should only get int64 from the parser
326                };
327                LimitEffect::Relative((*amount).max(0) as usize)
328            }
329            [_] => LimitEffect::Relative(1), // default value
330            _ => LimitEffect::Unknown,       // invalid arguments
331        }
332    }
333}
334
335/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
336/// refine it by matching it with the type of the default value.
337///
338/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
339/// is refined into `ScalarValue::Boolean(None)`. Only the type is
340/// refined, the expression value remains `NULL`.
341///
342/// When the window function is evaluated with `NULL` expression
343/// this guarantees that the type matches with that of the default
344/// value.
345///
346/// For more details see: <https://github.com/apache/datafusion/issues/12717>
347fn parse_expr(
348    input_exprs: &[Arc<dyn PhysicalExpr>],
349    input_fields: &[FieldRef],
350) -> Result<Arc<dyn PhysicalExpr>> {
351    assert!(!input_exprs.is_empty());
352    assert!(!input_fields.is_empty());
353
354    let expr = Arc::clone(input_exprs.first().unwrap());
355    let expr_field = input_fields.first().unwrap();
356
357    // Handles the most common case where NULL is unexpected
358    if !expr_field.data_type().is_null() {
359        return Ok(expr);
360    }
361
362    let default_value = get_scalar_value_from_args(input_exprs, 2)?;
363    default_value.map_or(Ok(expr), |value| {
364        ScalarValue::try_from(&value.data_type())
365            .map(|v| Arc::new(expressions::Literal::new(v)) as Arc<dyn PhysicalExpr>)
366    })
367}
368
369static NULL_FIELD: LazyLock<FieldRef> =
370    LazyLock::new(|| Field::new("value", DataType::Null, true).into());
371
372/// Returns the field of the default value(if provided) when the
373/// expression is `NULL`.
374///
375/// Otherwise, returns the expression field unchanged.
376fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
377    assert!(!input_fields.is_empty());
378    let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
379
380    // Handles the most common case where NULL is unexpected
381    if !expr_field.data_type().is_null() {
382        return Ok(expr_field.as_ref().clone().with_nullable(true).into());
383    }
384
385    let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
386    Ok(default_value_field
387        .as_ref()
388        .clone()
389        .with_nullable(true)
390        .into())
391}
392
393/// Handles type coercion and null value refinement for default value
394/// argument depending on the data type of the input expression.
395fn parse_default_value(
396    input_exprs: &[Arc<dyn PhysicalExpr>],
397    input_types: &[FieldRef],
398) -> Result<ScalarValue> {
399    let expr_field = parse_expr_field(input_types)?;
400    let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
401
402    unparsed
403        .filter(|v| !v.data_type().is_null())
404        .map(|v| v.cast_to(expr_field.data_type()))
405        .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
406}
407
408#[derive(Debug)]
409struct WindowShiftEvaluator {
410    shift_offset: i64,
411    default_value: ScalarValue,
412    ignore_nulls: bool,
413    // VecDeque contains offset values that between non-null entries
414    non_null_offsets: VecDeque<usize>,
415}
416
417impl WindowShiftEvaluator {
418    fn is_lag(&self) -> bool {
419        // Mode is LAG, when shift_offset is positive
420        self.shift_offset > 0
421    }
422}
423
424// implement ignore null for evaluate_all
425fn evaluate_all_with_ignore_null(
426    array: &ArrayRef,
427    offset: i64,
428    default_value: &ScalarValue,
429    is_lag: bool,
430) -> Result<ArrayRef, DataFusionError> {
431    let valid_indices: Vec<usize> =
432        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
433    let direction = !is_lag;
434    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
435        .map(|id| {
436            let result_index = match valid_indices.binary_search(&id) {
437                Ok(pos) => if direction {
438                    pos.checked_add(offset as usize)
439                } else {
440                    pos.checked_sub(offset.unsigned_abs() as usize)
441                }
442                .and_then(|new_pos| {
443                    if new_pos < valid_indices.len() {
444                        Some(valid_indices[new_pos])
445                    } else {
446                        None
447                    }
448                }),
449                Err(pos) => if direction {
450                    pos.checked_add(offset as usize)
451                } else if pos > 0 {
452                    pos.checked_sub(offset.unsigned_abs() as usize)
453                } else {
454                    None
455                }
456                .and_then(|new_pos| {
457                    if new_pos < valid_indices.len() {
458                        Some(valid_indices[new_pos])
459                    } else {
460                        None
461                    }
462                }),
463            };
464
465            match result_index {
466                Some(index) => ScalarValue::try_from_array(array, index),
467                None => Ok(default_value.clone()),
468            }
469        })
470        .collect();
471
472    let new_array = new_array_results?;
473    ScalarValue::iter_to_array(new_array)
474}
475// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
476fn shift_with_default_value(
477    array: &ArrayRef,
478    offset: i64,
479    default_value: &ScalarValue,
480) -> Result<ArrayRef> {
481    use datafusion_common::arrow::compute::concat;
482
483    let value_len = array.len() as i64;
484    if offset == 0 {
485        Ok(Arc::clone(array))
486    } else if offset == i64::MIN || offset.abs() >= value_len {
487        default_value.to_array_of_size(value_len as usize)
488    } else {
489        let slice_offset = (-offset).clamp(0, value_len) as usize;
490        let length = array.len() - offset.unsigned_abs() as usize;
491        let slice = array.slice(slice_offset, length);
492
493        // Generate array with remaining `null` items
494        let nulls = offset.unsigned_abs() as usize;
495        let default_values = default_value.to_array_of_size(nulls)?;
496
497        // Concatenate both arrays, add nulls after if shift > 0 else before
498        if offset > 0 {
499            concat(&[default_values.as_ref(), slice.as_ref()])
500                .map_err(|e| arrow_datafusion_err!(e))
501        } else {
502            concat(&[slice.as_ref(), default_values.as_ref()])
503                .map_err(|e| arrow_datafusion_err!(e))
504        }
505    }
506}
507
508impl PartitionEvaluator for WindowShiftEvaluator {
509    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
510        if self.is_lag() {
511            let start = if self.non_null_offsets.len() == self.shift_offset as usize {
512                // How many rows needed previous than the current row to get necessary lag result
513                let offset: usize = self.non_null_offsets.iter().sum();
514                idx.saturating_sub(offset)
515            } else if !self.ignore_nulls {
516                let offset = self.shift_offset as usize;
517                idx.saturating_sub(offset)
518            } else {
519                0
520            };
521            let end = idx + 1;
522            Ok(Range { start, end })
523        } else {
524            let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
525                // How many rows needed further than the current row to get necessary lead result
526                let offset: usize = self.non_null_offsets.iter().sum();
527                min(idx + offset + 1, n_rows)
528            } else if !self.ignore_nulls {
529                let offset = (-self.shift_offset) as usize;
530                min(idx + offset, n_rows)
531            } else {
532                n_rows
533            };
534            Ok(Range { start: idx, end })
535        }
536    }
537
538    fn is_causal(&self) -> bool {
539        // Lagging windows are causal by definition:
540        self.is_lag()
541    }
542
543    fn evaluate(
544        &mut self,
545        values: &[ArrayRef],
546        range: &Range<usize>,
547    ) -> Result<ScalarValue> {
548        let array = &values[0];
549        let len = array.len();
550
551        // LAG mode
552        let i = if self.is_lag() {
553            (range.end as i64 - self.shift_offset - 1) as usize
554        } else {
555            // LEAD mode
556            (range.start as i64 - self.shift_offset) as usize
557        };
558
559        let mut idx: Option<usize> = if i < len { Some(i) } else { None };
560
561        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
562        // If current row index points to NULL value the row is NOT counted
563        if self.ignore_nulls && self.is_lag() {
564            // LAG when NULLS are ignored.
565            // Find the nonNULL row index that shifted by offset comparing to current row index
566            idx = if self.non_null_offsets.len() == self.shift_offset as usize {
567                let total_offset: usize = self.non_null_offsets.iter().sum();
568                Some(range.end - 1 - total_offset)
569            } else {
570                None
571            };
572
573            // Keep track of offset values between non-null entries
574            if array.is_valid(range.end - 1) {
575                // Non-null add new offset
576                self.non_null_offsets.push_back(1);
577                if self.non_null_offsets.len() > self.shift_offset as usize {
578                    // WE do not need to keep track of more than `lag number of offset` values.
579                    self.non_null_offsets.pop_front();
580                }
581            } else if !self.non_null_offsets.is_empty() {
582                // Entry is null, increment offset value of the last entry.
583                let end_idx = self.non_null_offsets.len() - 1;
584                self.non_null_offsets[end_idx] += 1;
585            }
586        } else if self.ignore_nulls && !self.is_lag() {
587            // LEAD when NULLS are ignored.
588            // Stores the necessary non-null entry number further than the current row.
589            let non_null_row_count = (-self.shift_offset) as usize;
590
591            if self.non_null_offsets.is_empty() {
592                // When empty, fill non_null offsets with the data further than the current row.
593                let mut offset_val = 1;
594                for idx in range.start + 1..range.end {
595                    if array.is_valid(idx) {
596                        self.non_null_offsets.push_back(offset_val);
597                        offset_val = 1;
598                    } else {
599                        offset_val += 1;
600                    }
601                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
602                    // further data is unnecessary for the result.
603                    if self.non_null_offsets.len() == non_null_row_count + 1 {
604                        break;
605                    }
606                }
607            } else if range.end < len && array.is_valid(range.end) {
608                // Update `non_null_offsets` with the new end data.
609                if array.is_valid(range.end) {
610                    // When non-null, append a new offset.
611                    self.non_null_offsets.push_back(1);
612                } else {
613                    // When null, increment offset count of the last entry
614                    let last_idx = self.non_null_offsets.len() - 1;
615                    self.non_null_offsets[last_idx] += 1;
616                }
617            }
618
619            // Find the nonNULL row index that shifted by offset comparing to current row index
620            idx = if self.non_null_offsets.len() >= non_null_row_count {
621                let total_offset: usize =
622                    self.non_null_offsets.iter().take(non_null_row_count).sum();
623                Some(range.start + total_offset)
624            } else {
625                None
626            };
627            // Prune `self.non_null_offsets` from the start. so that at next iteration
628            // start of the `self.non_null_offsets` matches with current row.
629            if !self.non_null_offsets.is_empty() {
630                self.non_null_offsets[0] -= 1;
631                if self.non_null_offsets[0] == 0 {
632                    // When offset is 0. Remove it.
633                    self.non_null_offsets.pop_front();
634                }
635            }
636        }
637
638        // Set the default value if
639        // - index is out of window bounds
640        // OR
641        // - ignore nulls mode and current value is null and is within window bounds
642        // .unwrap() is safe here as there is a none check in front
643        #[allow(clippy::unnecessary_unwrap)]
644        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
645            ScalarValue::try_from_array(array, idx.unwrap())
646        } else {
647            Ok(self.default_value.clone())
648        }
649    }
650
651    fn evaluate_all(
652        &mut self,
653        values: &[ArrayRef],
654        _num_rows: usize,
655    ) -> Result<ArrayRef> {
656        // LEAD, LAG window functions take single column, values will have size 1
657        let value = &values[0];
658        if !self.ignore_nulls {
659            shift_with_default_value(value, self.shift_offset, &self.default_value)
660        } else {
661            evaluate_all_with_ignore_null(
662                value,
663                self.shift_offset,
664                &self.default_value,
665                self.is_lag(),
666            )
667        }
668    }
669
670    fn supports_bounded_execution(&self) -> bool {
671        true
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use arrow::array::*;
679    use datafusion_common::cast::as_int32_array;
680    use datafusion_physical_expr::expressions::{Column, Literal};
681    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
682
683    fn test_i32_result(
684        expr: WindowShift,
685        partition_evaluator_args: PartitionEvaluatorArgs,
686        expected: Int32Array,
687    ) -> Result<()> {
688        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
689        let values = vec![arr];
690        let num_rows = values.len();
691        let result = expr
692            .partition_evaluator(partition_evaluator_args)?
693            .evaluate_all(&values, num_rows)?;
694        let result = as_int32_array(&result)?;
695        assert_eq!(expected, *result);
696        Ok(())
697    }
698
699    #[test]
700    fn lead_lag_get_range() -> Result<()> {
701        // LAG(2)
702        let lag_fn = WindowShiftEvaluator {
703            shift_offset: 2,
704            default_value: ScalarValue::Null,
705            ignore_nulls: false,
706            non_null_offsets: Default::default(),
707        };
708        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
709        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
710
711        // LAG(2 ignore nulls)
712        let lag_fn = WindowShiftEvaluator {
713            shift_offset: 2,
714            default_value: ScalarValue::Null,
715            ignore_nulls: true,
716            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
717            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
718        };
719        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
720
721        // LEAD(2)
722        let lead_fn = WindowShiftEvaluator {
723            shift_offset: -2,
724            default_value: ScalarValue::Null,
725            ignore_nulls: false,
726            non_null_offsets: Default::default(),
727        };
728        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
729        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
730
731        // LEAD(2 ignore nulls)
732        let lead_fn = WindowShiftEvaluator {
733            shift_offset: -2,
734            default_value: ScalarValue::Null,
735            ignore_nulls: true,
736            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
737            non_null_offsets: vec![2, 2].into(),
738        };
739        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
740
741        Ok(())
742    }
743
744    #[test]
745    fn test_lead_window_shift() -> Result<()> {
746        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
747
748        test_i32_result(
749            WindowShift::lead(),
750            PartitionEvaluatorArgs::new(
751                &[expr],
752                &[Field::new("f", DataType::Int32, true).into()],
753                false,
754                false,
755            ),
756            [
757                Some(-2),
758                Some(3),
759                Some(-4),
760                Some(5),
761                Some(-6),
762                Some(7),
763                Some(8),
764                None,
765            ]
766            .iter()
767            .collect::<Int32Array>(),
768        )
769    }
770
771    #[test]
772    fn test_lag_window_shift() -> Result<()> {
773        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
774
775        test_i32_result(
776            WindowShift::lag(),
777            PartitionEvaluatorArgs::new(
778                &[expr],
779                &[Field::new("f", DataType::Int32, true).into()],
780                false,
781                false,
782            ),
783            [
784                None,
785                Some(1),
786                Some(-2),
787                Some(3),
788                Some(-4),
789                Some(5),
790                Some(-6),
791                Some(7),
792            ]
793            .iter()
794            .collect::<Int32Array>(),
795        )
796    }
797
798    #[test]
799    fn test_lag_with_default() -> Result<()> {
800        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
801        let shift_offset =
802            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
803        let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
804            as Arc<dyn PhysicalExpr>;
805
806        let input_exprs = &[expr, shift_offset, default_value];
807        let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
808            .into_iter()
809            .map(|d| Field::new("f", d, true))
810            .map(Arc::new)
811            .collect::<Vec<_>>();
812
813        test_i32_result(
814            WindowShift::lag(),
815            PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
816            [
817                Some(100),
818                Some(1),
819                Some(-2),
820                Some(3),
821                Some(-4),
822                Some(5),
823                Some(-6),
824                Some(7),
825            ]
826            .iter()
827            .collect::<Int32Array>(),
828        )
829    }
830}