1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29 TypeSignature, Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr::expressions;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use std::any::Any;
37use std::cmp::min;
38use std::collections::VecDeque;
39use std::hash::Hash;
40use std::ops::{Neg, Range};
41use std::sync::{Arc, LazyLock};
42
43get_or_init_udwf!(
44 Lag,
45 lag,
46 "Returns the row value that precedes the current row by a specified \
47 offset within partition. If no such row exists, then returns the \
48 default value.",
49 WindowShift::lag
50);
51get_or_init_udwf!(
52 Lead,
53 lead,
54 "Returns the value from a row that follows the current row by a \
55 specified offset within the partition. If no such row exists, then \
56 returns the default value.",
57 WindowShift::lead
58);
59
60pub fn lag(
67 arg: datafusion_expr::Expr,
68 shift_offset: Option<i64>,
69 default_value: Option<ScalarValue>,
70) -> datafusion_expr::Expr {
71 let shift_offset_lit = shift_offset
72 .map(|v| v.lit())
73 .unwrap_or(ScalarValue::Null.lit());
74 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
75
76 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
77}
78
79pub fn lead(
86 arg: datafusion_expr::Expr,
87 shift_offset: Option<i64>,
88 default_value: Option<ScalarValue>,
89) -> datafusion_expr::Expr {
90 let shift_offset_lit = shift_offset
91 .map(|v| v.lit())
92 .unwrap_or(ScalarValue::Null.lit());
93 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
94
95 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub enum WindowShiftKind {
100 Lag,
101 Lead,
102}
103
104impl WindowShiftKind {
105 fn name(&self) -> &'static str {
106 match self {
107 WindowShiftKind::Lag => "lag",
108 WindowShiftKind::Lead => "lead",
109 }
110 }
111
112 fn shift_offset(&self, value: Option<i64>) -> i64 {
116 match self {
117 WindowShiftKind::Lag => value.unwrap_or(1),
118 WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
119 }
120 }
121}
122
123#[derive(Debug, PartialEq, Eq, Hash)]
125pub struct WindowShift {
126 signature: Signature,
127 kind: WindowShiftKind,
128}
129
130impl WindowShift {
131 fn new(kind: WindowShiftKind) -> Self {
132 Self {
133 signature: Signature::one_of(
134 vec![
135 TypeSignature::Any(1),
136 TypeSignature::Any(2),
137 TypeSignature::Any(3),
138 ],
139 Volatility::Immutable,
140 ),
141 kind,
142 }
143 }
144
145 pub fn lag() -> Self {
146 Self::new(WindowShiftKind::Lag)
147 }
148
149 pub fn lead() -> Self {
150 Self::new(WindowShiftKind::Lead)
151 }
152
153 pub fn kind(&self) -> &WindowShiftKind {
154 &self.kind
155 }
156}
157
158static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
159 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
160 current row within the partition; if there is no such row, instead return default \
161 (which must be of the same type as value).", "lag(expression, offset, default)")
162 .with_argument("expression", "Expression to operate on")
163 .with_argument("offset", "Integer. Specifies how many rows back \
164 the value of expression should be retrieved. Defaults to 1.")
165 .with_argument("default", "The default value if the offset is \
166 not within the partition. Must be of the same type as expression.")
167 .with_sql_example(r#"
168```sql
169-- Example usage of the lag window function:
170SELECT employee_id,
171 salary,
172 lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
173FROM employees;
174
175+-------------+--------+-------------+
176| employee_id | salary | prev_salary |
177+-------------+--------+-------------+
178| 1 | 30000 | 0 |
179| 2 | 50000 | 30000 |
180| 3 | 70000 | 50000 |
181| 4 | 60000 | 70000 |
182+-------------+--------+-------------+
183```
184"#)
185 .build()
186});
187
188fn get_lag_doc() -> &'static Documentation {
189 &LAG_DOCUMENTATION
190}
191
192static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
193 Documentation::builder(DOC_SECTION_ANALYTICAL,
194 "Returns value evaluated at the row that is offset rows after the \
195 current row within the partition; if there is no such row, instead return default \
196 (which must be of the same type as value).",
197 "lead(expression, offset, default)")
198 .with_argument("expression", "Expression to operate on")
199 .with_argument("offset", "Integer. Specifies how many rows \
200 forward the value of expression should be retrieved. Defaults to 1.")
201 .with_argument("default", "The default value if the offset is \
202 not within the partition. Must be of the same type as expression.")
203 .with_sql_example(r#"
204```sql
205-- Example usage of lead window function:
206SELECT
207 employee_id,
208 department,
209 salary,
210 lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
211FROM employees;
212
213+-------------+-------------+--------+--------------+
214| employee_id | department | salary | next_salary |
215+-------------+-------------+--------+--------------+
216| 1 | Sales | 30000 | 50000 |
217| 2 | Sales | 50000 | 70000 |
218| 3 | Sales | 70000 | 0 |
219| 4 | Engineering | 40000 | 60000 |
220| 5 | Engineering | 60000 | 0 |
221+-------------+-------------+--------+--------------+
222```
223"#)
224 .build()
225});
226
227fn get_lead_doc() -> &'static Documentation {
228 &LEAD_DOCUMENTATION
229}
230
231impl WindowUDFImpl for WindowShift {
232 fn as_any(&self) -> &dyn Any {
233 self
234 }
235
236 fn name(&self) -> &str {
237 self.kind.name()
238 }
239
240 fn signature(&self) -> &Signature {
241 &self.signature
242 }
243
244 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
250 parse_expr(expr_args.input_exprs(), expr_args.input_fields())
251 .into_iter()
252 .collect::<Vec<_>>()
253 }
254
255 fn partition_evaluator(
256 &self,
257 partition_evaluator_args: PartitionEvaluatorArgs,
258 ) -> Result<Box<dyn PartitionEvaluator>> {
259 let shift_offset =
260 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
261 .map(get_signed_integer)
262 .map_or(Ok(None), |v| v.map(Some))
263 .map(|n| self.kind.shift_offset(n))
264 .map(|offset| {
265 if partition_evaluator_args.is_reversed() {
266 -offset
267 } else {
268 offset
269 }
270 })?;
271 let default_value = parse_default_value(
272 partition_evaluator_args.input_exprs(),
273 partition_evaluator_args.input_fields(),
274 )?;
275
276 Ok(Box::new(WindowShiftEvaluator {
277 shift_offset,
278 default_value,
279 ignore_nulls: partition_evaluator_args.ignore_nulls(),
280 non_null_offsets: VecDeque::new(),
281 }))
282 }
283
284 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
285 let return_field = parse_expr_field(field_args.input_fields())?;
286
287 Ok(return_field
288 .as_ref()
289 .clone()
290 .with_name(field_args.name())
291 .into())
292 }
293
294 fn reverse_expr(&self) -> ReversedUDWF {
295 match self.kind {
296 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
297 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
298 }
299 }
300
301 fn documentation(&self) -> Option<&Documentation> {
302 match self.kind {
303 WindowShiftKind::Lag => Some(get_lag_doc()),
304 WindowShiftKind::Lead => Some(get_lead_doc()),
305 }
306 }
307
308 fn limit_effect(&self, args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
309 if self.kind == WindowShiftKind::Lag {
310 return LimitEffect::None;
311 }
312 match args {
313 [_, expr, ..] => {
314 let Some(lit) = expr.as_any().downcast_ref::<expressions::Literal>()
315 else {
316 return LimitEffect::Unknown;
317 };
318 let ScalarValue::Int64(Some(amount)) = lit.value() else {
319 return LimitEffect::Unknown; };
321 LimitEffect::Relative((*amount).max(0) as usize)
322 }
323 [_] => LimitEffect::Relative(1), _ => LimitEffect::Unknown, }
326 }
327}
328
329fn parse_expr(
342 input_exprs: &[Arc<dyn PhysicalExpr>],
343 input_fields: &[FieldRef],
344) -> Result<Arc<dyn PhysicalExpr>> {
345 assert!(!input_exprs.is_empty());
346 assert!(!input_fields.is_empty());
347
348 let expr = Arc::clone(input_exprs.first().unwrap());
349 let expr_field = input_fields.first().unwrap();
350
351 if !expr_field.data_type().is_null() {
353 return Ok(expr);
354 }
355
356 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
357 default_value.map_or(Ok(expr), |value| {
358 ScalarValue::try_from(&value.data_type())
359 .map(|v| Arc::new(expressions::Literal::new(v)) as Arc<dyn PhysicalExpr>)
360 })
361}
362
363static NULL_FIELD: LazyLock<FieldRef> =
364 LazyLock::new(|| Field::new("value", DataType::Null, true).into());
365
366fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
371 assert!(!input_fields.is_empty());
372 let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
373
374 if !expr_field.data_type().is_null() {
376 return Ok(expr_field.as_ref().clone().with_nullable(true).into());
377 }
378
379 let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
380 Ok(default_value_field
381 .as_ref()
382 .clone()
383 .with_nullable(true)
384 .into())
385}
386
387fn parse_default_value(
390 input_exprs: &[Arc<dyn PhysicalExpr>],
391 input_types: &[FieldRef],
392) -> Result<ScalarValue> {
393 let expr_field = parse_expr_field(input_types)?;
394 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
395
396 unparsed
397 .filter(|v| !v.data_type().is_null())
398 .map(|v| v.cast_to(expr_field.data_type()))
399 .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
400}
401
402#[derive(Debug)]
403struct WindowShiftEvaluator {
404 shift_offset: i64,
405 default_value: ScalarValue,
406 ignore_nulls: bool,
407 non_null_offsets: VecDeque<usize>,
409}
410
411impl WindowShiftEvaluator {
412 fn is_lag(&self) -> bool {
413 self.shift_offset > 0
415 }
416}
417
418fn evaluate_all_with_ignore_null(
420 array: &ArrayRef,
421 offset: i64,
422 default_value: &ScalarValue,
423 is_lag: bool,
424) -> Result<ArrayRef, DataFusionError> {
425 let valid_indices: Vec<usize> =
426 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
427 let direction = !is_lag;
428 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
429 .map(|id| {
430 let result_index = match valid_indices.binary_search(&id) {
431 Ok(pos) => if direction {
432 pos.checked_add(offset as usize)
433 } else {
434 pos.checked_sub(offset.unsigned_abs() as usize)
435 }
436 .and_then(|new_pos| {
437 if new_pos < valid_indices.len() {
438 Some(valid_indices[new_pos])
439 } else {
440 None
441 }
442 }),
443 Err(pos) => if direction {
444 pos.checked_add(offset as usize)
445 } else if pos > 0 {
446 pos.checked_sub(offset.unsigned_abs() as usize)
447 } else {
448 None
449 }
450 .and_then(|new_pos| {
451 if new_pos < valid_indices.len() {
452 Some(valid_indices[new_pos])
453 } else {
454 None
455 }
456 }),
457 };
458
459 match result_index {
460 Some(index) => ScalarValue::try_from_array(array, index),
461 None => Ok(default_value.clone()),
462 }
463 })
464 .collect();
465
466 let new_array = new_array_results?;
467 ScalarValue::iter_to_array(new_array)
468}
469fn shift_with_default_value(
471 array: &ArrayRef,
472 offset: i64,
473 default_value: &ScalarValue,
474) -> Result<ArrayRef> {
475 use datafusion_common::arrow::compute::concat;
476
477 let value_len = array.len() as i64;
478 if offset == 0 {
479 Ok(Arc::clone(array))
480 } else if offset == i64::MIN || offset.abs() >= value_len {
481 default_value.to_array_of_size(value_len as usize)
482 } else {
483 let slice_offset = (-offset).clamp(0, value_len) as usize;
484 let length = array.len() - offset.unsigned_abs() as usize;
485 let slice = array.slice(slice_offset, length);
486
487 let nulls = offset.unsigned_abs() as usize;
489 let default_values = default_value.to_array_of_size(nulls)?;
490
491 if offset > 0 {
493 concat(&[default_values.as_ref(), slice.as_ref()])
494 .map_err(|e| arrow_datafusion_err!(e))
495 } else {
496 concat(&[slice.as_ref(), default_values.as_ref()])
497 .map_err(|e| arrow_datafusion_err!(e))
498 }
499 }
500}
501
502impl PartitionEvaluator for WindowShiftEvaluator {
503 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
504 if self.is_lag() {
505 let start = if self.non_null_offsets.len() == self.shift_offset as usize {
506 let offset: usize = self.non_null_offsets.iter().sum();
508 idx.saturating_sub(offset)
509 } else if !self.ignore_nulls {
510 let offset = self.shift_offset as usize;
511 idx.saturating_sub(offset)
512 } else {
513 0
514 };
515 let end = idx + 1;
516 Ok(Range { start, end })
517 } else {
518 let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
519 let offset: usize = self.non_null_offsets.iter().sum();
521 min(idx + offset + 1, n_rows)
522 } else if !self.ignore_nulls {
523 let offset = (-self.shift_offset) as usize;
524 min(idx + offset, n_rows)
525 } else {
526 n_rows
527 };
528 Ok(Range { start: idx, end })
529 }
530 }
531
532 fn is_causal(&self) -> bool {
533 self.is_lag()
535 }
536
537 fn evaluate(
538 &mut self,
539 values: &[ArrayRef],
540 range: &Range<usize>,
541 ) -> Result<ScalarValue> {
542 let array = &values[0];
543 let len = array.len();
544
545 let i = if self.is_lag() {
547 (range.end as i64 - self.shift_offset - 1) as usize
548 } else {
549 (range.start as i64 - self.shift_offset) as usize
551 };
552
553 let mut idx: Option<usize> = if i < len { Some(i) } else { None };
554
555 if self.ignore_nulls && self.is_lag() {
558 idx = if self.non_null_offsets.len() == self.shift_offset as usize {
561 let total_offset: usize = self.non_null_offsets.iter().sum();
562 Some(range.end - 1 - total_offset)
563 } else {
564 None
565 };
566
567 if array.is_valid(range.end - 1) {
569 self.non_null_offsets.push_back(1);
571 if self.non_null_offsets.len() > self.shift_offset as usize {
572 self.non_null_offsets.pop_front();
574 }
575 } else if !self.non_null_offsets.is_empty() {
576 let end_idx = self.non_null_offsets.len() - 1;
578 self.non_null_offsets[end_idx] += 1;
579 }
580 } else if self.ignore_nulls && !self.is_lag() {
581 let non_null_row_count = (-self.shift_offset) as usize;
584
585 if self.non_null_offsets.is_empty() {
586 let mut offset_val = 1;
588 for idx in range.start + 1..range.end {
589 if array.is_valid(idx) {
590 self.non_null_offsets.push_back(offset_val);
591 offset_val = 1;
592 } else {
593 offset_val += 1;
594 }
595 if self.non_null_offsets.len() == non_null_row_count + 1 {
598 break;
599 }
600 }
601 } else if range.end < len && array.is_valid(range.end) {
602 if array.is_valid(range.end) {
604 self.non_null_offsets.push_back(1);
606 } else {
607 let last_idx = self.non_null_offsets.len() - 1;
609 self.non_null_offsets[last_idx] += 1;
610 }
611 }
612
613 idx = if self.non_null_offsets.len() >= non_null_row_count {
615 let total_offset: usize =
616 self.non_null_offsets.iter().take(non_null_row_count).sum();
617 Some(range.start + total_offset)
618 } else {
619 None
620 };
621 if !self.non_null_offsets.is_empty() {
624 self.non_null_offsets[0] -= 1;
625 if self.non_null_offsets[0] == 0 {
626 self.non_null_offsets.pop_front();
628 }
629 }
630 }
631
632 #[allow(clippy::unnecessary_unwrap)]
638 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
639 ScalarValue::try_from_array(array, idx.unwrap())
640 } else {
641 Ok(self.default_value.clone())
642 }
643 }
644
645 fn evaluate_all(
646 &mut self,
647 values: &[ArrayRef],
648 _num_rows: usize,
649 ) -> Result<ArrayRef> {
650 let value = &values[0];
652 if !self.ignore_nulls {
653 shift_with_default_value(value, self.shift_offset, &self.default_value)
654 } else {
655 evaluate_all_with_ignore_null(
656 value,
657 self.shift_offset,
658 &self.default_value,
659 self.is_lag(),
660 )
661 }
662 }
663
664 fn supports_bounded_execution(&self) -> bool {
665 true
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use arrow::array::*;
673 use datafusion_common::cast::as_int32_array;
674 use datafusion_physical_expr::expressions::{Column, Literal};
675 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
676
677 fn test_i32_result(
678 expr: WindowShift,
679 partition_evaluator_args: PartitionEvaluatorArgs,
680 expected: Int32Array,
681 ) -> Result<()> {
682 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
683 let values = vec![arr];
684 let num_rows = values.len();
685 let result = expr
686 .partition_evaluator(partition_evaluator_args)?
687 .evaluate_all(&values, num_rows)?;
688 let result = as_int32_array(&result)?;
689 assert_eq!(expected, *result);
690 Ok(())
691 }
692
693 #[test]
694 fn lead_lag_get_range() -> Result<()> {
695 let lag_fn = WindowShiftEvaluator {
697 shift_offset: 2,
698 default_value: ScalarValue::Null,
699 ignore_nulls: false,
700 non_null_offsets: Default::default(),
701 };
702 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
703 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
704
705 let lag_fn = WindowShiftEvaluator {
707 shift_offset: 2,
708 default_value: ScalarValue::Null,
709 ignore_nulls: true,
710 non_null_offsets: vec![2, 2].into(), };
713 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
714
715 let lead_fn = WindowShiftEvaluator {
717 shift_offset: -2,
718 default_value: ScalarValue::Null,
719 ignore_nulls: false,
720 non_null_offsets: Default::default(),
721 };
722 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
723 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
724
725 let lead_fn = WindowShiftEvaluator {
727 shift_offset: -2,
728 default_value: ScalarValue::Null,
729 ignore_nulls: true,
730 non_null_offsets: vec![2, 2].into(),
732 };
733 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
734
735 Ok(())
736 }
737
738 #[test]
739 fn test_lead_window_shift() -> Result<()> {
740 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
741
742 test_i32_result(
743 WindowShift::lead(),
744 PartitionEvaluatorArgs::new(
745 &[expr],
746 &[Field::new("f", DataType::Int32, true).into()],
747 false,
748 false,
749 ),
750 [
751 Some(-2),
752 Some(3),
753 Some(-4),
754 Some(5),
755 Some(-6),
756 Some(7),
757 Some(8),
758 None,
759 ]
760 .iter()
761 .collect::<Int32Array>(),
762 )
763 }
764
765 #[test]
766 fn test_lag_window_shift() -> Result<()> {
767 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
768
769 test_i32_result(
770 WindowShift::lag(),
771 PartitionEvaluatorArgs::new(
772 &[expr],
773 &[Field::new("f", DataType::Int32, true).into()],
774 false,
775 false,
776 ),
777 [
778 None,
779 Some(1),
780 Some(-2),
781 Some(3),
782 Some(-4),
783 Some(5),
784 Some(-6),
785 Some(7),
786 ]
787 .iter()
788 .collect::<Int32Array>(),
789 )
790 }
791
792 #[test]
793 fn test_lag_with_default() -> Result<()> {
794 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
795 let shift_offset =
796 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
797 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
798 as Arc<dyn PhysicalExpr>;
799
800 let input_exprs = &[expr, shift_offset, default_value];
801 let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
802 .into_iter()
803 .map(|d| Field::new("f", d, true))
804 .map(Arc::new)
805 .collect::<Vec<_>>();
806
807 test_i32_result(
808 WindowShift::lag(),
809 PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
810 [
811 Some(100),
812 Some(1),
813 Some(-2),
814 Some(3),
815 Some(-4),
816 Some(5),
817 Some(-6),
818 Some(7),
819 ]
820 .iter()
821 .collect::<Int32Array>(),
822 )
823 }
824}