1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::buffer::NullBuffer;
23use arrow::datatypes::FieldRef;
24use datafusion_common::arrow::array::ArrayRef;
25use datafusion_common::arrow::datatypes::{DataType, Field};
26use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err};
27use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
28use datafusion_expr::window_state::WindowAggState;
29use datafusion_expr::{
30 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
31 TypeSignature, Volatility, WindowUDFImpl,
32};
33use datafusion_functions_window_common::field;
34use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
35use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
36use field::WindowUDFFieldArgs;
37use std::any::Any;
38use std::cmp::Ordering;
39use std::fmt::Debug;
40use std::hash::Hash;
41use std::ops::Range;
42use std::sync::{Arc, LazyLock};
43
44define_udwf_and_expr!(
45 First,
46 first_value,
47 [arg],
48 "Returns the first value in the window frame",
49 NthValue::first
50);
51define_udwf_and_expr!(
52 Last,
53 last_value,
54 [arg],
55 "Returns the last value in the window frame",
56 NthValue::last
57);
58get_or_init_udwf!(
59 NthValue,
60 nth_value,
61 "Returns the nth value in the window frame",
62 NthValue::nth
63);
64
65pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
67 nth_value_udwf().call(vec![arg, n.lit()])
68}
69
70#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
72pub enum NthValueKind {
73 First,
74 Last,
75 Nth,
76}
77
78impl NthValueKind {
79 fn name(&self) -> &'static str {
80 match self {
81 NthValueKind::First => "first_value",
82 NthValueKind::Last => "last_value",
83 NthValueKind::Nth => "nth_value",
84 }
85 }
86}
87
88#[derive(Debug, PartialEq, Eq, Hash)]
89pub struct NthValue {
90 signature: Signature,
91 kind: NthValueKind,
92}
93
94impl NthValue {
95 pub fn new(kind: NthValueKind) -> Self {
97 Self {
98 signature: Signature::one_of(
99 vec![
100 TypeSignature::Nullary,
101 TypeSignature::Any(1),
102 TypeSignature::Any(2),
103 ],
104 Volatility::Immutable,
105 ),
106 kind,
107 }
108 }
109
110 pub fn first() -> Self {
111 Self::new(NthValueKind::First)
112 }
113
114 pub fn last() -> Self {
115 Self::new(NthValueKind::Last)
116 }
117 pub fn nth() -> Self {
118 Self::new(NthValueKind::Nth)
119 }
120
121 pub fn kind(&self) -> &NthValueKind {
122 &self.kind
123 }
124}
125
126static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
127 Documentation::builder(
128 DOC_SECTION_ANALYTICAL,
129 "Returns value evaluated at the row that is the first row of the window \
130 frame.",
131 "first_value(expression)",
132 )
133 .with_argument("expression", "Expression to operate on")
134 .with_sql_example(
135 r#"
136```sql
137-- Example usage of the first_value window function:
138SELECT department,
139 employee_id,
140 salary,
141 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
142FROM employees;
143
144+-------------+-------------+--------+------------+
145| department | employee_id | salary | top_salary |
146+-------------+-------------+--------+------------+
147| Sales | 1 | 70000 | 70000 |
148| Sales | 2 | 50000 | 70000 |
149| Sales | 3 | 30000 | 70000 |
150| Engineering | 4 | 90000 | 90000 |
151| Engineering | 5 | 80000 | 90000 |
152+-------------+-------------+--------+------------+
153```
154"#,
155 )
156 .build()
157});
158
159fn get_first_value_doc() -> &'static Documentation {
160 &FIRST_VALUE_DOCUMENTATION
161}
162
163static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
164 Documentation::builder(
165 DOC_SECTION_ANALYTICAL,
166 "Returns value evaluated at the row that is the last row of the window \
167 frame.",
168 "last_value(expression)",
169 )
170 .with_argument("expression", "Expression to operate on")
171 .with_sql_example(r#"```sql
172-- SQL example of last_value:
173SELECT department,
174 employee_id,
175 salary,
176 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
177FROM employees;
178
179+-------------+-------------+--------+---------------------+
180| department | employee_id | salary | running_last_salary |
181+-------------+-------------+--------+---------------------+
182| Sales | 1 | 30000 | 30000 |
183| Sales | 2 | 50000 | 50000 |
184| Sales | 3 | 70000 | 70000 |
185| Engineering | 4 | 40000 | 40000 |
186| Engineering | 5 | 60000 | 60000 |
187+-------------+-------------+--------+---------------------+
188```
189"#)
190 .build()
191});
192
193fn get_last_value_doc() -> &'static Documentation {
194 &LAST_VALUE_DOCUMENTATION
195}
196
197static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
198 Documentation::builder(
199 DOC_SECTION_ANALYTICAL,
200 "Returns the value evaluated at the nth row of the window frame \
201 (counting from 1). Returns NULL if no such row exists.",
202 "nth_value(expression, n)",
203 )
204 .with_argument(
205 "expression",
206 "The column from which to retrieve the nth value.",
207 )
208 .with_argument(
209 "n",
210 "Integer. Specifies the row number (starting from 1) in the window frame.",
211 )
212 .with_sql_example(
213 r#"
214```sql
215-- Sample employees table:
216CREATE TABLE employees (id INT, salary INT);
217INSERT INTO employees (id, salary) VALUES
218(1, 30000),
219(2, 40000),
220(3, 50000),
221(4, 60000),
222(5, 70000);
223
224-- Example usage of nth_value:
225SELECT nth_value(salary, 2) OVER (
226 ORDER BY salary
227 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
228) AS nth_value
229FROM employees;
230
231+-----------+
232| nth_value |
233+-----------+
234| 40000 |
235| 40000 |
236| 40000 |
237| 40000 |
238| 40000 |
239+-----------+
240```
241"#,
242 )
243 .build()
244});
245
246fn get_nth_value_doc() -> &'static Documentation {
247 &NTH_VALUE_DOCUMENTATION
248}
249
250impl WindowUDFImpl for NthValue {
251 fn as_any(&self) -> &dyn Any {
252 self
253 }
254
255 fn name(&self) -> &str {
256 self.kind.name()
257 }
258
259 fn signature(&self) -> &Signature {
260 &self.signature
261 }
262
263 fn partition_evaluator(
264 &self,
265 partition_evaluator_args: PartitionEvaluatorArgs,
266 ) -> Result<Box<dyn PartitionEvaluator>> {
267 let state = NthValueState {
268 finalized_result: None,
269 kind: self.kind,
270 };
271
272 if self.kind != NthValueKind::Nth {
273 return Ok(Box::new(NthValueEvaluator {
274 state,
275 ignore_nulls: partition_evaluator_args.ignore_nulls(),
276 n: 0,
277 }));
278 }
279
280 let n = match get_scalar_value_from_args(
281 partition_evaluator_args.input_exprs(),
282 1,
283 )
284 .map_err(|_e| {
285 exec_datafusion_err!(
286 "Expected a signed integer literal for the second argument of nth_value"
287 )
288 })?
289 .map(|v| get_signed_integer(&v))
290 {
291 Some(Ok(n)) => {
292 if partition_evaluator_args.is_reversed() {
293 -n
294 } else {
295 n
296 }
297 }
298 _ => {
299 return exec_err!(
300 "Expected a signed integer literal for the second argument of nth_value"
301 );
302 }
303 };
304
305 Ok(Box::new(NthValueEvaluator {
306 state,
307 ignore_nulls: partition_evaluator_args.ignore_nulls(),
308 n,
309 }))
310 }
311
312 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
313 let return_type = field_args
314 .input_fields()
315 .first()
316 .map(|f| f.data_type())
317 .cloned()
318 .unwrap_or(DataType::Null);
319
320 Ok(Field::new(field_args.name(), return_type, true).into())
321 }
322
323 fn reverse_expr(&self) -> ReversedUDWF {
324 match self.kind {
325 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
326 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
327 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
328 }
329 }
330
331 fn documentation(&self) -> Option<&Documentation> {
332 match self.kind {
333 NthValueKind::First => Some(get_first_value_doc()),
334 NthValueKind::Last => Some(get_last_value_doc()),
335 NthValueKind::Nth => Some(get_nth_value_doc()),
336 }
337 }
338
339 fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
340 LimitEffect::None }
342}
343
344#[derive(Debug, Clone)]
345pub struct NthValueState {
346 pub finalized_result: Option<ScalarValue>,
355 pub kind: NthValueKind,
356}
357
358#[derive(Debug)]
359pub(crate) struct NthValueEvaluator {
360 state: NthValueState,
361 ignore_nulls: bool,
362 n: i64,
363}
364
365impl PartitionEvaluator for NthValueEvaluator {
366 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
372 let out = &state.out_col;
373 let size = out.len();
374 if self.ignore_nulls {
375 match self.state.kind {
376 NthValueKind::First => {
378 if let Some(nulls) = out.nulls() {
379 if self.state.finalized_result.is_none() {
380 if let Some(valid_index) = nulls.valid_indices().next() {
381 let result =
382 ScalarValue::try_from_array(out, valid_index)?;
383 self.state.finalized_result = Some(result);
384 } else {
385 }
387 }
388 if state.window_frame_range.start < state.window_frame_range.end {
389 state.window_frame_range.start =
390 state.window_frame_range.end - 1;
391 }
392 return Ok(());
393 } else {
394 }
396 }
397 NthValueKind::Last | NthValueKind::Nth => return Ok(()),
399 }
400 }
401 let mut buffer_size = 1;
402 let (is_prunable, is_reverse_direction) = match self.state.kind {
404 NthValueKind::First => {
405 let n_range =
406 state.window_frame_range.end - state.window_frame_range.start;
407 (n_range > 0 && size > 0, false)
408 }
409 NthValueKind::Last => (true, true),
410 NthValueKind::Nth => {
411 let n_range =
412 state.window_frame_range.end - state.window_frame_range.start;
413 match self.n.cmp(&0) {
414 Ordering::Greater => (
415 n_range >= (self.n as usize) && size > (self.n as usize),
416 false,
417 ),
418 Ordering::Less => {
419 let reverse_index = (-self.n) as usize;
420 buffer_size = reverse_index;
421 (n_range >= reverse_index, true)
423 }
424 Ordering::Equal => (false, false),
425 }
426 }
427 };
428 if is_prunable {
429 if self.state.finalized_result.is_none() && !is_reverse_direction {
430 let result = ScalarValue::try_from_array(out, size - 1)?;
431 self.state.finalized_result = Some(result);
432 }
433 state.window_frame_range.start =
434 state.window_frame_range.end.saturating_sub(buffer_size);
435 }
436 Ok(())
437 }
438
439 fn evaluate(
440 &mut self,
441 values: &[ArrayRef],
442 range: &Range<usize>,
443 ) -> Result<ScalarValue> {
444 if let Some(ref result) = self.state.finalized_result {
445 Ok(result.clone())
446 } else {
447 let arr = &values[0];
449 let n_range = range.end - range.start;
450 if n_range == 0 {
451 return ScalarValue::try_from(arr.data_type());
453 }
454 match self.valid_index(arr, range) {
455 Some(index) => ScalarValue::try_from_array(arr, index),
456 None => ScalarValue::try_from(arr.data_type()),
457 }
458 }
459 }
460
461 fn supports_bounded_execution(&self) -> bool {
462 true
463 }
464
465 fn uses_window_frame(&self) -> bool {
466 true
467 }
468}
469
470impl NthValueEvaluator {
471 fn valid_index(&self, array: &ArrayRef, range: &Range<usize>) -> Option<usize> {
472 let n_range = range.end - range.start;
473 if self.ignore_nulls {
474 let slice = array.slice(range.start, n_range);
476 if let Some(nulls) = slice.nulls()
477 && nulls.null_count() > 0
478 {
479 return self.valid_index_with_nulls(nulls, range.start);
480 }
481 }
482 match self.state.kind {
484 NthValueKind::First => Some(range.start),
485 NthValueKind::Last => Some(range.end - 1),
486 NthValueKind::Nth => match self.n.cmp(&0) {
487 Ordering::Greater => {
488 let index = (self.n as usize) - 1;
490 if index >= n_range {
491 None
493 } else {
494 Some(range.start + index)
495 }
496 }
497 Ordering::Less => {
498 let reverse_index = (-self.n) as usize;
499 if n_range < reverse_index {
500 None
502 } else {
503 Some(range.end - reverse_index)
504 }
505 }
506 Ordering::Equal => None,
507 },
508 }
509 }
510
511 fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option<usize> {
512 match self.state.kind {
513 NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset),
514 NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset),
515 NthValueKind::Nth => {
516 match self.n.cmp(&0) {
517 Ordering::Greater => {
518 let index = (self.n as usize) - 1;
520 nulls.valid_indices().nth(index).map(|idx| idx + offset)
521 }
522 Ordering::Less => {
523 let reverse_index = (-self.n) as usize;
524 let valid_indices_len = nulls.len() - nulls.null_count();
525 if reverse_index > valid_indices_len {
526 return None;
527 }
528 nulls
529 .valid_indices()
530 .nth(valid_indices_len - reverse_index)
531 .map(|idx| idx + offset)
532 }
533 Ordering::Equal => None,
534 }
535 }
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use arrow::array::*;
544 use datafusion_common::cast::as_int32_array;
545 use datafusion_physical_expr::expressions::{Column, Literal};
546 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
547 use std::sync::Arc;
548
549 fn test_i32_result(
550 expr: NthValue,
551 partition_evaluator_args: PartitionEvaluatorArgs,
552 expected: Int32Array,
553 ) -> Result<()> {
554 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
555 let values = vec![arr];
556 let mut ranges: Vec<Range<usize>> = vec![];
557 for i in 0..8 {
558 ranges.push(Range {
559 start: 0,
560 end: i + 1,
561 })
562 }
563 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
564 let result = ranges
565 .iter()
566 .map(|range| evaluator.evaluate(&values, range))
567 .collect::<Result<Vec<ScalarValue>>>()?;
568 let result = ScalarValue::iter_to_array(result.into_iter())?;
569 let result = as_int32_array(&result)?;
570 assert_eq!(expected, *result);
571 Ok(())
572 }
573
574 #[test]
575 fn first_value() -> Result<()> {
576 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
577 test_i32_result(
578 NthValue::first(),
579 PartitionEvaluatorArgs::new(
580 &[expr],
581 &[Field::new("f", DataType::Int32, true).into()],
582 false,
583 false,
584 ),
585 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
586 )
587 }
588
589 #[test]
590 fn last_value() -> Result<()> {
591 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
592 test_i32_result(
593 NthValue::last(),
594 PartitionEvaluatorArgs::new(
595 &[expr],
596 &[Field::new("f", DataType::Int32, true).into()],
597 false,
598 false,
599 ),
600 Int32Array::from(vec![
601 Some(1),
602 Some(-2),
603 Some(3),
604 Some(-4),
605 Some(5),
606 Some(-6),
607 Some(7),
608 Some(8),
609 ]),
610 )
611 }
612
613 #[test]
614 fn nth_value_1() -> Result<()> {
615 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
616 let n_value =
617 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
618
619 test_i32_result(
620 NthValue::nth(),
621 PartitionEvaluatorArgs::new(
622 &[expr, n_value],
623 &[Field::new("f", DataType::Int32, true).into()],
624 false,
625 false,
626 ),
627 Int32Array::from(vec![1; 8]),
628 )?;
629 Ok(())
630 }
631
632 #[test]
633 fn nth_value_2() -> Result<()> {
634 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
635 let n_value =
636 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
637
638 test_i32_result(
639 NthValue::nth(),
640 PartitionEvaluatorArgs::new(
641 &[expr, n_value],
642 &[Field::new("f", DataType::Int32, true).into()],
643 false,
644 false,
645 ),
646 Int32Array::from(vec![
647 None,
648 Some(-2),
649 Some(-2),
650 Some(-2),
651 Some(-2),
652 Some(-2),
653 Some(-2),
654 Some(-2),
655 ]),
656 )?;
657 Ok(())
658 }
659}