1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{DataFusionError, Result, ScalarValue, arrow_datafusion_err};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29 TypeSignature, Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr::expressions;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::hash::Hash;
39use std::ops::Range;
40use std::sync::{Arc, LazyLock};
41
42get_or_init_udwf!(
43 Lag,
44 lag,
45 lag_udwf,
46 "Returns the row value that precedes the current row by a specified \
47 offset within partition. If no such row exists, then returns the \
48 default value.",
49 WindowShift::lag
50);
51get_or_init_udwf!(
52 Lead,
53 lead,
54 lead_udwf,
55 "Returns the value from a row that follows the current row by a \
56 specified offset within the partition. If no such row exists, then \
57 returns the default value.",
58 WindowShift::lead
59);
60
61pub fn lag(
68 arg: datafusion_expr::Expr,
69 shift_offset: Option<i64>,
70 default_value: Option<ScalarValue>,
71) -> datafusion_expr::Expr {
72 let shift_offset_lit = shift_offset
73 .map(|v| v.lit())
74 .unwrap_or(ScalarValue::Null.lit());
75 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
76
77 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
78}
79
80pub fn lead(
87 arg: datafusion_expr::Expr,
88 shift_offset: Option<i64>,
89 default_value: Option<ScalarValue>,
90) -> datafusion_expr::Expr {
91 let shift_offset_lit = shift_offset
92 .map(|v| v.lit())
93 .unwrap_or(ScalarValue::Null.lit());
94 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
95
96 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
97}
98
99#[derive(Debug, PartialEq, Eq, Hash)]
100pub enum WindowShiftKind {
101 Lag,
102 Lead,
103}
104
105impl WindowShiftKind {
106 fn name(&self) -> &'static str {
107 match self {
108 WindowShiftKind::Lag => "lag",
109 WindowShiftKind::Lead => "lead",
110 }
111 }
112
113 fn shift_offset(&self, value: Option<i64>) -> i64 {
117 match self {
118 WindowShiftKind::Lag => value.unwrap_or(1),
119 WindowShiftKind::Lead => value.map_or(-1, |v| v.wrapping_neg()),
120 }
121 }
122}
123
124#[derive(Debug, PartialEq, Eq, Hash)]
126pub struct WindowShift {
127 signature: Signature,
128 kind: WindowShiftKind,
129}
130
131impl WindowShift {
132 fn new(kind: WindowShiftKind) -> Self {
133 Self {
134 signature: Signature::one_of(
135 vec![
136 TypeSignature::Any(1),
137 TypeSignature::Any(2),
138 TypeSignature::Any(3),
139 ],
140 Volatility::Immutable,
141 )
142 .with_parameter_names(vec![
143 "expr".to_string(),
144 "offset".to_string(),
145 "default".to_string(),
146 ])
147 .expect("valid parameter names for lead/lag"),
148 kind,
149 }
150 }
151
152 pub fn lag() -> Self {
153 Self::new(WindowShiftKind::Lag)
154 }
155
156 pub fn lead() -> Self {
157 Self::new(WindowShiftKind::Lead)
158 }
159
160 pub fn kind(&self) -> &WindowShiftKind {
161 &self.kind
162 }
163}
164
165static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
166 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
167 current row within the partition; if there is no such row, instead return default \
168 (which must be of the same type as value).", "lag(expression, offset, default)")
169 .with_argument("expression", "Expression to operate on")
170 .with_argument("offset", "Integer. Specifies how many rows back \
171 the value of expression should be retrieved. Defaults to 1.")
172 .with_argument("default", "The default value if the offset is \
173 not within the partition. Must be of the same type as expression.")
174 .with_sql_example(r#"
175```sql
176-- Example usage of the lag window function:
177SELECT employee_id,
178 salary,
179 lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
180FROM employees;
181
182+-------------+--------+-------------+
183| employee_id | salary | prev_salary |
184+-------------+--------+-------------+
185| 1 | 30000 | 0 |
186| 2 | 50000 | 30000 |
187| 3 | 70000 | 50000 |
188| 4 | 60000 | 70000 |
189+-------------+--------+-------------+
190```
191"#)
192 .build()
193});
194
195fn get_lag_doc() -> &'static Documentation {
196 &LAG_DOCUMENTATION
197}
198
199static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
200 Documentation::builder(DOC_SECTION_ANALYTICAL,
201 "Returns value evaluated at the row that is offset rows after the \
202 current row within the partition; if there is no such row, instead return default \
203 (which must be of the same type as value).",
204 "lead(expression, offset, default)")
205 .with_argument("expression", "Expression to operate on")
206 .with_argument("offset", "Integer. Specifies how many rows \
207 forward the value of expression should be retrieved. Defaults to 1.")
208 .with_argument("default", "The default value if the offset is \
209 not within the partition. Must be of the same type as expression.")
210 .with_sql_example(r#"
211```sql
212-- Example usage of lead window function:
213SELECT
214 employee_id,
215 department,
216 salary,
217 lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
218FROM employees;
219
220+-------------+-------------+--------+--------------+
221| employee_id | department | salary | next_salary |
222+-------------+-------------+--------+--------------+
223| 1 | Sales | 30000 | 50000 |
224| 2 | Sales | 50000 | 70000 |
225| 3 | Sales | 70000 | 0 |
226| 4 | Engineering | 40000 | 60000 |
227| 5 | Engineering | 60000 | 0 |
228+-------------+-------------+--------+--------------+
229```
230"#)
231 .build()
232});
233
234fn get_lead_doc() -> &'static Documentation {
235 &LEAD_DOCUMENTATION
236}
237
238impl WindowUDFImpl for WindowShift {
239 fn name(&self) -> &str {
240 self.kind.name()
241 }
242
243 fn signature(&self) -> &Signature {
244 &self.signature
245 }
246
247 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
253 parse_expr(expr_args.input_exprs(), expr_args.input_fields())
254 .into_iter()
255 .collect::<Vec<_>>()
256 }
257
258 fn partition_evaluator(
259 &self,
260 partition_evaluator_args: PartitionEvaluatorArgs,
261 ) -> Result<Box<dyn PartitionEvaluator>> {
262 let shift_offset =
263 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
264 .map(|v| get_signed_integer(&v))
265 .map_or(Ok(None), |v| v.map(Some))
266 .map(|n| self.kind.shift_offset(n))
267 .map(|offset| {
268 if partition_evaluator_args.is_reversed() {
269 offset.wrapping_neg()
270 } else {
271 offset
272 }
273 })?;
274 let default_value = parse_default_value(
275 partition_evaluator_args.input_exprs(),
276 partition_evaluator_args.input_fields(),
277 )?;
278
279 Ok(Box::new(WindowShiftEvaluator {
280 shift_offset,
281 default_value,
282 ignore_nulls: partition_evaluator_args.ignore_nulls(),
283 non_null_offsets: VecDeque::new(),
284 }))
285 }
286
287 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
288 let return_field = parse_expr_field(field_args.input_fields())?;
289
290 Ok(return_field
291 .as_ref()
292 .clone()
293 .with_name(field_args.name())
294 .into())
295 }
296
297 fn reverse_expr(&self) -> ReversedUDWF {
298 match self.kind {
299 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
300 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
301 }
302 }
303
304 fn documentation(&self) -> Option<&Documentation> {
305 match self.kind {
306 WindowShiftKind::Lag => Some(get_lag_doc()),
307 WindowShiftKind::Lead => Some(get_lead_doc()),
308 }
309 }
310
311 fn limit_effect(&self, args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
312 if self.kind == WindowShiftKind::Lag {
313 return LimitEffect::None;
314 }
315 match args {
316 [_, expr, ..] => {
317 let Some(lit) = expr.downcast_ref::<expressions::Literal>() else {
318 return LimitEffect::Unknown;
319 };
320 let ScalarValue::Int64(Some(amount)) = lit.value() else {
321 return LimitEffect::Unknown; };
323 LimitEffect::Relative((*amount).max(0) as usize)
324 }
325 [_] => LimitEffect::Relative(1), _ => LimitEffect::Unknown, }
328 }
329}
330
331fn parse_expr(
344 input_exprs: &[Arc<dyn PhysicalExpr>],
345 input_fields: &[FieldRef],
346) -> Result<Arc<dyn PhysicalExpr>> {
347 assert!(!input_exprs.is_empty());
348 assert!(!input_fields.is_empty());
349
350 let expr = Arc::clone(input_exprs.first().unwrap());
351 let expr_field = input_fields.first().unwrap();
352
353 if !expr_field.data_type().is_null() {
355 return Ok(expr);
356 }
357
358 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
359 default_value.map_or(Ok(expr), |value| {
360 ScalarValue::try_from(&value.data_type())
361 .map(|v| Arc::new(expressions::Literal::new(v)) as Arc<dyn PhysicalExpr>)
362 })
363}
364
365static NULL_FIELD: LazyLock<FieldRef> =
366 LazyLock::new(|| Field::new("value", DataType::Null, true).into());
367
368fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
373 assert!(!input_fields.is_empty());
374 let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
375
376 if !expr_field.data_type().is_null() {
378 return Ok(expr_field.as_ref().clone().with_nullable(true).into());
379 }
380
381 let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
382 Ok(default_value_field
383 .as_ref()
384 .clone()
385 .with_nullable(true)
386 .into())
387}
388
389fn parse_default_value(
392 input_exprs: &[Arc<dyn PhysicalExpr>],
393 input_types: &[FieldRef],
394) -> Result<ScalarValue> {
395 let expr_field = parse_expr_field(input_types)?;
396 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
397
398 unparsed
399 .filter(|v| !v.data_type().is_null())
400 .map(|v| v.cast_to(expr_field.data_type()))
401 .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
402}
403
404#[derive(Debug)]
405struct WindowShiftEvaluator {
406 shift_offset: i64,
407 default_value: ScalarValue,
408 ignore_nulls: bool,
409 non_null_offsets: VecDeque<usize>,
411}
412
413fn offset_magnitude(offset: i64) -> usize {
414 let offset = offset.unsigned_abs();
415 if offset > usize::MAX as u64 {
416 usize::MAX
417 } else {
418 offset as usize
419 }
420}
421
422impl WindowShiftEvaluator {
423 fn is_lag(&self) -> bool {
424 self.shift_offset > 0
426 }
427}
428
429fn evaluate_all_with_ignore_null(
431 array: &ArrayRef,
432 offset: i64,
433 default_value: &ScalarValue,
434 is_lag: bool,
435) -> Result<ArrayRef, DataFusionError> {
436 let valid_indices: Vec<usize> =
437 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
438 let direction = !is_lag;
439 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
440 .map(|id| {
441 let result_index = match valid_indices.binary_search(&id) {
442 Ok(pos) => if direction {
443 pos.checked_add(offset as usize)
444 } else {
445 pos.checked_sub(offset.unsigned_abs() as usize)
446 }
447 .and_then(|new_pos| {
448 if new_pos < valid_indices.len() {
449 Some(valid_indices[new_pos])
450 } else {
451 None
452 }
453 }),
454 Err(pos) => if direction {
455 pos.checked_add(offset as usize)
456 } else if pos > 0 {
457 pos.checked_sub(offset.unsigned_abs() as usize)
458 } else {
459 None
460 }
461 .and_then(|new_pos| {
462 if new_pos < valid_indices.len() {
463 Some(valid_indices[new_pos])
464 } else {
465 None
466 }
467 }),
468 };
469
470 match result_index {
471 Some(index) => ScalarValue::try_from_array(array, index),
472 None => Ok(default_value.clone()),
473 }
474 })
475 .collect();
476
477 let new_array = new_array_results?;
478 ScalarValue::iter_to_array(new_array)
479}
480fn shift_with_default_value(
482 array: &ArrayRef,
483 offset: i64,
484 default_value: &ScalarValue,
485) -> Result<ArrayRef> {
486 use datafusion_common::arrow::compute::concat;
487
488 let value_len = array.len() as i64;
489 if offset == 0 {
490 Ok(Arc::clone(array))
491 } else if offset == i64::MIN || offset.abs() >= value_len {
492 default_value.to_array_of_size(value_len as usize)
493 } else {
494 let slice_offset = (-offset).clamp(0, value_len) as usize;
495 let length = array.len() - offset.unsigned_abs() as usize;
496 let slice = array.slice(slice_offset, length);
497
498 let nulls = offset.unsigned_abs() as usize;
500 let default_values = default_value.to_array_of_size(nulls)?;
501
502 if offset > 0 {
504 concat(&[default_values.as_ref(), slice.as_ref()])
505 .map_err(|e| arrow_datafusion_err!(e))
506 } else {
507 concat(&[slice.as_ref(), default_values.as_ref()])
508 .map_err(|e| arrow_datafusion_err!(e))
509 }
510 }
511}
512
513impl PartitionEvaluator for WindowShiftEvaluator {
514 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
515 let offset = offset_magnitude(self.shift_offset);
516
517 if self.is_lag() {
518 let start = if self.non_null_offsets.len() == offset {
519 let offset: usize = self.non_null_offsets.iter().sum();
521 idx.saturating_sub(offset)
522 } else if !self.ignore_nulls {
523 idx.saturating_sub(offset)
524 } else {
525 0
526 };
527 let end = idx + 1;
528 Ok(Range { start, end })
529 } else {
530 let end = if self.non_null_offsets.len() == offset {
531 let offset: usize = self.non_null_offsets.iter().sum();
533 min(idx.saturating_add(offset).saturating_add(1), n_rows)
534 } else if !self.ignore_nulls {
535 min(idx.saturating_add(offset), n_rows)
536 } else {
537 n_rows
538 };
539 Ok(Range { start: idx, end })
540 }
541 }
542
543 fn is_causal(&self) -> bool {
544 self.is_lag()
546 }
547
548 fn evaluate(
549 &mut self,
550 values: &[ArrayRef],
551 range: &Range<usize>,
552 ) -> Result<ScalarValue> {
553 let array = &values[0];
554 let len = array.len();
555
556 let i = if self.is_lag() {
558 range
559 .end
560 .checked_sub(1)
561 .and_then(|end| (end as i64).checked_sub(self.shift_offset))
562 .and_then(|value| usize::try_from(value).ok())
563 } else {
564 (range.start as i64)
566 .checked_sub(self.shift_offset)
567 .and_then(|value| usize::try_from(value).ok())
568 };
569 let mut idx: Option<usize> = i.filter(|i| *i < len);
570
571 if self.ignore_nulls && self.is_lag() {
574 let shift_offset = offset_magnitude(self.shift_offset);
577 idx = if self.non_null_offsets.len() == shift_offset {
578 let total_offset: usize = self.non_null_offsets.iter().sum();
579 Some(range.end - 1 - total_offset)
580 } else {
581 None
582 };
583
584 if array.is_valid(range.end - 1) {
586 self.non_null_offsets.push_back(1);
588 if self.non_null_offsets.len() > shift_offset {
589 self.non_null_offsets.pop_front();
591 }
592 } else if !self.non_null_offsets.is_empty() {
593 let end_idx = self.non_null_offsets.len() - 1;
595 self.non_null_offsets[end_idx] += 1;
596 }
597 } else if self.ignore_nulls && !self.is_lag() {
598 let non_null_row_count = offset_magnitude(self.shift_offset);
601
602 if self.non_null_offsets.is_empty() {
603 let mut offset_val = 1;
605 for idx in range.start + 1..range.end {
606 if array.is_valid(idx) {
607 self.non_null_offsets.push_back(offset_val);
608 offset_val = 1;
609 } else {
610 offset_val += 1;
611 }
612 if self.non_null_offsets.len() == non_null_row_count.saturating_add(1)
615 {
616 break;
617 }
618 }
619 } else if range.end < len && array.is_valid(range.end) {
620 if array.is_valid(range.end) {
622 self.non_null_offsets.push_back(1);
624 } else {
625 let last_idx = self.non_null_offsets.len() - 1;
627 self.non_null_offsets[last_idx] += 1;
628 }
629 }
630
631 idx = if self.non_null_offsets.len() >= non_null_row_count {
633 let total_offset: usize =
634 self.non_null_offsets.iter().take(non_null_row_count).sum();
635 Some(range.start + total_offset)
636 } else {
637 None
638 };
639 if !self.non_null_offsets.is_empty() {
642 self.non_null_offsets[0] -= 1;
643 if self.non_null_offsets[0] == 0 {
644 self.non_null_offsets.pop_front();
646 }
647 }
648 }
649
650 #[expect(clippy::unnecessary_unwrap)]
656 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
657 ScalarValue::try_from_array(array, idx.unwrap())
658 } else {
659 Ok(self.default_value.clone())
660 }
661 }
662
663 fn evaluate_all(
664 &mut self,
665 values: &[ArrayRef],
666 _num_rows: usize,
667 ) -> Result<ArrayRef> {
668 let value = &values[0];
670 if !self.ignore_nulls {
671 shift_with_default_value(value, self.shift_offset, &self.default_value)
672 } else {
673 evaluate_all_with_ignore_null(
674 value,
675 self.shift_offset,
676 &self.default_value,
677 self.is_lag(),
678 )
679 }
680 }
681
682 fn supports_bounded_execution(&self) -> bool {
683 true
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use arrow::array::*;
691 use datafusion_common::cast::as_int32_array;
692 use datafusion_physical_expr::expressions::{Column, Literal};
693
694 fn test_i32_result(
695 expr: WindowShift,
696 partition_evaluator_args: PartitionEvaluatorArgs,
697 expected: Int32Array,
698 ) -> Result<()> {
699 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
700 let values = vec![arr];
701 let num_rows = values.len();
702 let result = expr
703 .partition_evaluator(partition_evaluator_args)?
704 .evaluate_all(&values, num_rows)?;
705 let result = as_int32_array(&result)?;
706 assert_eq!(expected, *result);
707 Ok(())
708 }
709
710 #[test]
711 fn lead_lag_get_range() -> Result<()> {
712 let lag_fn = WindowShiftEvaluator {
714 shift_offset: 2,
715 default_value: ScalarValue::Null,
716 ignore_nulls: false,
717 non_null_offsets: Default::default(),
718 };
719 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
720 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
721
722 let lag_fn = WindowShiftEvaluator {
724 shift_offset: 2,
725 default_value: ScalarValue::Null,
726 ignore_nulls: true,
727 non_null_offsets: vec![2, 2].into(), };
730 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
731
732 let lead_fn = WindowShiftEvaluator {
734 shift_offset: -2,
735 default_value: ScalarValue::Null,
736 ignore_nulls: false,
737 non_null_offsets: Default::default(),
738 };
739 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
740 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
741
742 let lead_fn = WindowShiftEvaluator {
744 shift_offset: -2,
745 default_value: ScalarValue::Null,
746 ignore_nulls: true,
747 non_null_offsets: vec![2, 2].into(),
749 };
750 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
751
752 Ok(())
753 }
754
755 #[test]
756 fn test_lead_window_shift() -> Result<()> {
757 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
758
759 test_i32_result(
760 WindowShift::lead(),
761 PartitionEvaluatorArgs::new(
762 &[expr],
763 &[Field::new("f", DataType::Int32, true).into()],
764 false,
765 false,
766 ),
767 [
768 Some(-2),
769 Some(3),
770 Some(-4),
771 Some(5),
772 Some(-6),
773 Some(7),
774 Some(8),
775 None,
776 ]
777 .iter()
778 .collect::<Int32Array>(),
779 )
780 }
781
782 #[test]
783 fn test_lag_window_shift() -> Result<()> {
784 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
785
786 test_i32_result(
787 WindowShift::lag(),
788 PartitionEvaluatorArgs::new(
789 &[expr],
790 &[Field::new("f", DataType::Int32, true).into()],
791 false,
792 false,
793 ),
794 [
795 None,
796 Some(1),
797 Some(-2),
798 Some(3),
799 Some(-4),
800 Some(5),
801 Some(-6),
802 Some(7),
803 ]
804 .iter()
805 .collect::<Int32Array>(),
806 )
807 }
808
809 #[test]
810 fn test_lag_with_default() -> Result<()> {
811 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
812 let shift_offset =
813 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
814 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
815 as Arc<dyn PhysicalExpr>;
816
817 let input_exprs = &[expr, shift_offset, default_value];
818 let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
819 .into_iter()
820 .map(|d| Field::new("f", d, true))
821 .map(Arc::new)
822 .collect::<Vec<_>>();
823
824 test_i32_result(
825 WindowShift::lag(),
826 PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
827 [
828 Some(100),
829 Some(1),
830 Some(-2),
831 Some(3),
832 Some(-4),
833 Some(5),
834 Some(-6),
835 Some(7),
836 ]
837 .iter()
838 .collect::<Int32Array>(),
839 )
840 }
841}