1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use datafusion_common::arrow::array::ArrayRef;
22use datafusion_common::arrow::datatypes::DataType;
23use datafusion_common::arrow::datatypes::Field;
24use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
25use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
26use datafusion_expr::{
27 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
28 Volatility, WindowUDFImpl,
29};
30use datafusion_functions_window_common::expr::ExpressionArgs;
31use datafusion_functions_window_common::field::WindowUDFFieldArgs;
32use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
33use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
34use std::any::Any;
35use std::cmp::min;
36use std::collections::VecDeque;
37use std::ops::{Neg, Range};
38use std::sync::{Arc, LazyLock};
39
40get_or_init_udwf!(
41 Lag,
42 lag,
43 "Returns the row value that precedes the current row by a specified \
44 offset within partition. If no such row exists, then returns the \
45 default value.",
46 WindowShift::lag
47);
48get_or_init_udwf!(
49 Lead,
50 lead,
51 "Returns the value from a row that follows the current row by a \
52 specified offset within the partition. If no such row exists, then \
53 returns the default value.",
54 WindowShift::lead
55);
56
57pub fn lag(
64 arg: datafusion_expr::Expr,
65 shift_offset: Option<i64>,
66 default_value: Option<ScalarValue>,
67) -> datafusion_expr::Expr {
68 let shift_offset_lit = shift_offset
69 .map(|v| v.lit())
70 .unwrap_or(ScalarValue::Null.lit());
71 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
72
73 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
74}
75
76pub fn lead(
83 arg: datafusion_expr::Expr,
84 shift_offset: Option<i64>,
85 default_value: Option<ScalarValue>,
86) -> datafusion_expr::Expr {
87 let shift_offset_lit = shift_offset
88 .map(|v| v.lit())
89 .unwrap_or(ScalarValue::Null.lit());
90 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
91
92 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
93}
94
95#[derive(Debug)]
96enum WindowShiftKind {
97 Lag,
98 Lead,
99}
100
101impl WindowShiftKind {
102 fn name(&self) -> &'static str {
103 match self {
104 WindowShiftKind::Lag => "lag",
105 WindowShiftKind::Lead => "lead",
106 }
107 }
108
109 fn shift_offset(&self, value: Option<i64>) -> i64 {
113 match self {
114 WindowShiftKind::Lag => value.unwrap_or(1),
115 WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
116 }
117 }
118}
119
120#[derive(Debug)]
122pub struct WindowShift {
123 signature: Signature,
124 kind: WindowShiftKind,
125}
126
127impl WindowShift {
128 fn new(kind: WindowShiftKind) -> Self {
129 Self {
130 signature: Signature::one_of(
131 vec![
132 TypeSignature::Any(1),
133 TypeSignature::Any(2),
134 TypeSignature::Any(3),
135 ],
136 Volatility::Immutable,
137 ),
138 kind,
139 }
140 }
141
142 pub fn lag() -> Self {
143 Self::new(WindowShiftKind::Lag)
144 }
145
146 pub fn lead() -> Self {
147 Self::new(WindowShiftKind::Lead)
148 }
149}
150
151static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
152 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
153 current row within the partition; if there is no such row, instead return default \
154 (which must be of the same type as value).", "lag(expression, offset, default)")
155 .with_argument("expression", "Expression to operate on")
156 .with_argument("offset", "Integer. Specifies how many rows back \
157 the value of expression should be retrieved. Defaults to 1.")
158 .with_argument("default", "The default value if the offset is \
159 not within the partition. Must be of the same type as expression.")
160 .build()
161});
162
163fn get_lag_doc() -> &'static Documentation {
164 &LAG_DOCUMENTATION
165}
166
167static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
168 Documentation::builder(DOC_SECTION_ANALYTICAL,
169 "Returns value evaluated at the row that is offset rows after the \
170 current row within the partition; if there is no such row, instead return default \
171 (which must be of the same type as value).",
172 "lead(expression, offset, default)")
173 .with_argument("expression", "Expression to operate on")
174 .with_argument("offset", "Integer. Specifies how many rows \
175 forward the value of expression should be retrieved. Defaults to 1.")
176 .with_argument("default", "The default value if the offset is \
177 not within the partition. Must be of the same type as expression.")
178 .build()
179});
180
181fn get_lead_doc() -> &'static Documentation {
182 &LEAD_DOCUMENTATION
183}
184
185impl WindowUDFImpl for WindowShift {
186 fn as_any(&self) -> &dyn Any {
187 self
188 }
189
190 fn name(&self) -> &str {
191 self.kind.name()
192 }
193
194 fn signature(&self) -> &Signature {
195 &self.signature
196 }
197
198 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
204 parse_expr(expr_args.input_exprs(), expr_args.input_types())
205 .into_iter()
206 .collect::<Vec<_>>()
207 }
208
209 fn partition_evaluator(
210 &self,
211 partition_evaluator_args: PartitionEvaluatorArgs,
212 ) -> Result<Box<dyn PartitionEvaluator>> {
213 let shift_offset =
214 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
215 .map(get_signed_integer)
216 .map_or(Ok(None), |v| v.map(Some))
217 .map(|n| self.kind.shift_offset(n))
218 .map(|offset| {
219 if partition_evaluator_args.is_reversed() {
220 -offset
221 } else {
222 offset
223 }
224 })?;
225 let default_value = parse_default_value(
226 partition_evaluator_args.input_exprs(),
227 partition_evaluator_args.input_types(),
228 )?;
229
230 Ok(Box::new(WindowShiftEvaluator {
231 shift_offset,
232 default_value,
233 ignore_nulls: partition_evaluator_args.ignore_nulls(),
234 non_null_offsets: VecDeque::new(),
235 }))
236 }
237
238 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
239 let return_type = parse_expr_type(field_args.input_types())?;
240
241 Ok(Field::new(field_args.name(), return_type, true))
242 }
243
244 fn reverse_expr(&self) -> ReversedUDWF {
245 match self.kind {
246 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
247 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
248 }
249 }
250
251 fn documentation(&self) -> Option<&Documentation> {
252 match self.kind {
253 WindowShiftKind::Lag => Some(get_lag_doc()),
254 WindowShiftKind::Lead => Some(get_lead_doc()),
255 }
256 }
257}
258
259fn parse_expr(
272 input_exprs: &[Arc<dyn PhysicalExpr>],
273 input_types: &[DataType],
274) -> Result<Arc<dyn PhysicalExpr>> {
275 assert!(!input_exprs.is_empty());
276 assert!(!input_types.is_empty());
277
278 let expr = Arc::clone(input_exprs.first().unwrap());
279 let expr_type = input_types.first().unwrap();
280
281 if !expr_type.is_null() {
283 return Ok(expr);
284 }
285
286 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
287 default_value.map_or(Ok(expr), |value| {
288 ScalarValue::try_from(&value.data_type()).map(|v| {
289 Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
290 as Arc<dyn PhysicalExpr>
291 })
292 })
293}
294
295fn parse_expr_type(input_types: &[DataType]) -> Result<DataType> {
300 assert!(!input_types.is_empty());
301 let expr_type = input_types.first().unwrap_or(&DataType::Null);
302
303 if !expr_type.is_null() {
305 return Ok(expr_type.clone());
306 }
307
308 let default_value_type = input_types.get(2).unwrap_or(&DataType::Null);
309 Ok(default_value_type.clone())
310}
311
312fn parse_default_value(
315 input_exprs: &[Arc<dyn PhysicalExpr>],
316 input_types: &[DataType],
317) -> Result<ScalarValue> {
318 let expr_type = parse_expr_type(input_types)?;
319 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
320
321 unparsed
322 .filter(|v| !v.data_type().is_null())
323 .map(|v| v.cast_to(&expr_type))
324 .unwrap_or(ScalarValue::try_from(expr_type))
325}
326
327#[derive(Debug)]
328struct WindowShiftEvaluator {
329 shift_offset: i64,
330 default_value: ScalarValue,
331 ignore_nulls: bool,
332 non_null_offsets: VecDeque<usize>,
334}
335
336impl WindowShiftEvaluator {
337 fn is_lag(&self) -> bool {
338 self.shift_offset > 0
340 }
341}
342
343fn evaluate_all_with_ignore_null(
345 array: &ArrayRef,
346 offset: i64,
347 default_value: &ScalarValue,
348 is_lag: bool,
349) -> Result<ArrayRef, DataFusionError> {
350 let valid_indices: Vec<usize> =
351 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
352 let direction = !is_lag;
353 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
354 .map(|id| {
355 let result_index = match valid_indices.binary_search(&id) {
356 Ok(pos) => if direction {
357 pos.checked_add(offset as usize)
358 } else {
359 pos.checked_sub(offset.unsigned_abs() as usize)
360 }
361 .and_then(|new_pos| {
362 if new_pos < valid_indices.len() {
363 Some(valid_indices[new_pos])
364 } else {
365 None
366 }
367 }),
368 Err(pos) => if direction {
369 pos.checked_add(offset as usize)
370 } else if pos > 0 {
371 pos.checked_sub(offset.unsigned_abs() as usize)
372 } else {
373 None
374 }
375 .and_then(|new_pos| {
376 if new_pos < valid_indices.len() {
377 Some(valid_indices[new_pos])
378 } else {
379 None
380 }
381 }),
382 };
383
384 match result_index {
385 Some(index) => ScalarValue::try_from_array(array, index),
386 None => Ok(default_value.clone()),
387 }
388 })
389 .collect();
390
391 let new_array = new_array_results?;
392 ScalarValue::iter_to_array(new_array)
393}
394fn shift_with_default_value(
396 array: &ArrayRef,
397 offset: i64,
398 default_value: &ScalarValue,
399) -> Result<ArrayRef> {
400 use datafusion_common::arrow::compute::concat;
401
402 let value_len = array.len() as i64;
403 if offset == 0 {
404 Ok(Arc::clone(array))
405 } else if offset == i64::MIN || offset.abs() >= value_len {
406 default_value.to_array_of_size(value_len as usize)
407 } else {
408 let slice_offset = (-offset).clamp(0, value_len) as usize;
409 let length = array.len() - offset.unsigned_abs() as usize;
410 let slice = array.slice(slice_offset, length);
411
412 let nulls = offset.unsigned_abs() as usize;
414 let default_values = default_value.to_array_of_size(nulls)?;
415
416 if offset > 0 {
418 concat(&[default_values.as_ref(), slice.as_ref()])
419 .map_err(|e| arrow_datafusion_err!(e))
420 } else {
421 concat(&[slice.as_ref(), default_values.as_ref()])
422 .map_err(|e| arrow_datafusion_err!(e))
423 }
424 }
425}
426
427impl PartitionEvaluator for WindowShiftEvaluator {
428 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
429 if self.is_lag() {
430 let start = if self.non_null_offsets.len() == self.shift_offset as usize {
431 let offset: usize = self.non_null_offsets.iter().sum();
433 idx.saturating_sub(offset)
434 } else if !self.ignore_nulls {
435 let offset = self.shift_offset as usize;
436 idx.saturating_sub(offset)
437 } else {
438 0
439 };
440 let end = idx + 1;
441 Ok(Range { start, end })
442 } else {
443 let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
444 let offset: usize = self.non_null_offsets.iter().sum();
446 min(idx + offset + 1, n_rows)
447 } else if !self.ignore_nulls {
448 let offset = (-self.shift_offset) as usize;
449 min(idx + offset, n_rows)
450 } else {
451 n_rows
452 };
453 Ok(Range { start: idx, end })
454 }
455 }
456
457 fn is_causal(&self) -> bool {
458 self.is_lag()
460 }
461
462 fn evaluate(
463 &mut self,
464 values: &[ArrayRef],
465 range: &Range<usize>,
466 ) -> Result<ScalarValue> {
467 let array = &values[0];
468 let len = array.len();
469
470 let i = if self.is_lag() {
472 (range.end as i64 - self.shift_offset - 1) as usize
473 } else {
474 (range.start as i64 - self.shift_offset) as usize
476 };
477
478 let mut idx: Option<usize> = if i < len { Some(i) } else { None };
479
480 if self.ignore_nulls && self.is_lag() {
483 idx = if self.non_null_offsets.len() == self.shift_offset as usize {
486 let total_offset: usize = self.non_null_offsets.iter().sum();
487 Some(range.end - 1 - total_offset)
488 } else {
489 None
490 };
491
492 if array.is_valid(range.end - 1) {
494 self.non_null_offsets.push_back(1);
496 if self.non_null_offsets.len() > self.shift_offset as usize {
497 self.non_null_offsets.pop_front();
499 }
500 } else if !self.non_null_offsets.is_empty() {
501 let end_idx = self.non_null_offsets.len() - 1;
503 self.non_null_offsets[end_idx] += 1;
504 }
505 } else if self.ignore_nulls && !self.is_lag() {
506 let non_null_row_count = (-self.shift_offset) as usize;
509
510 if self.non_null_offsets.is_empty() {
511 let mut offset_val = 1;
513 for idx in range.start + 1..range.end {
514 if array.is_valid(idx) {
515 self.non_null_offsets.push_back(offset_val);
516 offset_val = 1;
517 } else {
518 offset_val += 1;
519 }
520 if self.non_null_offsets.len() == non_null_row_count + 1 {
523 break;
524 }
525 }
526 } else if range.end < len && array.is_valid(range.end) {
527 if array.is_valid(range.end) {
529 self.non_null_offsets.push_back(1);
531 } else {
532 let last_idx = self.non_null_offsets.len() - 1;
534 self.non_null_offsets[last_idx] += 1;
535 }
536 }
537
538 idx = if self.non_null_offsets.len() >= non_null_row_count {
540 let total_offset: usize =
541 self.non_null_offsets.iter().take(non_null_row_count).sum();
542 Some(range.start + total_offset)
543 } else {
544 None
545 };
546 if !self.non_null_offsets.is_empty() {
549 self.non_null_offsets[0] -= 1;
550 if self.non_null_offsets[0] == 0 {
551 self.non_null_offsets.pop_front();
553 }
554 }
555 }
556
557 #[allow(clippy::unnecessary_unwrap)]
563 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
564 ScalarValue::try_from_array(array, idx.unwrap())
565 } else {
566 Ok(self.default_value.clone())
567 }
568 }
569
570 fn evaluate_all(
571 &mut self,
572 values: &[ArrayRef],
573 _num_rows: usize,
574 ) -> Result<ArrayRef> {
575 let value = &values[0];
577 if !self.ignore_nulls {
578 shift_with_default_value(value, self.shift_offset, &self.default_value)
579 } else {
580 evaluate_all_with_ignore_null(
581 value,
582 self.shift_offset,
583 &self.default_value,
584 self.is_lag(),
585 )
586 }
587 }
588
589 fn supports_bounded_execution(&self) -> bool {
590 true
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use arrow::array::*;
598 use datafusion_common::cast::as_int32_array;
599 use datafusion_physical_expr::expressions::{Column, Literal};
600 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
601
602 fn test_i32_result(
603 expr: WindowShift,
604 partition_evaluator_args: PartitionEvaluatorArgs,
605 expected: Int32Array,
606 ) -> Result<()> {
607 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
608 let values = vec![arr];
609 let num_rows = values.len();
610 let result = expr
611 .partition_evaluator(partition_evaluator_args)?
612 .evaluate_all(&values, num_rows)?;
613 let result = as_int32_array(&result)?;
614 assert_eq!(expected, *result);
615 Ok(())
616 }
617
618 #[test]
619 fn lead_lag_get_range() -> Result<()> {
620 let lag_fn = WindowShiftEvaluator {
622 shift_offset: 2,
623 default_value: ScalarValue::Null,
624 ignore_nulls: false,
625 non_null_offsets: Default::default(),
626 };
627 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
628 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
629
630 let lag_fn = WindowShiftEvaluator {
632 shift_offset: 2,
633 default_value: ScalarValue::Null,
634 ignore_nulls: true,
635 non_null_offsets: vec![2, 2].into(), };
638 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
639
640 let lead_fn = WindowShiftEvaluator {
642 shift_offset: -2,
643 default_value: ScalarValue::Null,
644 ignore_nulls: false,
645 non_null_offsets: Default::default(),
646 };
647 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
648 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
649
650 let lead_fn = WindowShiftEvaluator {
652 shift_offset: -2,
653 default_value: ScalarValue::Null,
654 ignore_nulls: true,
655 non_null_offsets: vec![2, 2].into(),
657 };
658 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
659
660 Ok(())
661 }
662
663 #[test]
664 fn test_lead_window_shift() -> Result<()> {
665 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
666
667 test_i32_result(
668 WindowShift::lead(),
669 PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
670 [
671 Some(-2),
672 Some(3),
673 Some(-4),
674 Some(5),
675 Some(-6),
676 Some(7),
677 Some(8),
678 None,
679 ]
680 .iter()
681 .collect::<Int32Array>(),
682 )
683 }
684
685 #[test]
686 fn test_lag_window_shift() -> Result<()> {
687 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
688
689 test_i32_result(
690 WindowShift::lag(),
691 PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
692 [
693 None,
694 Some(1),
695 Some(-2),
696 Some(3),
697 Some(-4),
698 Some(5),
699 Some(-6),
700 Some(7),
701 ]
702 .iter()
703 .collect::<Int32Array>(),
704 )
705 }
706
707 #[test]
708 fn test_lag_with_default() -> Result<()> {
709 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
710 let shift_offset =
711 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
712 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
713 as Arc<dyn PhysicalExpr>;
714
715 let input_exprs = &[expr, shift_offset, default_value];
716 let input_types: &[DataType] =
717 &[DataType::Int32, DataType::Int32, DataType::Int32];
718
719 test_i32_result(
720 WindowShift::lag(),
721 PartitionEvaluatorArgs::new(input_exprs, input_types, false, false),
722 [
723 Some(100),
724 Some(1),
725 Some(-2),
726 Some(3),
727 Some(-4),
728 Some(5),
729 Some(-6),
730 Some(7),
731 ]
732 .iter()
733 .collect::<Int32Array>(),
734 )
735 }
736}