1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
29 Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use std::any::Any;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::hash::{DefaultHasher, Hash, Hasher};
39use std::ops::{Neg, Range};
40use std::sync::{Arc, LazyLock};
41
42get_or_init_udwf!(
43 Lag,
44 lag,
45 "Returns the row value that precedes the current row by a specified \
46 offset within partition. If no such row exists, then returns the \
47 default value.",
48 WindowShift::lag
49);
50get_or_init_udwf!(
51 Lead,
52 lead,
53 "Returns the value from a row that follows the current row by a \
54 specified offset within the partition. If no such row exists, then \
55 returns the default value.",
56 WindowShift::lead
57);
58
59pub fn lag(
66 arg: datafusion_expr::Expr,
67 shift_offset: Option<i64>,
68 default_value: Option<ScalarValue>,
69) -> datafusion_expr::Expr {
70 let shift_offset_lit = shift_offset
71 .map(|v| v.lit())
72 .unwrap_or(ScalarValue::Null.lit());
73 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
74
75 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
76}
77
78pub fn lead(
85 arg: datafusion_expr::Expr,
86 shift_offset: Option<i64>,
87 default_value: Option<ScalarValue>,
88) -> datafusion_expr::Expr {
89 let shift_offset_lit = shift_offset
90 .map(|v| v.lit())
91 .unwrap_or(ScalarValue::Null.lit());
92 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
93
94 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
95}
96
97#[derive(Debug, PartialEq, Eq, Hash)]
98enum WindowShiftKind {
99 Lag,
100 Lead,
101}
102
103impl WindowShiftKind {
104 fn name(&self) -> &'static str {
105 match self {
106 WindowShiftKind::Lag => "lag",
107 WindowShiftKind::Lead => "lead",
108 }
109 }
110
111 fn shift_offset(&self, value: Option<i64>) -> i64 {
115 match self {
116 WindowShiftKind::Lag => value.unwrap_or(1),
117 WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
118 }
119 }
120}
121
122#[derive(Debug)]
124pub struct WindowShift {
125 signature: Signature,
126 kind: WindowShiftKind,
127}
128
129impl WindowShift {
130 fn new(kind: WindowShiftKind) -> Self {
131 Self {
132 signature: Signature::one_of(
133 vec![
134 TypeSignature::Any(1),
135 TypeSignature::Any(2),
136 TypeSignature::Any(3),
137 ],
138 Volatility::Immutable,
139 ),
140 kind,
141 }
142 }
143
144 pub fn lag() -> Self {
145 Self::new(WindowShiftKind::Lag)
146 }
147
148 pub fn lead() -> Self {
149 Self::new(WindowShiftKind::Lead)
150 }
151}
152
153static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
154 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
155 current row within the partition; if there is no such row, instead return default \
156 (which must be of the same type as value).", "lag(expression, offset, default)")
157 .with_argument("expression", "Expression to operate on")
158 .with_argument("offset", "Integer. Specifies how many rows back \
159 the value of expression should be retrieved. Defaults to 1.")
160 .with_argument("default", "The default value if the offset is \
161 not within the partition. Must be of the same type as expression.")
162 .with_sql_example(r#"```sql
163 --Example usage of the lag window function:
164 SELECT employee_id,
165 salary,
166 lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
167 FROM employees;
168```
169
170```sql
171+-------------+--------+-------------+
172| employee_id | salary | prev_salary |
173+-------------+--------+-------------+
174| 1 | 30000 | 0 |
175| 2 | 50000 | 30000 |
176| 3 | 70000 | 50000 |
177| 4 | 60000 | 70000 |
178+-------------+--------+-------------+
179```"#)
180 .build()
181});
182
183fn get_lag_doc() -> &'static Documentation {
184 &LAG_DOCUMENTATION
185}
186
187static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
188 Documentation::builder(DOC_SECTION_ANALYTICAL,
189 "Returns value evaluated at the row that is offset rows after the \
190 current row within the partition; if there is no such row, instead return default \
191 (which must be of the same type as value).",
192 "lead(expression, offset, default)")
193 .with_argument("expression", "Expression to operate on")
194 .with_argument("offset", "Integer. Specifies how many rows \
195 forward the value of expression should be retrieved. Defaults to 1.")
196 .with_argument("default", "The default value if the offset is \
197 not within the partition. Must be of the same type as expression.")
198 .with_sql_example(r#"```sql
199-- Example usage of lead() :
200SELECT
201 employee_id,
202 department,
203 salary,
204 lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
205FROM employees;
206```
207
208```sql
209+-------------+-------------+--------+--------------+
210| employee_id | department | salary | next_salary |
211+-------------+-------------+--------+--------------+
212| 1 | Sales | 30000 | 50000 |
213| 2 | Sales | 50000 | 70000 |
214| 3 | Sales | 70000 | 0 |
215| 4 | Engineering | 40000 | 60000 |
216| 5 | Engineering | 60000 | 0 |
217+-------------+-------------+--------+--------------+
218```"#)
219 .build()
220});
221
222fn get_lead_doc() -> &'static Documentation {
223 &LEAD_DOCUMENTATION
224}
225
226impl WindowUDFImpl for WindowShift {
227 fn as_any(&self) -> &dyn Any {
228 self
229 }
230
231 fn name(&self) -> &str {
232 self.kind.name()
233 }
234
235 fn signature(&self) -> &Signature {
236 &self.signature
237 }
238
239 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
245 parse_expr(expr_args.input_exprs(), expr_args.input_fields())
246 .into_iter()
247 .collect::<Vec<_>>()
248 }
249
250 fn partition_evaluator(
251 &self,
252 partition_evaluator_args: PartitionEvaluatorArgs,
253 ) -> Result<Box<dyn PartitionEvaluator>> {
254 let shift_offset =
255 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
256 .map(get_signed_integer)
257 .map_or(Ok(None), |v| v.map(Some))
258 .map(|n| self.kind.shift_offset(n))
259 .map(|offset| {
260 if partition_evaluator_args.is_reversed() {
261 -offset
262 } else {
263 offset
264 }
265 })?;
266 let default_value = parse_default_value(
267 partition_evaluator_args.input_exprs(),
268 partition_evaluator_args.input_fields(),
269 )?;
270
271 Ok(Box::new(WindowShiftEvaluator {
272 shift_offset,
273 default_value,
274 ignore_nulls: partition_evaluator_args.ignore_nulls(),
275 non_null_offsets: VecDeque::new(),
276 }))
277 }
278
279 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
280 let return_field = parse_expr_field(field_args.input_fields())?;
281
282 Ok(return_field
283 .as_ref()
284 .clone()
285 .with_name(field_args.name())
286 .into())
287 }
288
289 fn reverse_expr(&self) -> ReversedUDWF {
290 match self.kind {
291 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
292 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
293 }
294 }
295
296 fn documentation(&self) -> Option<&Documentation> {
297 match self.kind {
298 WindowShiftKind::Lag => Some(get_lag_doc()),
299 WindowShiftKind::Lead => Some(get_lead_doc()),
300 }
301 }
302
303 fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
304 let Some(other) = other.as_any().downcast_ref::<Self>() else {
305 return false;
306 };
307 let Self { signature, kind } = self;
308 signature == &other.signature && kind == &other.kind
309 }
310
311 fn hash_value(&self) -> u64 {
312 let Self { signature, kind } = self;
313 let mut hasher = DefaultHasher::new();
314 std::any::type_name::<Self>().hash(&mut hasher);
315 signature.hash(&mut hasher);
316 kind.hash(&mut hasher);
317 hasher.finish()
318 }
319}
320
321fn parse_expr(
334 input_exprs: &[Arc<dyn PhysicalExpr>],
335 input_fields: &[FieldRef],
336) -> Result<Arc<dyn PhysicalExpr>> {
337 assert!(!input_exprs.is_empty());
338 assert!(!input_fields.is_empty());
339
340 let expr = Arc::clone(input_exprs.first().unwrap());
341 let expr_field = input_fields.first().unwrap();
342
343 if !expr_field.data_type().is_null() {
345 return Ok(expr);
346 }
347
348 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
349 default_value.map_or(Ok(expr), |value| {
350 ScalarValue::try_from(&value.data_type()).map(|v| {
351 Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
352 as Arc<dyn PhysicalExpr>
353 })
354 })
355}
356
357static NULL_FIELD: LazyLock<FieldRef> =
358 LazyLock::new(|| Field::new("value", DataType::Null, true).into());
359
360fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
365 assert!(!input_fields.is_empty());
366 let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
367
368 if !expr_field.data_type().is_null() {
370 return Ok(expr_field.as_ref().clone().with_nullable(true).into());
371 }
372
373 let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
374 Ok(default_value_field
375 .as_ref()
376 .clone()
377 .with_nullable(true)
378 .into())
379}
380
381fn parse_default_value(
384 input_exprs: &[Arc<dyn PhysicalExpr>],
385 input_types: &[FieldRef],
386) -> Result<ScalarValue> {
387 let expr_field = parse_expr_field(input_types)?;
388 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
389
390 unparsed
391 .filter(|v| !v.data_type().is_null())
392 .map(|v| v.cast_to(expr_field.data_type()))
393 .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
394}
395
396#[derive(Debug)]
397struct WindowShiftEvaluator {
398 shift_offset: i64,
399 default_value: ScalarValue,
400 ignore_nulls: bool,
401 non_null_offsets: VecDeque<usize>,
403}
404
405impl WindowShiftEvaluator {
406 fn is_lag(&self) -> bool {
407 self.shift_offset > 0
409 }
410}
411
412fn evaluate_all_with_ignore_null(
414 array: &ArrayRef,
415 offset: i64,
416 default_value: &ScalarValue,
417 is_lag: bool,
418) -> Result<ArrayRef, DataFusionError> {
419 let valid_indices: Vec<usize> =
420 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
421 let direction = !is_lag;
422 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
423 .map(|id| {
424 let result_index = match valid_indices.binary_search(&id) {
425 Ok(pos) => if direction {
426 pos.checked_add(offset as usize)
427 } else {
428 pos.checked_sub(offset.unsigned_abs() as usize)
429 }
430 .and_then(|new_pos| {
431 if new_pos < valid_indices.len() {
432 Some(valid_indices[new_pos])
433 } else {
434 None
435 }
436 }),
437 Err(pos) => if direction {
438 pos.checked_add(offset as usize)
439 } else if pos > 0 {
440 pos.checked_sub(offset.unsigned_abs() as usize)
441 } else {
442 None
443 }
444 .and_then(|new_pos| {
445 if new_pos < valid_indices.len() {
446 Some(valid_indices[new_pos])
447 } else {
448 None
449 }
450 }),
451 };
452
453 match result_index {
454 Some(index) => ScalarValue::try_from_array(array, index),
455 None => Ok(default_value.clone()),
456 }
457 })
458 .collect();
459
460 let new_array = new_array_results?;
461 ScalarValue::iter_to_array(new_array)
462}
463fn shift_with_default_value(
465 array: &ArrayRef,
466 offset: i64,
467 default_value: &ScalarValue,
468) -> Result<ArrayRef> {
469 use datafusion_common::arrow::compute::concat;
470
471 let value_len = array.len() as i64;
472 if offset == 0 {
473 Ok(Arc::clone(array))
474 } else if offset == i64::MIN || offset.abs() >= value_len {
475 default_value.to_array_of_size(value_len as usize)
476 } else {
477 let slice_offset = (-offset).clamp(0, value_len) as usize;
478 let length = array.len() - offset.unsigned_abs() as usize;
479 let slice = array.slice(slice_offset, length);
480
481 let nulls = offset.unsigned_abs() as usize;
483 let default_values = default_value.to_array_of_size(nulls)?;
484
485 if offset > 0 {
487 concat(&[default_values.as_ref(), slice.as_ref()])
488 .map_err(|e| arrow_datafusion_err!(e))
489 } else {
490 concat(&[slice.as_ref(), default_values.as_ref()])
491 .map_err(|e| arrow_datafusion_err!(e))
492 }
493 }
494}
495
496impl PartitionEvaluator for WindowShiftEvaluator {
497 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
498 if self.is_lag() {
499 let start = if self.non_null_offsets.len() == self.shift_offset as usize {
500 let offset: usize = self.non_null_offsets.iter().sum();
502 idx.saturating_sub(offset)
503 } else if !self.ignore_nulls {
504 let offset = self.shift_offset as usize;
505 idx.saturating_sub(offset)
506 } else {
507 0
508 };
509 let end = idx + 1;
510 Ok(Range { start, end })
511 } else {
512 let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
513 let offset: usize = self.non_null_offsets.iter().sum();
515 min(idx + offset + 1, n_rows)
516 } else if !self.ignore_nulls {
517 let offset = (-self.shift_offset) as usize;
518 min(idx + offset, n_rows)
519 } else {
520 n_rows
521 };
522 Ok(Range { start: idx, end })
523 }
524 }
525
526 fn is_causal(&self) -> bool {
527 self.is_lag()
529 }
530
531 fn evaluate(
532 &mut self,
533 values: &[ArrayRef],
534 range: &Range<usize>,
535 ) -> Result<ScalarValue> {
536 let array = &values[0];
537 let len = array.len();
538
539 let i = if self.is_lag() {
541 (range.end as i64 - self.shift_offset - 1) as usize
542 } else {
543 (range.start as i64 - self.shift_offset) as usize
545 };
546
547 let mut idx: Option<usize> = if i < len { Some(i) } else { None };
548
549 if self.ignore_nulls && self.is_lag() {
552 idx = if self.non_null_offsets.len() == self.shift_offset as usize {
555 let total_offset: usize = self.non_null_offsets.iter().sum();
556 Some(range.end - 1 - total_offset)
557 } else {
558 None
559 };
560
561 if array.is_valid(range.end - 1) {
563 self.non_null_offsets.push_back(1);
565 if self.non_null_offsets.len() > self.shift_offset as usize {
566 self.non_null_offsets.pop_front();
568 }
569 } else if !self.non_null_offsets.is_empty() {
570 let end_idx = self.non_null_offsets.len() - 1;
572 self.non_null_offsets[end_idx] += 1;
573 }
574 } else if self.ignore_nulls && !self.is_lag() {
575 let non_null_row_count = (-self.shift_offset) as usize;
578
579 if self.non_null_offsets.is_empty() {
580 let mut offset_val = 1;
582 for idx in range.start + 1..range.end {
583 if array.is_valid(idx) {
584 self.non_null_offsets.push_back(offset_val);
585 offset_val = 1;
586 } else {
587 offset_val += 1;
588 }
589 if self.non_null_offsets.len() == non_null_row_count + 1 {
592 break;
593 }
594 }
595 } else if range.end < len && array.is_valid(range.end) {
596 if array.is_valid(range.end) {
598 self.non_null_offsets.push_back(1);
600 } else {
601 let last_idx = self.non_null_offsets.len() - 1;
603 self.non_null_offsets[last_idx] += 1;
604 }
605 }
606
607 idx = if self.non_null_offsets.len() >= non_null_row_count {
609 let total_offset: usize =
610 self.non_null_offsets.iter().take(non_null_row_count).sum();
611 Some(range.start + total_offset)
612 } else {
613 None
614 };
615 if !self.non_null_offsets.is_empty() {
618 self.non_null_offsets[0] -= 1;
619 if self.non_null_offsets[0] == 0 {
620 self.non_null_offsets.pop_front();
622 }
623 }
624 }
625
626 #[allow(clippy::unnecessary_unwrap)]
632 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
633 ScalarValue::try_from_array(array, idx.unwrap())
634 } else {
635 Ok(self.default_value.clone())
636 }
637 }
638
639 fn evaluate_all(
640 &mut self,
641 values: &[ArrayRef],
642 _num_rows: usize,
643 ) -> Result<ArrayRef> {
644 let value = &values[0];
646 if !self.ignore_nulls {
647 shift_with_default_value(value, self.shift_offset, &self.default_value)
648 } else {
649 evaluate_all_with_ignore_null(
650 value,
651 self.shift_offset,
652 &self.default_value,
653 self.is_lag(),
654 )
655 }
656 }
657
658 fn supports_bounded_execution(&self) -> bool {
659 true
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666 use arrow::array::*;
667 use datafusion_common::cast::as_int32_array;
668 use datafusion_physical_expr::expressions::{Column, Literal};
669 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
670
671 fn test_i32_result(
672 expr: WindowShift,
673 partition_evaluator_args: PartitionEvaluatorArgs,
674 expected: Int32Array,
675 ) -> Result<()> {
676 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
677 let values = vec![arr];
678 let num_rows = values.len();
679 let result = expr
680 .partition_evaluator(partition_evaluator_args)?
681 .evaluate_all(&values, num_rows)?;
682 let result = as_int32_array(&result)?;
683 assert_eq!(expected, *result);
684 Ok(())
685 }
686
687 #[test]
688 fn lead_lag_get_range() -> Result<()> {
689 let lag_fn = WindowShiftEvaluator {
691 shift_offset: 2,
692 default_value: ScalarValue::Null,
693 ignore_nulls: false,
694 non_null_offsets: Default::default(),
695 };
696 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
697 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
698
699 let lag_fn = WindowShiftEvaluator {
701 shift_offset: 2,
702 default_value: ScalarValue::Null,
703 ignore_nulls: true,
704 non_null_offsets: vec![2, 2].into(), };
707 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
708
709 let lead_fn = WindowShiftEvaluator {
711 shift_offset: -2,
712 default_value: ScalarValue::Null,
713 ignore_nulls: false,
714 non_null_offsets: Default::default(),
715 };
716 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
717 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
718
719 let lead_fn = WindowShiftEvaluator {
721 shift_offset: -2,
722 default_value: ScalarValue::Null,
723 ignore_nulls: true,
724 non_null_offsets: vec![2, 2].into(),
726 };
727 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
728
729 Ok(())
730 }
731
732 #[test]
733 fn test_lead_window_shift() -> Result<()> {
734 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
735
736 test_i32_result(
737 WindowShift::lead(),
738 PartitionEvaluatorArgs::new(
739 &[expr],
740 &[Field::new("f", DataType::Int32, true).into()],
741 false,
742 false,
743 ),
744 [
745 Some(-2),
746 Some(3),
747 Some(-4),
748 Some(5),
749 Some(-6),
750 Some(7),
751 Some(8),
752 None,
753 ]
754 .iter()
755 .collect::<Int32Array>(),
756 )
757 }
758
759 #[test]
760 fn test_lag_window_shift() -> Result<()> {
761 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
762
763 test_i32_result(
764 WindowShift::lag(),
765 PartitionEvaluatorArgs::new(
766 &[expr],
767 &[Field::new("f", DataType::Int32, true).into()],
768 false,
769 false,
770 ),
771 [
772 None,
773 Some(1),
774 Some(-2),
775 Some(3),
776 Some(-4),
777 Some(5),
778 Some(-6),
779 Some(7),
780 ]
781 .iter()
782 .collect::<Int32Array>(),
783 )
784 }
785
786 #[test]
787 fn test_lag_with_default() -> Result<()> {
788 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
789 let shift_offset =
790 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
791 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
792 as Arc<dyn PhysicalExpr>;
793
794 let input_exprs = &[expr, shift_offset, default_value];
795 let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
796 .into_iter()
797 .map(|d| Field::new("f", d, true))
798 .map(Arc::new)
799 .collect::<Vec<_>>();
800
801 test_i32_result(
802 WindowShift::lag(),
803 PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
804 [
805 Some(100),
806 Some(1),
807 Some(-2),
808 Some(3),
809 Some(-4),
810 Some(5),
811 Some(-6),
812 Some(7),
813 ]
814 .iter()
815 .collect::<Int32Array>(),
816 )
817 }
818}