Skip to main content

datafusion_functions_window/
lead_lag.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `lead` and `lag` window function implementations
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{DataFusionError, Result, ScalarValue, arrow_datafusion_err};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28    Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29    TypeSignature, Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr::expressions;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::hash::Hash;
39use std::ops::Range;
40use std::sync::{Arc, LazyLock};
41
42get_or_init_udwf!(
43    Lag,
44    lag,
45    lag_udwf,
46    "Returns the row value that precedes the current row by a specified \
47    offset within partition. If no such row exists, then returns the \
48    default value.",
49    WindowShift::lag
50);
51get_or_init_udwf!(
52    Lead,
53    lead,
54    lead_udwf,
55    "Returns the value from a row that follows the current row by a \
56    specified offset within the partition. If no such row exists, then \
57    returns the default value.",
58    WindowShift::lead
59);
60
61/// Create an expression to represent the `lag` window function
62///
63/// returns value evaluated at the row that is offset rows before the current row within the partition;
64/// if there is no such row, instead return default (which must be of the same type as value).
65/// Both offset and default are evaluated with respect to the current row.
66/// If omitted, offset defaults to 1 and default to null
67pub fn lag(
68    arg: datafusion_expr::Expr,
69    shift_offset: Option<i64>,
70    default_value: Option<ScalarValue>,
71) -> datafusion_expr::Expr {
72    let shift_offset_lit = shift_offset
73        .map(|v| v.lit())
74        .unwrap_or(ScalarValue::Null.lit());
75    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
76
77    lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
78}
79
80/// Create an expression to represent the `lead` window function
81///
82/// returns value evaluated at the row that is offset rows after the current row within the partition;
83/// if there is no such row, instead return default (which must be of the same type as value).
84/// Both offset and default are evaluated with respect to the current row.
85/// If omitted, offset defaults to 1 and default to null
86pub fn lead(
87    arg: datafusion_expr::Expr,
88    shift_offset: Option<i64>,
89    default_value: Option<ScalarValue>,
90) -> datafusion_expr::Expr {
91    let shift_offset_lit = shift_offset
92        .map(|v| v.lit())
93        .unwrap_or(ScalarValue::Null.lit());
94    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
95
96    lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
97}
98
99#[derive(Debug, PartialEq, Eq, Hash)]
100pub enum WindowShiftKind {
101    Lag,
102    Lead,
103}
104
105impl WindowShiftKind {
106    fn name(&self) -> &'static str {
107        match self {
108            WindowShiftKind::Lag => "lag",
109            WindowShiftKind::Lead => "lead",
110        }
111    }
112
113    /// In [`WindowShiftEvaluator`] a positive offset is used to signal
114    /// computation of `lag()`. So here we negate the input offset
115    /// value when computing `lead()`.
116    fn shift_offset(&self, value: Option<i64>) -> i64 {
117        match self {
118            WindowShiftKind::Lag => value.unwrap_or(1),
119            WindowShiftKind::Lead => value.map_or(-1, |v| v.wrapping_neg()),
120        }
121    }
122}
123
124/// window shift expression
125#[derive(Debug, PartialEq, Eq, Hash)]
126pub struct WindowShift {
127    signature: Signature,
128    kind: WindowShiftKind,
129}
130
131impl WindowShift {
132    fn new(kind: WindowShiftKind) -> Self {
133        Self {
134            signature: Signature::one_of(
135                vec![
136                    TypeSignature::Any(1),
137                    TypeSignature::Any(2),
138                    TypeSignature::Any(3),
139                ],
140                Volatility::Immutable,
141            )
142            .with_parameter_names(vec![
143                "expr".to_string(),
144                "offset".to_string(),
145                "default".to_string(),
146            ])
147            .expect("valid parameter names for lead/lag"),
148            kind,
149        }
150    }
151
152    pub fn lag() -> Self {
153        Self::new(WindowShiftKind::Lag)
154    }
155
156    pub fn lead() -> Self {
157        Self::new(WindowShiftKind::Lead)
158    }
159
160    pub fn kind(&self) -> &WindowShiftKind {
161        &self.kind
162    }
163}
164
165static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
166    Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
167            current row within the partition; if there is no such row, instead return default \
168            (which must be of the same type as value).", "lag(expression, offset, default)")
169        .with_argument("expression", "Expression to operate on")
170        .with_argument("offset", "Integer. Specifies how many rows back \
171        the value of expression should be retrieved. Defaults to 1.")
172        .with_argument("default", "The default value if the offset is \
173        not within the partition. Must be of the same type as expression.")
174        .with_sql_example(r#"
175```sql
176-- Example usage of the lag window function:
177SELECT employee_id,
178    salary,
179    lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
180FROM employees;
181
182+-------------+--------+-------------+
183| employee_id | salary | prev_salary |
184+-------------+--------+-------------+
185| 1           | 30000  | 0           |
186| 2           | 50000  | 30000       |
187| 3           | 70000  | 50000       |
188| 4           | 60000  | 70000       |
189+-------------+--------+-------------+
190```
191"#)
192        .build()
193});
194
195fn get_lag_doc() -> &'static Documentation {
196    &LAG_DOCUMENTATION
197}
198
199static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
200    Documentation::builder(DOC_SECTION_ANALYTICAL,
201            "Returns value evaluated at the row that is offset rows after the \
202            current row within the partition; if there is no such row, instead return default \
203            (which must be of the same type as value).",
204        "lead(expression, offset, default)")
205        .with_argument("expression", "Expression to operate on")
206        .with_argument("offset", "Integer. Specifies how many rows \
207        forward the value of expression should be retrieved. Defaults to 1.")
208        .with_argument("default", "The default value if the offset is \
209        not within the partition. Must be of the same type as expression.")
210        .with_sql_example(r#"
211```sql
212-- Example usage of lead window function:
213SELECT
214    employee_id,
215    department,
216    salary,
217    lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
218FROM employees;
219
220+-------------+-------------+--------+--------------+
221| employee_id | department  | salary | next_salary  |
222+-------------+-------------+--------+--------------+
223| 1           | Sales       | 30000  | 50000        |
224| 2           | Sales       | 50000  | 70000        |
225| 3           | Sales       | 70000  | 0            |
226| 4           | Engineering | 40000  | 60000        |
227| 5           | Engineering | 60000  | 0            |
228+-------------+-------------+--------+--------------+
229```
230"#)
231        .build()
232});
233
234fn get_lead_doc() -> &'static Documentation {
235    &LEAD_DOCUMENTATION
236}
237
238impl WindowUDFImpl for WindowShift {
239    fn name(&self) -> &str {
240        self.kind.name()
241    }
242
243    fn signature(&self) -> &Signature {
244        &self.signature
245    }
246
247    /// Handles the case where `NULL` expression is passed as an
248    /// argument to `lead`/`lag`. The type is refined depending
249    /// on the default value argument.
250    ///
251    /// For more details see: <https://github.com/apache/datafusion/issues/12717>
252    fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
253        parse_expr(expr_args.input_exprs(), expr_args.input_fields())
254            .into_iter()
255            .collect::<Vec<_>>()
256    }
257
258    fn partition_evaluator(
259        &self,
260        partition_evaluator_args: PartitionEvaluatorArgs,
261    ) -> Result<Box<dyn PartitionEvaluator>> {
262        let shift_offset =
263            get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
264                .map(|v| get_signed_integer(&v))
265                .map_or(Ok(None), |v| v.map(Some))
266                .map(|n| self.kind.shift_offset(n))
267                .map(|offset| {
268                    if partition_evaluator_args.is_reversed() {
269                        offset.wrapping_neg()
270                    } else {
271                        offset
272                    }
273                })?;
274        let default_value = parse_default_value(
275            partition_evaluator_args.input_exprs(),
276            partition_evaluator_args.input_fields(),
277        )?;
278
279        Ok(Box::new(WindowShiftEvaluator {
280            shift_offset,
281            default_value,
282            ignore_nulls: partition_evaluator_args.ignore_nulls(),
283            non_null_offsets: VecDeque::new(),
284        }))
285    }
286
287    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
288        let return_field = parse_expr_field(field_args.input_fields())?;
289
290        Ok(return_field
291            .as_ref()
292            .clone()
293            .with_name(field_args.name())
294            .into())
295    }
296
297    fn reverse_expr(&self) -> ReversedUDWF {
298        match self.kind {
299            WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
300            WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
301        }
302    }
303
304    fn documentation(&self) -> Option<&Documentation> {
305        match self.kind {
306            WindowShiftKind::Lag => Some(get_lag_doc()),
307            WindowShiftKind::Lead => Some(get_lead_doc()),
308        }
309    }
310
311    fn limit_effect(&self, args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
312        if self.kind == WindowShiftKind::Lag {
313            return LimitEffect::None;
314        }
315        match args {
316            [_, expr, ..] => {
317                let Some(lit) = expr.downcast_ref::<expressions::Literal>() else {
318                    return LimitEffect::Unknown;
319                };
320                let ScalarValue::Int64(Some(amount)) = lit.value() else {
321                    return LimitEffect::Unknown; // we should only get int64 from the parser
322                };
323                LimitEffect::Relative((*amount).max(0) as usize)
324            }
325            [_] => LimitEffect::Relative(1), // default value
326            _ => LimitEffect::Unknown,       // invalid arguments
327        }
328    }
329}
330
331/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
332/// refine it by matching it with the type of the default value.
333///
334/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
335/// is refined into `ScalarValue::Boolean(None)`. Only the type is
336/// refined, the expression value remains `NULL`.
337///
338/// When the window function is evaluated with `NULL` expression
339/// this guarantees that the type matches with that of the default
340/// value.
341///
342/// For more details see: <https://github.com/apache/datafusion/issues/12717>
343fn parse_expr(
344    input_exprs: &[Arc<dyn PhysicalExpr>],
345    input_fields: &[FieldRef],
346) -> Result<Arc<dyn PhysicalExpr>> {
347    assert!(!input_exprs.is_empty());
348    assert!(!input_fields.is_empty());
349
350    let expr = Arc::clone(input_exprs.first().unwrap());
351    let expr_field = input_fields.first().unwrap();
352
353    // Handles the most common case where NULL is unexpected
354    if !expr_field.data_type().is_null() {
355        return Ok(expr);
356    }
357
358    let default_value = get_scalar_value_from_args(input_exprs, 2)?;
359    default_value.map_or(Ok(expr), |value| {
360        ScalarValue::try_from(&value.data_type())
361            .map(|v| Arc::new(expressions::Literal::new(v)) as Arc<dyn PhysicalExpr>)
362    })
363}
364
365static NULL_FIELD: LazyLock<FieldRef> =
366    LazyLock::new(|| Field::new("value", DataType::Null, true).into());
367
368/// Returns the field of the default value(if provided) when the
369/// expression is `NULL`.
370///
371/// Otherwise, returns the expression field unchanged.
372fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
373    assert!(!input_fields.is_empty());
374    let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
375
376    // Handles the most common case where NULL is unexpected
377    if !expr_field.data_type().is_null() {
378        return Ok(expr_field.as_ref().clone().with_nullable(true).into());
379    }
380
381    let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
382    Ok(default_value_field
383        .as_ref()
384        .clone()
385        .with_nullable(true)
386        .into())
387}
388
389/// Handles type coercion and null value refinement for default value
390/// argument depending on the data type of the input expression.
391fn parse_default_value(
392    input_exprs: &[Arc<dyn PhysicalExpr>],
393    input_types: &[FieldRef],
394) -> Result<ScalarValue> {
395    let expr_field = parse_expr_field(input_types)?;
396    let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
397
398    unparsed
399        .filter(|v| !v.data_type().is_null())
400        .map(|v| v.cast_to(expr_field.data_type()))
401        .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
402}
403
404#[derive(Debug)]
405struct WindowShiftEvaluator {
406    shift_offset: i64,
407    default_value: ScalarValue,
408    ignore_nulls: bool,
409    // VecDeque contains offset values that between non-null entries
410    non_null_offsets: VecDeque<usize>,
411}
412
413fn offset_magnitude(offset: i64) -> usize {
414    let offset = offset.unsigned_abs();
415    if offset > usize::MAX as u64 {
416        usize::MAX
417    } else {
418        offset as usize
419    }
420}
421
422impl WindowShiftEvaluator {
423    fn is_lag(&self) -> bool {
424        // Mode is LAG, when shift_offset is positive
425        self.shift_offset > 0
426    }
427}
428
429// implement ignore null for evaluate_all
430fn evaluate_all_with_ignore_null(
431    array: &ArrayRef,
432    offset: i64,
433    default_value: &ScalarValue,
434    is_lag: bool,
435) -> Result<ArrayRef, DataFusionError> {
436    let valid_indices: Vec<usize> =
437        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
438    let direction = !is_lag;
439    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
440        .map(|id| {
441            let result_index = match valid_indices.binary_search(&id) {
442                Ok(pos) => if direction {
443                    pos.checked_add(offset as usize)
444                } else {
445                    pos.checked_sub(offset.unsigned_abs() as usize)
446                }
447                .and_then(|new_pos| {
448                    if new_pos < valid_indices.len() {
449                        Some(valid_indices[new_pos])
450                    } else {
451                        None
452                    }
453                }),
454                Err(pos) => if direction {
455                    pos.checked_add(offset as usize)
456                } else if pos > 0 {
457                    pos.checked_sub(offset.unsigned_abs() as usize)
458                } else {
459                    None
460                }
461                .and_then(|new_pos| {
462                    if new_pos < valid_indices.len() {
463                        Some(valid_indices[new_pos])
464                    } else {
465                        None
466                    }
467                }),
468            };
469
470            match result_index {
471                Some(index) => ScalarValue::try_from_array(array, index),
472                None => Ok(default_value.clone()),
473            }
474        })
475        .collect();
476
477    let new_array = new_array_results?;
478    ScalarValue::iter_to_array(new_array)
479}
480// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
481fn shift_with_default_value(
482    array: &ArrayRef,
483    offset: i64,
484    default_value: &ScalarValue,
485) -> Result<ArrayRef> {
486    use datafusion_common::arrow::compute::concat;
487
488    let value_len = array.len() as i64;
489    if offset == 0 {
490        Ok(Arc::clone(array))
491    } else if offset == i64::MIN || offset.abs() >= value_len {
492        default_value.to_array_of_size(value_len as usize)
493    } else {
494        let slice_offset = (-offset).clamp(0, value_len) as usize;
495        let length = array.len() - offset.unsigned_abs() as usize;
496        let slice = array.slice(slice_offset, length);
497
498        // Generate array with remaining `null` items
499        let nulls = offset.unsigned_abs() as usize;
500        let default_values = default_value.to_array_of_size(nulls)?;
501
502        // Concatenate both arrays, add nulls after if shift > 0 else before
503        if offset > 0 {
504            concat(&[default_values.as_ref(), slice.as_ref()])
505                .map_err(|e| arrow_datafusion_err!(e))
506        } else {
507            concat(&[slice.as_ref(), default_values.as_ref()])
508                .map_err(|e| arrow_datafusion_err!(e))
509        }
510    }
511}
512
513impl PartitionEvaluator for WindowShiftEvaluator {
514    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
515        let offset = offset_magnitude(self.shift_offset);
516
517        if self.is_lag() {
518            let start = if self.non_null_offsets.len() == offset {
519                // How many rows needed previous than the current row to get necessary lag result
520                let offset: usize = self.non_null_offsets.iter().sum();
521                idx.saturating_sub(offset)
522            } else if !self.ignore_nulls {
523                idx.saturating_sub(offset)
524            } else {
525                0
526            };
527            let end = idx + 1;
528            Ok(Range { start, end })
529        } else {
530            let end = if self.non_null_offsets.len() == offset {
531                // How many rows needed further than the current row to get necessary lead result
532                let offset: usize = self.non_null_offsets.iter().sum();
533                min(idx.saturating_add(offset).saturating_add(1), n_rows)
534            } else if !self.ignore_nulls {
535                min(idx.saturating_add(offset), n_rows)
536            } else {
537                n_rows
538            };
539            Ok(Range { start: idx, end })
540        }
541    }
542
543    fn is_causal(&self) -> bool {
544        // Lagging windows are causal by definition:
545        self.is_lag()
546    }
547
548    fn evaluate(
549        &mut self,
550        values: &[ArrayRef],
551        range: &Range<usize>,
552    ) -> Result<ScalarValue> {
553        let array = &values[0];
554        let len = array.len();
555
556        // LAG mode
557        let i = if self.is_lag() {
558            range
559                .end
560                .checked_sub(1)
561                .and_then(|end| (end as i64).checked_sub(self.shift_offset))
562                .and_then(|value| usize::try_from(value).ok())
563        } else {
564            // LEAD mode
565            (range.start as i64)
566                .checked_sub(self.shift_offset)
567                .and_then(|value| usize::try_from(value).ok())
568        };
569        let mut idx: Option<usize> = i.filter(|i| *i < len);
570
571        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
572        // If current row index points to NULL value the row is NOT counted
573        if self.ignore_nulls && self.is_lag() {
574            // LAG when NULLS are ignored.
575            // Find the nonNULL row index that shifted by offset comparing to current row index
576            let shift_offset = offset_magnitude(self.shift_offset);
577            idx = if self.non_null_offsets.len() == shift_offset {
578                let total_offset: usize = self.non_null_offsets.iter().sum();
579                Some(range.end - 1 - total_offset)
580            } else {
581                None
582            };
583
584            // Keep track of offset values between non-null entries
585            if array.is_valid(range.end - 1) {
586                // Non-null add new offset
587                self.non_null_offsets.push_back(1);
588                if self.non_null_offsets.len() > shift_offset {
589                    // WE do not need to keep track of more than `lag number of offset` values.
590                    self.non_null_offsets.pop_front();
591                }
592            } else if !self.non_null_offsets.is_empty() {
593                // Entry is null, increment offset value of the last entry.
594                let end_idx = self.non_null_offsets.len() - 1;
595                self.non_null_offsets[end_idx] += 1;
596            }
597        } else if self.ignore_nulls && !self.is_lag() {
598            // LEAD when NULLS are ignored.
599            // Stores the necessary non-null entry number further than the current row.
600            let non_null_row_count = offset_magnitude(self.shift_offset);
601
602            if self.non_null_offsets.is_empty() {
603                // When empty, fill non_null offsets with the data further than the current row.
604                let mut offset_val = 1;
605                for idx in range.start + 1..range.end {
606                    if array.is_valid(idx) {
607                        self.non_null_offsets.push_back(offset_val);
608                        offset_val = 1;
609                    } else {
610                        offset_val += 1;
611                    }
612                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
613                    // further data is unnecessary for the result.
614                    if self.non_null_offsets.len() == non_null_row_count.saturating_add(1)
615                    {
616                        break;
617                    }
618                }
619            } else if range.end < len && array.is_valid(range.end) {
620                // Update `non_null_offsets` with the new end data.
621                if array.is_valid(range.end) {
622                    // When non-null, append a new offset.
623                    self.non_null_offsets.push_back(1);
624                } else {
625                    // When null, increment offset count of the last entry
626                    let last_idx = self.non_null_offsets.len() - 1;
627                    self.non_null_offsets[last_idx] += 1;
628                }
629            }
630
631            // Find the nonNULL row index that shifted by offset comparing to current row index
632            idx = if self.non_null_offsets.len() >= non_null_row_count {
633                let total_offset: usize =
634                    self.non_null_offsets.iter().take(non_null_row_count).sum();
635                Some(range.start + total_offset)
636            } else {
637                None
638            };
639            // Prune `self.non_null_offsets` from the start. so that at next iteration
640            // start of the `self.non_null_offsets` matches with current row.
641            if !self.non_null_offsets.is_empty() {
642                self.non_null_offsets[0] -= 1;
643                if self.non_null_offsets[0] == 0 {
644                    // When offset is 0. Remove it.
645                    self.non_null_offsets.pop_front();
646                }
647            }
648        }
649
650        // Set the default value if
651        // - index is out of window bounds
652        // OR
653        // - ignore nulls mode and current value is null and is within window bounds
654        // .unwrap() is safe here as there is a none check in front
655        #[expect(clippy::unnecessary_unwrap)]
656        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
657            ScalarValue::try_from_array(array, idx.unwrap())
658        } else {
659            Ok(self.default_value.clone())
660        }
661    }
662
663    fn evaluate_all(
664        &mut self,
665        values: &[ArrayRef],
666        _num_rows: usize,
667    ) -> Result<ArrayRef> {
668        // LEAD, LAG window functions take single column, values will have size 1
669        let value = &values[0];
670        if !self.ignore_nulls {
671            shift_with_default_value(value, self.shift_offset, &self.default_value)
672        } else {
673            evaluate_all_with_ignore_null(
674                value,
675                self.shift_offset,
676                &self.default_value,
677                self.is_lag(),
678            )
679        }
680    }
681
682    fn supports_bounded_execution(&self) -> bool {
683        true
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use arrow::array::*;
691    use datafusion_common::cast::as_int32_array;
692    use datafusion_physical_expr::expressions::{Column, Literal};
693
694    fn test_i32_result(
695        expr: WindowShift,
696        partition_evaluator_args: PartitionEvaluatorArgs,
697        expected: Int32Array,
698    ) -> Result<()> {
699        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
700        let values = vec![arr];
701        let num_rows = values.len();
702        let result = expr
703            .partition_evaluator(partition_evaluator_args)?
704            .evaluate_all(&values, num_rows)?;
705        let result = as_int32_array(&result)?;
706        assert_eq!(expected, *result);
707        Ok(())
708    }
709
710    #[test]
711    fn lead_lag_get_range() -> Result<()> {
712        // LAG(2)
713        let lag_fn = WindowShiftEvaluator {
714            shift_offset: 2,
715            default_value: ScalarValue::Null,
716            ignore_nulls: false,
717            non_null_offsets: Default::default(),
718        };
719        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
720        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
721
722        // LAG(2 ignore nulls)
723        let lag_fn = WindowShiftEvaluator {
724            shift_offset: 2,
725            default_value: ScalarValue::Null,
726            ignore_nulls: true,
727            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
728            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
729        };
730        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
731
732        // LEAD(2)
733        let lead_fn = WindowShiftEvaluator {
734            shift_offset: -2,
735            default_value: ScalarValue::Null,
736            ignore_nulls: false,
737            non_null_offsets: Default::default(),
738        };
739        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
740        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
741
742        // LEAD(2 ignore nulls)
743        let lead_fn = WindowShiftEvaluator {
744            shift_offset: -2,
745            default_value: ScalarValue::Null,
746            ignore_nulls: true,
747            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
748            non_null_offsets: vec![2, 2].into(),
749        };
750        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
751
752        Ok(())
753    }
754
755    #[test]
756    fn test_lead_window_shift() -> Result<()> {
757        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
758
759        test_i32_result(
760            WindowShift::lead(),
761            PartitionEvaluatorArgs::new(
762                &[expr],
763                &[Field::new("f", DataType::Int32, true).into()],
764                false,
765                false,
766            ),
767            [
768                Some(-2),
769                Some(3),
770                Some(-4),
771                Some(5),
772                Some(-6),
773                Some(7),
774                Some(8),
775                None,
776            ]
777            .iter()
778            .collect::<Int32Array>(),
779        )
780    }
781
782    #[test]
783    fn test_lag_window_shift() -> Result<()> {
784        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
785
786        test_i32_result(
787            WindowShift::lag(),
788            PartitionEvaluatorArgs::new(
789                &[expr],
790                &[Field::new("f", DataType::Int32, true).into()],
791                false,
792                false,
793            ),
794            [
795                None,
796                Some(1),
797                Some(-2),
798                Some(3),
799                Some(-4),
800                Some(5),
801                Some(-6),
802                Some(7),
803            ]
804            .iter()
805            .collect::<Int32Array>(),
806        )
807    }
808
809    #[test]
810    fn test_lag_with_default() -> Result<()> {
811        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
812        let shift_offset =
813            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
814        let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
815            as Arc<dyn PhysicalExpr>;
816
817        let input_exprs = &[expr, shift_offset, default_value];
818        let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
819            .into_iter()
820            .map(|d| Field::new("f", d, true))
821            .map(Arc::new)
822            .collect::<Vec<_>>();
823
824        test_i32_result(
825            WindowShift::lag(),
826            PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
827            [
828                Some(100),
829                Some(1),
830                Some(-2),
831                Some(3),
832                Some(-4),
833                Some(5),
834                Some(-6),
835                Some(7),
836            ]
837            .iter()
838            .collect::<Int32Array>(),
839        )
840    }
841}