1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
30 TypeSignature, Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use field::WindowUDFFieldArgs;
36use std::any::Any;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43define_udwf_and_expr!(
44 First,
45 first_value,
46 [arg],
47 "Returns the first value in the window frame",
48 NthValue::first
49);
50define_udwf_and_expr!(
51 Last,
52 last_value,
53 [arg],
54 "Returns the last value in the window frame",
55 NthValue::last
56);
57get_or_init_udwf!(
58 NthValue,
59 nth_value,
60 "Returns the nth value in the window frame",
61 NthValue::nth
62);
63
64pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
66 nth_value_udwf().call(vec![arg, n.lit()])
67}
68
69#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
71pub enum NthValueKind {
72 First,
73 Last,
74 Nth,
75}
76
77impl NthValueKind {
78 fn name(&self) -> &'static str {
79 match self {
80 NthValueKind::First => "first_value",
81 NthValueKind::Last => "last_value",
82 NthValueKind::Nth => "nth_value",
83 }
84 }
85}
86
87#[derive(Debug, PartialEq, Eq, Hash)]
88pub struct NthValue {
89 signature: Signature,
90 kind: NthValueKind,
91}
92
93impl NthValue {
94 pub fn new(kind: NthValueKind) -> Self {
96 Self {
97 signature: Signature::one_of(
98 vec![
99 TypeSignature::Any(0),
100 TypeSignature::Any(1),
101 TypeSignature::Any(2),
102 ],
103 Volatility::Immutable,
104 ),
105 kind,
106 }
107 }
108
109 pub fn first() -> Self {
110 Self::new(NthValueKind::First)
111 }
112
113 pub fn last() -> Self {
114 Self::new(NthValueKind::Last)
115 }
116 pub fn nth() -> Self {
117 Self::new(NthValueKind::Nth)
118 }
119
120 pub fn kind(&self) -> &NthValueKind {
121 &self.kind
122 }
123}
124
125static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
126 Documentation::builder(
127 DOC_SECTION_ANALYTICAL,
128 "Returns value evaluated at the row that is the first row of the window \
129 frame.",
130 "first_value(expression)",
131 )
132 .with_argument("expression", "Expression to operate on")
133 .with_sql_example(
134 r#"
135```sql
136-- Example usage of the first_value window function:
137SELECT department,
138 employee_id,
139 salary,
140 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
141FROM employees;
142
143+-------------+-------------+--------+------------+
144| department | employee_id | salary | top_salary |
145+-------------+-------------+--------+------------+
146| Sales | 1 | 70000 | 70000 |
147| Sales | 2 | 50000 | 70000 |
148| Sales | 3 | 30000 | 70000 |
149| Engineering | 4 | 90000 | 90000 |
150| Engineering | 5 | 80000 | 90000 |
151+-------------+-------------+--------+------------+
152```
153"#,
154 )
155 .build()
156});
157
158fn get_first_value_doc() -> &'static Documentation {
159 &FIRST_VALUE_DOCUMENTATION
160}
161
162static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
163 Documentation::builder(
164 DOC_SECTION_ANALYTICAL,
165 "Returns value evaluated at the row that is the last row of the window \
166 frame.",
167 "last_value(expression)",
168 )
169 .with_argument("expression", "Expression to operate on")
170 .with_sql_example(r#"```sql
171-- SQL example of last_value:
172SELECT department,
173 employee_id,
174 salary,
175 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
176FROM employees;
177
178+-------------+-------------+--------+---------------------+
179| department | employee_id | salary | running_last_salary |
180+-------------+-------------+--------+---------------------+
181| Sales | 1 | 30000 | 30000 |
182| Sales | 2 | 50000 | 50000 |
183| Sales | 3 | 70000 | 70000 |
184| Engineering | 4 | 40000 | 40000 |
185| Engineering | 5 | 60000 | 60000 |
186+-------------+-------------+--------+---------------------+
187```
188"#)
189 .build()
190});
191
192fn get_last_value_doc() -> &'static Documentation {
193 &LAST_VALUE_DOCUMENTATION
194}
195
196static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
197 Documentation::builder(
198 DOC_SECTION_ANALYTICAL,
199 "Returns the value evaluated at the nth row of the window frame \
200 (counting from 1). Returns NULL if no such row exists.",
201 "nth_value(expression, n)",
202 )
203 .with_argument(
204 "expression",
205 "The column from which to retrieve the nth value.",
206 )
207 .with_argument(
208 "n",
209 "Integer. Specifies the row number (starting from 1) in the window frame.",
210 )
211 .with_sql_example(
212 r#"
213```sql
214-- Sample employees table:
215CREATE TABLE employees (id INT, salary INT);
216INSERT INTO employees (id, salary) VALUES
217(1, 30000),
218(2, 40000),
219(3, 50000),
220(4, 60000),
221(5, 70000);
222
223-- Example usage of nth_value:
224SELECT nth_value(salary, 2) OVER (
225 ORDER BY salary
226 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
227) AS nth_value
228FROM employees;
229
230+-----------+
231| nth_value |
232+-----------+
233| 40000 |
234| 40000 |
235| 40000 |
236| 40000 |
237| 40000 |
238+-----------+
239```
240"#,
241 )
242 .build()
243});
244
245fn get_nth_value_doc() -> &'static Documentation {
246 &NTH_VALUE_DOCUMENTATION
247}
248
249impl WindowUDFImpl for NthValue {
250 fn as_any(&self) -> &dyn Any {
251 self
252 }
253
254 fn name(&self) -> &str {
255 self.kind.name()
256 }
257
258 fn signature(&self) -> &Signature {
259 &self.signature
260 }
261
262 fn partition_evaluator(
263 &self,
264 partition_evaluator_args: PartitionEvaluatorArgs,
265 ) -> Result<Box<dyn PartitionEvaluator>> {
266 let state = NthValueState {
267 finalized_result: None,
268 kind: self.kind,
269 };
270
271 if !matches!(self.kind, NthValueKind::Nth) {
272 return Ok(Box::new(NthValueEvaluator {
273 state,
274 ignore_nulls: partition_evaluator_args.ignore_nulls(),
275 n: 0,
276 }));
277 }
278
279 let n =
280 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
281 .map_err(|_e| {
282 exec_datafusion_err!(
283 "Expected a signed integer literal for the second argument of nth_value")
284 })?
285 .map(get_signed_integer)
286 {
287 Some(Ok(n)) => {
288 if partition_evaluator_args.is_reversed() {
289 -n
290 } else {
291 n
292 }
293 }
294 _ => {
295 return exec_err!(
296 "Expected a signed integer literal for the second argument of nth_value"
297 )
298 }
299 };
300
301 Ok(Box::new(NthValueEvaluator {
302 state,
303 ignore_nulls: partition_evaluator_args.ignore_nulls(),
304 n,
305 }))
306 }
307
308 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
309 let return_type = field_args
310 .input_fields()
311 .first()
312 .map(|f| f.data_type())
313 .cloned()
314 .unwrap_or(DataType::Null);
315
316 Ok(Field::new(field_args.name(), return_type, true).into())
317 }
318
319 fn reverse_expr(&self) -> ReversedUDWF {
320 match self.kind {
321 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
322 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
323 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
324 }
325 }
326
327 fn documentation(&self) -> Option<&Documentation> {
328 match self.kind {
329 NthValueKind::First => Some(get_first_value_doc()),
330 NthValueKind::Last => Some(get_last_value_doc()),
331 NthValueKind::Nth => Some(get_nth_value_doc()),
332 }
333 }
334
335 fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
336 LimitEffect::None }
338}
339
340#[derive(Debug, Clone)]
341pub struct NthValueState {
342 pub finalized_result: Option<ScalarValue>,
351 pub kind: NthValueKind,
352}
353
354#[derive(Debug)]
355pub(crate) struct NthValueEvaluator {
356 state: NthValueState,
357 ignore_nulls: bool,
358 n: i64,
359}
360
361impl PartitionEvaluator for NthValueEvaluator {
362 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
368 let out = &state.out_col;
369 let size = out.len();
370 let mut buffer_size = 1;
371 let (is_prunable, is_reverse_direction) = match self.state.kind {
373 NthValueKind::First => {
374 let n_range =
375 state.window_frame_range.end - state.window_frame_range.start;
376 (n_range > 0 && size > 0, false)
377 }
378 NthValueKind::Last => (true, true),
379 NthValueKind::Nth => {
380 let n_range =
381 state.window_frame_range.end - state.window_frame_range.start;
382 match self.n.cmp(&0) {
383 Ordering::Greater => (
384 n_range >= (self.n as usize) && size > (self.n as usize),
385 false,
386 ),
387 Ordering::Less => {
388 let reverse_index = (-self.n) as usize;
389 buffer_size = reverse_index;
390 (n_range >= reverse_index, true)
392 }
393 Ordering::Equal => (false, false),
394 }
395 }
396 };
397 if is_prunable && !self.ignore_nulls {
399 if self.state.finalized_result.is_none() && !is_reverse_direction {
400 let result = ScalarValue::try_from_array(out, size - 1)?;
401 self.state.finalized_result = Some(result);
402 }
403 state.window_frame_range.start =
404 state.window_frame_range.end.saturating_sub(buffer_size);
405 }
406 Ok(())
407 }
408
409 fn evaluate(
410 &mut self,
411 values: &[ArrayRef],
412 range: &Range<usize>,
413 ) -> Result<ScalarValue> {
414 if let Some(ref result) = self.state.finalized_result {
415 Ok(result.clone())
416 } else {
417 let arr = &values[0];
419 let n_range = range.end - range.start;
420 if n_range == 0 {
421 return ScalarValue::try_from(arr.data_type());
423 }
424
425 let valid_indices = if self.ignore_nulls {
427 let slice = arr.slice(range.start, n_range);
429 match slice.nulls() {
430 Some(nulls) => {
431 let valid_indices = nulls
432 .valid_indices()
433 .map(|idx| {
434 idx + range.start
436 })
437 .collect::<Vec<_>>();
438 if valid_indices.is_empty() {
439 return ScalarValue::try_from(arr.data_type());
441 }
442 Some(valid_indices)
443 }
444 None => None,
445 }
446 } else {
447 None
448 };
449 match self.state.kind {
450 NthValueKind::First => {
451 if let Some(valid_indices) = &valid_indices {
452 ScalarValue::try_from_array(arr, valid_indices[0])
453 } else {
454 ScalarValue::try_from_array(arr, range.start)
455 }
456 }
457 NthValueKind::Last => {
458 if let Some(valid_indices) = &valid_indices {
459 ScalarValue::try_from_array(
460 arr,
461 valid_indices[valid_indices.len() - 1],
462 )
463 } else {
464 ScalarValue::try_from_array(arr, range.end - 1)
465 }
466 }
467 NthValueKind::Nth => {
468 match self.n.cmp(&0) {
469 Ordering::Greater => {
470 let index = (self.n as usize) - 1;
472 if index >= n_range {
473 ScalarValue::try_from(arr.data_type())
475 } else if let Some(valid_indices) = valid_indices {
476 if index >= valid_indices.len() {
477 return ScalarValue::try_from(arr.data_type());
478 }
479 ScalarValue::try_from_array(&arr, valid_indices[index])
480 } else {
481 ScalarValue::try_from_array(arr, range.start + index)
482 }
483 }
484 Ordering::Less => {
485 let reverse_index = (-self.n) as usize;
486 if n_range < reverse_index {
487 ScalarValue::try_from(arr.data_type())
489 } else if let Some(valid_indices) = valid_indices {
490 if reverse_index > valid_indices.len() {
491 return ScalarValue::try_from(arr.data_type());
492 }
493 let new_index =
494 valid_indices[valid_indices.len() - reverse_index];
495 ScalarValue::try_from_array(&arr, new_index)
496 } else {
497 ScalarValue::try_from_array(
498 arr,
499 range.start + n_range - reverse_index,
500 )
501 }
502 }
503 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
504 }
505 }
506 }
507 }
508 }
509
510 fn supports_bounded_execution(&self) -> bool {
511 true
512 }
513
514 fn uses_window_frame(&self) -> bool {
515 true
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use arrow::array::*;
523 use datafusion_common::cast::as_int32_array;
524 use datafusion_physical_expr::expressions::{Column, Literal};
525 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
526 use std::sync::Arc;
527
528 fn test_i32_result(
529 expr: NthValue,
530 partition_evaluator_args: PartitionEvaluatorArgs,
531 expected: Int32Array,
532 ) -> Result<()> {
533 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
534 let values = vec![arr];
535 let mut ranges: Vec<Range<usize>> = vec![];
536 for i in 0..8 {
537 ranges.push(Range {
538 start: 0,
539 end: i + 1,
540 })
541 }
542 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
543 let result = ranges
544 .iter()
545 .map(|range| evaluator.evaluate(&values, range))
546 .collect::<Result<Vec<ScalarValue>>>()?;
547 let result = ScalarValue::iter_to_array(result.into_iter())?;
548 let result = as_int32_array(&result)?;
549 assert_eq!(expected, *result);
550 Ok(())
551 }
552
553 #[test]
554 fn first_value() -> Result<()> {
555 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
556 test_i32_result(
557 NthValue::first(),
558 PartitionEvaluatorArgs::new(
559 &[expr],
560 &[Field::new("f", DataType::Int32, true).into()],
561 false,
562 false,
563 ),
564 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
565 )
566 }
567
568 #[test]
569 fn last_value() -> Result<()> {
570 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
571 test_i32_result(
572 NthValue::last(),
573 PartitionEvaluatorArgs::new(
574 &[expr],
575 &[Field::new("f", DataType::Int32, true).into()],
576 false,
577 false,
578 ),
579 Int32Array::from(vec![
580 Some(1),
581 Some(-2),
582 Some(3),
583 Some(-4),
584 Some(5),
585 Some(-6),
586 Some(7),
587 Some(8),
588 ]),
589 )
590 }
591
592 #[test]
593 fn nth_value_1() -> Result<()> {
594 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
595 let n_value =
596 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
597
598 test_i32_result(
599 NthValue::nth(),
600 PartitionEvaluatorArgs::new(
601 &[expr, n_value],
602 &[Field::new("f", DataType::Int32, true).into()],
603 false,
604 false,
605 ),
606 Int32Array::from(vec![1; 8]),
607 )?;
608 Ok(())
609 }
610
611 #[test]
612 fn nth_value_2() -> Result<()> {
613 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
614 let n_value =
615 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
616
617 test_i32_result(
618 NthValue::nth(),
619 PartitionEvaluatorArgs::new(
620 &[expr, n_value],
621 &[Field::new("f", DataType::Int32, true).into()],
622 false,
623 false,
624 ),
625 Int32Array::from(vec![
626 None,
627 Some(-2),
628 Some(-2),
629 Some(-2),
630 Some(-2),
631 Some(-2),
632 Some(-2),
633 Some(-2),
634 ]),
635 )?;
636 Ok(())
637 }
638}