1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
29 TypeSignature, Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr::expressions;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use std::any::Any;
37use std::cmp::min;
38use std::collections::VecDeque;
39use std::hash::Hash;
40use std::ops::{Neg, Range};
41use std::sync::{Arc, LazyLock};
42
43get_or_init_udwf!(
44 Lag,
45 lag,
46 "Returns the row value that precedes the current row by a specified \
47 offset within partition. If no such row exists, then returns the \
48 default value.",
49 WindowShift::lag
50);
51get_or_init_udwf!(
52 Lead,
53 lead,
54 "Returns the value from a row that follows the current row by a \
55 specified offset within the partition. If no such row exists, then \
56 returns the default value.",
57 WindowShift::lead
58);
59
60pub fn lag(
67 arg: datafusion_expr::Expr,
68 shift_offset: Option<i64>,
69 default_value: Option<ScalarValue>,
70) -> datafusion_expr::Expr {
71 let shift_offset_lit = shift_offset
72 .map(|v| v.lit())
73 .unwrap_or(ScalarValue::Null.lit());
74 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
75
76 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
77}
78
79pub fn lead(
86 arg: datafusion_expr::Expr,
87 shift_offset: Option<i64>,
88 default_value: Option<ScalarValue>,
89) -> datafusion_expr::Expr {
90 let shift_offset_lit = shift_offset
91 .map(|v| v.lit())
92 .unwrap_or(ScalarValue::Null.lit());
93 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
94
95 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub enum WindowShiftKind {
100 Lag,
101 Lead,
102}
103
104impl WindowShiftKind {
105 fn name(&self) -> &'static str {
106 match self {
107 WindowShiftKind::Lag => "lag",
108 WindowShiftKind::Lead => "lead",
109 }
110 }
111
112 fn shift_offset(&self, value: Option<i64>) -> i64 {
116 match self {
117 WindowShiftKind::Lag => value.unwrap_or(1),
118 WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
119 }
120 }
121}
122
123#[derive(Debug, PartialEq, Eq, Hash)]
125pub struct WindowShift {
126 signature: Signature,
127 kind: WindowShiftKind,
128}
129
130impl WindowShift {
131 fn new(kind: WindowShiftKind) -> Self {
132 Self {
133 signature: Signature::one_of(
134 vec![
135 TypeSignature::Any(1),
136 TypeSignature::Any(2),
137 TypeSignature::Any(3),
138 ],
139 Volatility::Immutable,
140 )
141 .with_parameter_names(vec![
142 "expr".to_string(),
143 "offset".to_string(),
144 "default".to_string(),
145 ])
146 .expect("valid parameter names for lead/lag"),
147 kind,
148 }
149 }
150
151 pub fn lag() -> Self {
152 Self::new(WindowShiftKind::Lag)
153 }
154
155 pub fn lead() -> Self {
156 Self::new(WindowShiftKind::Lead)
157 }
158
159 pub fn kind(&self) -> &WindowShiftKind {
160 &self.kind
161 }
162}
163
164static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
165 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
166 current row within the partition; if there is no such row, instead return default \
167 (which must be of the same type as value).", "lag(expression, offset, default)")
168 .with_argument("expression", "Expression to operate on")
169 .with_argument("offset", "Integer. Specifies how many rows back \
170 the value of expression should be retrieved. Defaults to 1.")
171 .with_argument("default", "The default value if the offset is \
172 not within the partition. Must be of the same type as expression.")
173 .with_sql_example(r#"
174```sql
175-- Example usage of the lag window function:
176SELECT employee_id,
177 salary,
178 lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
179FROM employees;
180
181+-------------+--------+-------------+
182| employee_id | salary | prev_salary |
183+-------------+--------+-------------+
184| 1 | 30000 | 0 |
185| 2 | 50000 | 30000 |
186| 3 | 70000 | 50000 |
187| 4 | 60000 | 70000 |
188+-------------+--------+-------------+
189```
190"#)
191 .build()
192});
193
194fn get_lag_doc() -> &'static Documentation {
195 &LAG_DOCUMENTATION
196}
197
198static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
199 Documentation::builder(DOC_SECTION_ANALYTICAL,
200 "Returns value evaluated at the row that is offset rows after the \
201 current row within the partition; if there is no such row, instead return default \
202 (which must be of the same type as value).",
203 "lead(expression, offset, default)")
204 .with_argument("expression", "Expression to operate on")
205 .with_argument("offset", "Integer. Specifies how many rows \
206 forward the value of expression should be retrieved. Defaults to 1.")
207 .with_argument("default", "The default value if the offset is \
208 not within the partition. Must be of the same type as expression.")
209 .with_sql_example(r#"
210```sql
211-- Example usage of lead window function:
212SELECT
213 employee_id,
214 department,
215 salary,
216 lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
217FROM employees;
218
219+-------------+-------------+--------+--------------+
220| employee_id | department | salary | next_salary |
221+-------------+-------------+--------+--------------+
222| 1 | Sales | 30000 | 50000 |
223| 2 | Sales | 50000 | 70000 |
224| 3 | Sales | 70000 | 0 |
225| 4 | Engineering | 40000 | 60000 |
226| 5 | Engineering | 60000 | 0 |
227+-------------+-------------+--------+--------------+
228```
229"#)
230 .build()
231});
232
233fn get_lead_doc() -> &'static Documentation {
234 &LEAD_DOCUMENTATION
235}
236
237impl WindowUDFImpl for WindowShift {
238 fn as_any(&self) -> &dyn Any {
239 self
240 }
241
242 fn name(&self) -> &str {
243 self.kind.name()
244 }
245
246 fn signature(&self) -> &Signature {
247 &self.signature
248 }
249
250 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
256 parse_expr(expr_args.input_exprs(), expr_args.input_fields())
257 .into_iter()
258 .collect::<Vec<_>>()
259 }
260
261 fn partition_evaluator(
262 &self,
263 partition_evaluator_args: PartitionEvaluatorArgs,
264 ) -> Result<Box<dyn PartitionEvaluator>> {
265 let shift_offset =
266 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
267 .map(get_signed_integer)
268 .map_or(Ok(None), |v| v.map(Some))
269 .map(|n| self.kind.shift_offset(n))
270 .map(|offset| {
271 if partition_evaluator_args.is_reversed() {
272 -offset
273 } else {
274 offset
275 }
276 })?;
277 let default_value = parse_default_value(
278 partition_evaluator_args.input_exprs(),
279 partition_evaluator_args.input_fields(),
280 )?;
281
282 Ok(Box::new(WindowShiftEvaluator {
283 shift_offset,
284 default_value,
285 ignore_nulls: partition_evaluator_args.ignore_nulls(),
286 non_null_offsets: VecDeque::new(),
287 }))
288 }
289
290 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
291 let return_field = parse_expr_field(field_args.input_fields())?;
292
293 Ok(return_field
294 .as_ref()
295 .clone()
296 .with_name(field_args.name())
297 .into())
298 }
299
300 fn reverse_expr(&self) -> ReversedUDWF {
301 match self.kind {
302 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
303 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
304 }
305 }
306
307 fn documentation(&self) -> Option<&Documentation> {
308 match self.kind {
309 WindowShiftKind::Lag => Some(get_lag_doc()),
310 WindowShiftKind::Lead => Some(get_lead_doc()),
311 }
312 }
313
314 fn limit_effect(&self, args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
315 if self.kind == WindowShiftKind::Lag {
316 return LimitEffect::None;
317 }
318 match args {
319 [_, expr, ..] => {
320 let Some(lit) = expr.as_any().downcast_ref::<expressions::Literal>()
321 else {
322 return LimitEffect::Unknown;
323 };
324 let ScalarValue::Int64(Some(amount)) = lit.value() else {
325 return LimitEffect::Unknown; };
327 LimitEffect::Relative((*amount).max(0) as usize)
328 }
329 [_] => LimitEffect::Relative(1), _ => LimitEffect::Unknown, }
332 }
333}
334
335fn parse_expr(
348 input_exprs: &[Arc<dyn PhysicalExpr>],
349 input_fields: &[FieldRef],
350) -> Result<Arc<dyn PhysicalExpr>> {
351 assert!(!input_exprs.is_empty());
352 assert!(!input_fields.is_empty());
353
354 let expr = Arc::clone(input_exprs.first().unwrap());
355 let expr_field = input_fields.first().unwrap();
356
357 if !expr_field.data_type().is_null() {
359 return Ok(expr);
360 }
361
362 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
363 default_value.map_or(Ok(expr), |value| {
364 ScalarValue::try_from(&value.data_type())
365 .map(|v| Arc::new(expressions::Literal::new(v)) as Arc<dyn PhysicalExpr>)
366 })
367}
368
369static NULL_FIELD: LazyLock<FieldRef> =
370 LazyLock::new(|| Field::new("value", DataType::Null, true).into());
371
372fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
377 assert!(!input_fields.is_empty());
378 let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
379
380 if !expr_field.data_type().is_null() {
382 return Ok(expr_field.as_ref().clone().with_nullable(true).into());
383 }
384
385 let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
386 Ok(default_value_field
387 .as_ref()
388 .clone()
389 .with_nullable(true)
390 .into())
391}
392
393fn parse_default_value(
396 input_exprs: &[Arc<dyn PhysicalExpr>],
397 input_types: &[FieldRef],
398) -> Result<ScalarValue> {
399 let expr_field = parse_expr_field(input_types)?;
400 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
401
402 unparsed
403 .filter(|v| !v.data_type().is_null())
404 .map(|v| v.cast_to(expr_field.data_type()))
405 .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
406}
407
408#[derive(Debug)]
409struct WindowShiftEvaluator {
410 shift_offset: i64,
411 default_value: ScalarValue,
412 ignore_nulls: bool,
413 non_null_offsets: VecDeque<usize>,
415}
416
417impl WindowShiftEvaluator {
418 fn is_lag(&self) -> bool {
419 self.shift_offset > 0
421 }
422}
423
424fn evaluate_all_with_ignore_null(
426 array: &ArrayRef,
427 offset: i64,
428 default_value: &ScalarValue,
429 is_lag: bool,
430) -> Result<ArrayRef, DataFusionError> {
431 let valid_indices: Vec<usize> =
432 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
433 let direction = !is_lag;
434 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
435 .map(|id| {
436 let result_index = match valid_indices.binary_search(&id) {
437 Ok(pos) => if direction {
438 pos.checked_add(offset as usize)
439 } else {
440 pos.checked_sub(offset.unsigned_abs() as usize)
441 }
442 .and_then(|new_pos| {
443 if new_pos < valid_indices.len() {
444 Some(valid_indices[new_pos])
445 } else {
446 None
447 }
448 }),
449 Err(pos) => if direction {
450 pos.checked_add(offset as usize)
451 } else if pos > 0 {
452 pos.checked_sub(offset.unsigned_abs() as usize)
453 } else {
454 None
455 }
456 .and_then(|new_pos| {
457 if new_pos < valid_indices.len() {
458 Some(valid_indices[new_pos])
459 } else {
460 None
461 }
462 }),
463 };
464
465 match result_index {
466 Some(index) => ScalarValue::try_from_array(array, index),
467 None => Ok(default_value.clone()),
468 }
469 })
470 .collect();
471
472 let new_array = new_array_results?;
473 ScalarValue::iter_to_array(new_array)
474}
475fn shift_with_default_value(
477 array: &ArrayRef,
478 offset: i64,
479 default_value: &ScalarValue,
480) -> Result<ArrayRef> {
481 use datafusion_common::arrow::compute::concat;
482
483 let value_len = array.len() as i64;
484 if offset == 0 {
485 Ok(Arc::clone(array))
486 } else if offset == i64::MIN || offset.abs() >= value_len {
487 default_value.to_array_of_size(value_len as usize)
488 } else {
489 let slice_offset = (-offset).clamp(0, value_len) as usize;
490 let length = array.len() - offset.unsigned_abs() as usize;
491 let slice = array.slice(slice_offset, length);
492
493 let nulls = offset.unsigned_abs() as usize;
495 let default_values = default_value.to_array_of_size(nulls)?;
496
497 if offset > 0 {
499 concat(&[default_values.as_ref(), slice.as_ref()])
500 .map_err(|e| arrow_datafusion_err!(e))
501 } else {
502 concat(&[slice.as_ref(), default_values.as_ref()])
503 .map_err(|e| arrow_datafusion_err!(e))
504 }
505 }
506}
507
508impl PartitionEvaluator for WindowShiftEvaluator {
509 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
510 if self.is_lag() {
511 let start = if self.non_null_offsets.len() == self.shift_offset as usize {
512 let offset: usize = self.non_null_offsets.iter().sum();
514 idx.saturating_sub(offset)
515 } else if !self.ignore_nulls {
516 let offset = self.shift_offset as usize;
517 idx.saturating_sub(offset)
518 } else {
519 0
520 };
521 let end = idx + 1;
522 Ok(Range { start, end })
523 } else {
524 let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
525 let offset: usize = self.non_null_offsets.iter().sum();
527 min(idx + offset + 1, n_rows)
528 } else if !self.ignore_nulls {
529 let offset = (-self.shift_offset) as usize;
530 min(idx + offset, n_rows)
531 } else {
532 n_rows
533 };
534 Ok(Range { start: idx, end })
535 }
536 }
537
538 fn is_causal(&self) -> bool {
539 self.is_lag()
541 }
542
543 fn evaluate(
544 &mut self,
545 values: &[ArrayRef],
546 range: &Range<usize>,
547 ) -> Result<ScalarValue> {
548 let array = &values[0];
549 let len = array.len();
550
551 let i = if self.is_lag() {
553 (range.end as i64 - self.shift_offset - 1) as usize
554 } else {
555 (range.start as i64 - self.shift_offset) as usize
557 };
558
559 let mut idx: Option<usize> = if i < len { Some(i) } else { None };
560
561 if self.ignore_nulls && self.is_lag() {
564 idx = if self.non_null_offsets.len() == self.shift_offset as usize {
567 let total_offset: usize = self.non_null_offsets.iter().sum();
568 Some(range.end - 1 - total_offset)
569 } else {
570 None
571 };
572
573 if array.is_valid(range.end - 1) {
575 self.non_null_offsets.push_back(1);
577 if self.non_null_offsets.len() > self.shift_offset as usize {
578 self.non_null_offsets.pop_front();
580 }
581 } else if !self.non_null_offsets.is_empty() {
582 let end_idx = self.non_null_offsets.len() - 1;
584 self.non_null_offsets[end_idx] += 1;
585 }
586 } else if self.ignore_nulls && !self.is_lag() {
587 let non_null_row_count = (-self.shift_offset) as usize;
590
591 if self.non_null_offsets.is_empty() {
592 let mut offset_val = 1;
594 for idx in range.start + 1..range.end {
595 if array.is_valid(idx) {
596 self.non_null_offsets.push_back(offset_val);
597 offset_val = 1;
598 } else {
599 offset_val += 1;
600 }
601 if self.non_null_offsets.len() == non_null_row_count + 1 {
604 break;
605 }
606 }
607 } else if range.end < len && array.is_valid(range.end) {
608 if array.is_valid(range.end) {
610 self.non_null_offsets.push_back(1);
612 } else {
613 let last_idx = self.non_null_offsets.len() - 1;
615 self.non_null_offsets[last_idx] += 1;
616 }
617 }
618
619 idx = if self.non_null_offsets.len() >= non_null_row_count {
621 let total_offset: usize =
622 self.non_null_offsets.iter().take(non_null_row_count).sum();
623 Some(range.start + total_offset)
624 } else {
625 None
626 };
627 if !self.non_null_offsets.is_empty() {
630 self.non_null_offsets[0] -= 1;
631 if self.non_null_offsets[0] == 0 {
632 self.non_null_offsets.pop_front();
634 }
635 }
636 }
637
638 #[allow(clippy::unnecessary_unwrap)]
644 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
645 ScalarValue::try_from_array(array, idx.unwrap())
646 } else {
647 Ok(self.default_value.clone())
648 }
649 }
650
651 fn evaluate_all(
652 &mut self,
653 values: &[ArrayRef],
654 _num_rows: usize,
655 ) -> Result<ArrayRef> {
656 let value = &values[0];
658 if !self.ignore_nulls {
659 shift_with_default_value(value, self.shift_offset, &self.default_value)
660 } else {
661 evaluate_all_with_ignore_null(
662 value,
663 self.shift_offset,
664 &self.default_value,
665 self.is_lag(),
666 )
667 }
668 }
669
670 fn supports_bounded_execution(&self) -> bool {
671 true
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use arrow::array::*;
679 use datafusion_common::cast::as_int32_array;
680 use datafusion_physical_expr::expressions::{Column, Literal};
681 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
682
683 fn test_i32_result(
684 expr: WindowShift,
685 partition_evaluator_args: PartitionEvaluatorArgs,
686 expected: Int32Array,
687 ) -> Result<()> {
688 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
689 let values = vec![arr];
690 let num_rows = values.len();
691 let result = expr
692 .partition_evaluator(partition_evaluator_args)?
693 .evaluate_all(&values, num_rows)?;
694 let result = as_int32_array(&result)?;
695 assert_eq!(expected, *result);
696 Ok(())
697 }
698
699 #[test]
700 fn lead_lag_get_range() -> Result<()> {
701 let lag_fn = WindowShiftEvaluator {
703 shift_offset: 2,
704 default_value: ScalarValue::Null,
705 ignore_nulls: false,
706 non_null_offsets: Default::default(),
707 };
708 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
709 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
710
711 let lag_fn = WindowShiftEvaluator {
713 shift_offset: 2,
714 default_value: ScalarValue::Null,
715 ignore_nulls: true,
716 non_null_offsets: vec![2, 2].into(), };
719 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
720
721 let lead_fn = WindowShiftEvaluator {
723 shift_offset: -2,
724 default_value: ScalarValue::Null,
725 ignore_nulls: false,
726 non_null_offsets: Default::default(),
727 };
728 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
729 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
730
731 let lead_fn = WindowShiftEvaluator {
733 shift_offset: -2,
734 default_value: ScalarValue::Null,
735 ignore_nulls: true,
736 non_null_offsets: vec![2, 2].into(),
738 };
739 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
740
741 Ok(())
742 }
743
744 #[test]
745 fn test_lead_window_shift() -> Result<()> {
746 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
747
748 test_i32_result(
749 WindowShift::lead(),
750 PartitionEvaluatorArgs::new(
751 &[expr],
752 &[Field::new("f", DataType::Int32, true).into()],
753 false,
754 false,
755 ),
756 [
757 Some(-2),
758 Some(3),
759 Some(-4),
760 Some(5),
761 Some(-6),
762 Some(7),
763 Some(8),
764 None,
765 ]
766 .iter()
767 .collect::<Int32Array>(),
768 )
769 }
770
771 #[test]
772 fn test_lag_window_shift() -> Result<()> {
773 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
774
775 test_i32_result(
776 WindowShift::lag(),
777 PartitionEvaluatorArgs::new(
778 &[expr],
779 &[Field::new("f", DataType::Int32, true).into()],
780 false,
781 false,
782 ),
783 [
784 None,
785 Some(1),
786 Some(-2),
787 Some(3),
788 Some(-4),
789 Some(5),
790 Some(-6),
791 Some(7),
792 ]
793 .iter()
794 .collect::<Int32Array>(),
795 )
796 }
797
798 #[test]
799 fn test_lag_with_default() -> Result<()> {
800 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
801 let shift_offset =
802 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
803 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
804 as Arc<dyn PhysicalExpr>;
805
806 let input_exprs = &[expr, shift_offset, default_value];
807 let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
808 .into_iter()
809 .map(|d| Field::new("f", d, true))
810 .map(Arc::new)
811 .collect::<Vec<_>>();
812
813 test_i32_result(
814 WindowShift::lag(),
815 PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
816 [
817 Some(100),
818 Some(1),
819 Some(-2),
820 Some(3),
821 Some(-4),
822 Some(5),
823 Some(-6),
824 Some(7),
825 ]
826 .iter()
827 .collect::<Int32Array>(),
828 )
829 }
830}