1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err};
26use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
30 TypeSignature, Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use field::WindowUDFFieldArgs;
36use std::any::Any;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43define_udwf_and_expr!(
44 First,
45 first_value,
46 [arg],
47 "Returns the first value in the window frame",
48 NthValue::first
49);
50define_udwf_and_expr!(
51 Last,
52 last_value,
53 [arg],
54 "Returns the last value in the window frame",
55 NthValue::last
56);
57get_or_init_udwf!(
58 NthValue,
59 nth_value,
60 "Returns the nth value in the window frame",
61 NthValue::nth
62);
63
64pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
66 nth_value_udwf().call(vec![arg, n.lit()])
67}
68
69#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
71pub enum NthValueKind {
72 First,
73 Last,
74 Nth,
75}
76
77impl NthValueKind {
78 fn name(&self) -> &'static str {
79 match self {
80 NthValueKind::First => "first_value",
81 NthValueKind::Last => "last_value",
82 NthValueKind::Nth => "nth_value",
83 }
84 }
85}
86
87#[derive(Debug, PartialEq, Eq, Hash)]
88pub struct NthValue {
89 signature: Signature,
90 kind: NthValueKind,
91}
92
93impl NthValue {
94 pub fn new(kind: NthValueKind) -> Self {
96 Self {
97 signature: Signature::one_of(
98 vec![
99 TypeSignature::Any(0),
100 TypeSignature::Any(1),
101 TypeSignature::Any(2),
102 ],
103 Volatility::Immutable,
104 ),
105 kind,
106 }
107 }
108
109 pub fn first() -> Self {
110 Self::new(NthValueKind::First)
111 }
112
113 pub fn last() -> Self {
114 Self::new(NthValueKind::Last)
115 }
116 pub fn nth() -> Self {
117 Self::new(NthValueKind::Nth)
118 }
119
120 pub fn kind(&self) -> &NthValueKind {
121 &self.kind
122 }
123}
124
125static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
126 Documentation::builder(
127 DOC_SECTION_ANALYTICAL,
128 "Returns value evaluated at the row that is the first row of the window \
129 frame.",
130 "first_value(expression)",
131 )
132 .with_argument("expression", "Expression to operate on")
133 .with_sql_example(
134 r#"
135```sql
136-- Example usage of the first_value window function:
137SELECT department,
138 employee_id,
139 salary,
140 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
141FROM employees;
142
143+-------------+-------------+--------+------------+
144| department | employee_id | salary | top_salary |
145+-------------+-------------+--------+------------+
146| Sales | 1 | 70000 | 70000 |
147| Sales | 2 | 50000 | 70000 |
148| Sales | 3 | 30000 | 70000 |
149| Engineering | 4 | 90000 | 90000 |
150| Engineering | 5 | 80000 | 90000 |
151+-------------+-------------+--------+------------+
152```
153"#,
154 )
155 .build()
156});
157
158fn get_first_value_doc() -> &'static Documentation {
159 &FIRST_VALUE_DOCUMENTATION
160}
161
162static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
163 Documentation::builder(
164 DOC_SECTION_ANALYTICAL,
165 "Returns value evaluated at the row that is the last row of the window \
166 frame.",
167 "last_value(expression)",
168 )
169 .with_argument("expression", "Expression to operate on")
170 .with_sql_example(r#"```sql
171-- SQL example of last_value:
172SELECT department,
173 employee_id,
174 salary,
175 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
176FROM employees;
177
178+-------------+-------------+--------+---------------------+
179| department | employee_id | salary | running_last_salary |
180+-------------+-------------+--------+---------------------+
181| Sales | 1 | 30000 | 30000 |
182| Sales | 2 | 50000 | 50000 |
183| Sales | 3 | 70000 | 70000 |
184| Engineering | 4 | 40000 | 40000 |
185| Engineering | 5 | 60000 | 60000 |
186+-------------+-------------+--------+---------------------+
187```
188"#)
189 .build()
190});
191
192fn get_last_value_doc() -> &'static Documentation {
193 &LAST_VALUE_DOCUMENTATION
194}
195
196static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
197 Documentation::builder(
198 DOC_SECTION_ANALYTICAL,
199 "Returns the value evaluated at the nth row of the window frame \
200 (counting from 1). Returns NULL if no such row exists.",
201 "nth_value(expression, n)",
202 )
203 .with_argument(
204 "expression",
205 "The column from which to retrieve the nth value.",
206 )
207 .with_argument(
208 "n",
209 "Integer. Specifies the row number (starting from 1) in the window frame.",
210 )
211 .with_sql_example(
212 r#"
213```sql
214-- Sample employees table:
215CREATE TABLE employees (id INT, salary INT);
216INSERT INTO employees (id, salary) VALUES
217(1, 30000),
218(2, 40000),
219(3, 50000),
220(4, 60000),
221(5, 70000);
222
223-- Example usage of nth_value:
224SELECT nth_value(salary, 2) OVER (
225 ORDER BY salary
226 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
227) AS nth_value
228FROM employees;
229
230+-----------+
231| nth_value |
232+-----------+
233| 40000 |
234| 40000 |
235| 40000 |
236| 40000 |
237| 40000 |
238+-----------+
239```
240"#,
241 )
242 .build()
243});
244
245fn get_nth_value_doc() -> &'static Documentation {
246 &NTH_VALUE_DOCUMENTATION
247}
248
249impl WindowUDFImpl for NthValue {
250 fn as_any(&self) -> &dyn Any {
251 self
252 }
253
254 fn name(&self) -> &str {
255 self.kind.name()
256 }
257
258 fn signature(&self) -> &Signature {
259 &self.signature
260 }
261
262 fn partition_evaluator(
263 &self,
264 partition_evaluator_args: PartitionEvaluatorArgs,
265 ) -> Result<Box<dyn PartitionEvaluator>> {
266 let state = NthValueState {
267 finalized_result: None,
268 kind: self.kind,
269 };
270
271 if !matches!(self.kind, NthValueKind::Nth) {
272 return Ok(Box::new(NthValueEvaluator {
273 state,
274 ignore_nulls: partition_evaluator_args.ignore_nulls(),
275 n: 0,
276 }));
277 }
278
279 let n = match get_scalar_value_from_args(
280 partition_evaluator_args.input_exprs(),
281 1,
282 )
283 .map_err(|_e| {
284 exec_datafusion_err!(
285 "Expected a signed integer literal for the second argument of nth_value"
286 )
287 })?
288 .map(|v| get_signed_integer(&v))
289 {
290 Some(Ok(n)) => {
291 if partition_evaluator_args.is_reversed() {
292 -n
293 } else {
294 n
295 }
296 }
297 _ => {
298 return exec_err!(
299 "Expected a signed integer literal for the second argument of nth_value"
300 );
301 }
302 };
303
304 Ok(Box::new(NthValueEvaluator {
305 state,
306 ignore_nulls: partition_evaluator_args.ignore_nulls(),
307 n,
308 }))
309 }
310
311 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
312 let return_type = field_args
313 .input_fields()
314 .first()
315 .map(|f| f.data_type())
316 .cloned()
317 .unwrap_or(DataType::Null);
318
319 Ok(Field::new(field_args.name(), return_type, true).into())
320 }
321
322 fn reverse_expr(&self) -> ReversedUDWF {
323 match self.kind {
324 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
325 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
326 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
327 }
328 }
329
330 fn documentation(&self) -> Option<&Documentation> {
331 match self.kind {
332 NthValueKind::First => Some(get_first_value_doc()),
333 NthValueKind::Last => Some(get_last_value_doc()),
334 NthValueKind::Nth => Some(get_nth_value_doc()),
335 }
336 }
337
338 fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
339 LimitEffect::None }
341}
342
343#[derive(Debug, Clone)]
344pub struct NthValueState {
345 pub finalized_result: Option<ScalarValue>,
354 pub kind: NthValueKind,
355}
356
357#[derive(Debug)]
358pub(crate) struct NthValueEvaluator {
359 state: NthValueState,
360 ignore_nulls: bool,
361 n: i64,
362}
363
364impl PartitionEvaluator for NthValueEvaluator {
365 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
371 let out = &state.out_col;
372 let size = out.len();
373 let mut buffer_size = 1;
374 let (is_prunable, is_reverse_direction) = match self.state.kind {
376 NthValueKind::First => {
377 let n_range =
378 state.window_frame_range.end - state.window_frame_range.start;
379 (n_range > 0 && size > 0, false)
380 }
381 NthValueKind::Last => (true, true),
382 NthValueKind::Nth => {
383 let n_range =
384 state.window_frame_range.end - state.window_frame_range.start;
385 match self.n.cmp(&0) {
386 Ordering::Greater => (
387 n_range >= (self.n as usize) && size > (self.n as usize),
388 false,
389 ),
390 Ordering::Less => {
391 let reverse_index = (-self.n) as usize;
392 buffer_size = reverse_index;
393 (n_range >= reverse_index, true)
395 }
396 Ordering::Equal => (false, false),
397 }
398 }
399 };
400 if is_prunable && !self.ignore_nulls {
402 if self.state.finalized_result.is_none() && !is_reverse_direction {
403 let result = ScalarValue::try_from_array(out, size - 1)?;
404 self.state.finalized_result = Some(result);
405 }
406 state.window_frame_range.start =
407 state.window_frame_range.end.saturating_sub(buffer_size);
408 }
409 Ok(())
410 }
411
412 fn evaluate(
413 &mut self,
414 values: &[ArrayRef],
415 range: &Range<usize>,
416 ) -> Result<ScalarValue> {
417 if let Some(ref result) = self.state.finalized_result {
418 Ok(result.clone())
419 } else {
420 let arr = &values[0];
422 let n_range = range.end - range.start;
423 if n_range == 0 {
424 return ScalarValue::try_from(arr.data_type());
426 }
427
428 let valid_indices = if self.ignore_nulls {
430 let slice = arr.slice(range.start, n_range);
432 match slice.nulls() {
433 Some(nulls) => {
434 let valid_indices = nulls
435 .valid_indices()
436 .map(|idx| {
437 idx + range.start
439 })
440 .collect::<Vec<_>>();
441 if valid_indices.is_empty() {
442 return ScalarValue::try_from(arr.data_type());
444 }
445 Some(valid_indices)
446 }
447 None => None,
448 }
449 } else {
450 None
451 };
452 match self.state.kind {
453 NthValueKind::First => {
454 if let Some(valid_indices) = &valid_indices {
455 ScalarValue::try_from_array(arr, valid_indices[0])
456 } else {
457 ScalarValue::try_from_array(arr, range.start)
458 }
459 }
460 NthValueKind::Last => {
461 if let Some(valid_indices) = &valid_indices {
462 ScalarValue::try_from_array(
463 arr,
464 valid_indices[valid_indices.len() - 1],
465 )
466 } else {
467 ScalarValue::try_from_array(arr, range.end - 1)
468 }
469 }
470 NthValueKind::Nth => {
471 match self.n.cmp(&0) {
472 Ordering::Greater => {
473 let index = (self.n as usize) - 1;
475 if index >= n_range {
476 ScalarValue::try_from(arr.data_type())
478 } else if let Some(valid_indices) = valid_indices {
479 if index >= valid_indices.len() {
480 return ScalarValue::try_from(arr.data_type());
481 }
482 ScalarValue::try_from_array(&arr, valid_indices[index])
483 } else {
484 ScalarValue::try_from_array(arr, range.start + index)
485 }
486 }
487 Ordering::Less => {
488 let reverse_index = (-self.n) as usize;
489 if n_range < reverse_index {
490 ScalarValue::try_from(arr.data_type())
492 } else if let Some(valid_indices) = valid_indices {
493 if reverse_index > valid_indices.len() {
494 return ScalarValue::try_from(arr.data_type());
495 }
496 let new_index =
497 valid_indices[valid_indices.len() - reverse_index];
498 ScalarValue::try_from_array(&arr, new_index)
499 } else {
500 ScalarValue::try_from_array(
501 arr,
502 range.start + n_range - reverse_index,
503 )
504 }
505 }
506 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
507 }
508 }
509 }
510 }
511 }
512
513 fn supports_bounded_execution(&self) -> bool {
514 true
515 }
516
517 fn uses_window_frame(&self) -> bool {
518 true
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use arrow::array::*;
526 use datafusion_common::cast::as_int32_array;
527 use datafusion_physical_expr::expressions::{Column, Literal};
528 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
529 use std::sync::Arc;
530
531 fn test_i32_result(
532 expr: NthValue,
533 partition_evaluator_args: PartitionEvaluatorArgs,
534 expected: Int32Array,
535 ) -> Result<()> {
536 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
537 let values = vec![arr];
538 let mut ranges: Vec<Range<usize>> = vec![];
539 for i in 0..8 {
540 ranges.push(Range {
541 start: 0,
542 end: i + 1,
543 })
544 }
545 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
546 let result = ranges
547 .iter()
548 .map(|range| evaluator.evaluate(&values, range))
549 .collect::<Result<Vec<ScalarValue>>>()?;
550 let result = ScalarValue::iter_to_array(result.into_iter())?;
551 let result = as_int32_array(&result)?;
552 assert_eq!(expected, *result);
553 Ok(())
554 }
555
556 #[test]
557 fn first_value() -> Result<()> {
558 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
559 test_i32_result(
560 NthValue::first(),
561 PartitionEvaluatorArgs::new(
562 &[expr],
563 &[Field::new("f", DataType::Int32, true).into()],
564 false,
565 false,
566 ),
567 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
568 )
569 }
570
571 #[test]
572 fn last_value() -> Result<()> {
573 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
574 test_i32_result(
575 NthValue::last(),
576 PartitionEvaluatorArgs::new(
577 &[expr],
578 &[Field::new("f", DataType::Int32, true).into()],
579 false,
580 false,
581 ),
582 Int32Array::from(vec![
583 Some(1),
584 Some(-2),
585 Some(3),
586 Some(-4),
587 Some(5),
588 Some(-6),
589 Some(7),
590 Some(8),
591 ]),
592 )
593 }
594
595 #[test]
596 fn nth_value_1() -> Result<()> {
597 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
598 let n_value =
599 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
600
601 test_i32_result(
602 NthValue::nth(),
603 PartitionEvaluatorArgs::new(
604 &[expr, n_value],
605 &[Field::new("f", DataType::Int32, true).into()],
606 false,
607 false,
608 ),
609 Int32Array::from(vec![1; 8]),
610 )?;
611 Ok(())
612 }
613
614 #[test]
615 fn nth_value_2() -> Result<()> {
616 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
617 let n_value =
618 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
619
620 test_i32_result(
621 NthValue::nth(),
622 PartitionEvaluatorArgs::new(
623 &[expr, n_value],
624 &[Field::new("f", DataType::Int32, true).into()],
625 false,
626 false,
627 ),
628 Int32Array::from(vec![
629 None,
630 Some(-2),
631 Some(-2),
632 Some(-2),
633 Some(-2),
634 Some(-2),
635 Some(-2),
636 Some(-2),
637 ]),
638 )?;
639 Ok(())
640 }
641}