1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
30 Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::hash::Hash;
39use std::ops::Range;
40use std::sync::LazyLock;
41
42get_or_init_udwf!(
43 First,
44 first_value,
45 "returns the first value in the window frame",
46 NthValue::first
47);
48get_or_init_udwf!(
49 Last,
50 last_value,
51 "returns the last value in the window frame",
52 NthValue::last
53);
54get_or_init_udwf!(
55 NthValue,
56 nth_value,
57 "returns the nth value in the window frame",
58 NthValue::nth
59);
60
61pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
64 first_value_udwf().call(vec![arg])
65}
66
67pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
70 last_value_udwf().call(vec![arg])
71}
72
73pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
76 nth_value_udwf().call(vec![arg, n.lit()])
77}
78
79#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
81pub enum NthValueKind {
82 First,
83 Last,
84 Nth,
85}
86
87impl NthValueKind {
88 fn name(&self) -> &'static str {
89 match self {
90 NthValueKind::First => "first_value",
91 NthValueKind::Last => "last_value",
92 NthValueKind::Nth => "nth_value",
93 }
94 }
95}
96
97#[derive(Debug, PartialEq, Eq, Hash)]
98pub struct NthValue {
99 signature: Signature,
100 kind: NthValueKind,
101}
102
103impl NthValue {
104 pub fn new(kind: NthValueKind) -> Self {
106 Self {
107 signature: Signature::one_of(
108 vec![
109 TypeSignature::Any(0),
110 TypeSignature::Any(1),
111 TypeSignature::Any(2),
112 ],
113 Volatility::Immutable,
114 ),
115 kind,
116 }
117 }
118
119 pub fn first() -> Self {
120 Self::new(NthValueKind::First)
121 }
122
123 pub fn last() -> Self {
124 Self::new(NthValueKind::Last)
125 }
126 pub fn nth() -> Self {
127 Self::new(NthValueKind::Nth)
128 }
129}
130
131static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
132 Documentation::builder(
133 DOC_SECTION_ANALYTICAL,
134 "Returns value evaluated at the row that is the first row of the window \
135 frame.",
136 "first_value(expression)",
137 )
138 .with_argument("expression", "Expression to operate on")
139 .with_sql_example(
140 r#"
141```sql
142-- Example usage of the first_value window function:
143SELECT department,
144 employee_id,
145 salary,
146 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
147FROM employees;
148
149+-------------+-------------+--------+------------+
150| department | employee_id | salary | top_salary |
151+-------------+-------------+--------+------------+
152| Sales | 1 | 70000 | 70000 |
153| Sales | 2 | 50000 | 70000 |
154| Sales | 3 | 30000 | 70000 |
155| Engineering | 4 | 90000 | 90000 |
156| Engineering | 5 | 80000 | 90000 |
157+-------------+-------------+--------+------------+
158```
159"#,
160 )
161 .build()
162});
163
164fn get_first_value_doc() -> &'static Documentation {
165 &FIRST_VALUE_DOCUMENTATION
166}
167
168static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
169 Documentation::builder(
170 DOC_SECTION_ANALYTICAL,
171 "Returns value evaluated at the row that is the last row of the window \
172 frame.",
173 "last_value(expression)",
174 )
175 .with_argument("expression", "Expression to operate on")
176 .with_sql_example(r#"```sql
177-- SQL example of last_value:
178SELECT department,
179 employee_id,
180 salary,
181 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
182FROM employees;
183
184+-------------+-------------+--------+---------------------+
185| department | employee_id | salary | running_last_salary |
186+-------------+-------------+--------+---------------------+
187| Sales | 1 | 30000 | 30000 |
188| Sales | 2 | 50000 | 50000 |
189| Sales | 3 | 70000 | 70000 |
190| Engineering | 4 | 40000 | 40000 |
191| Engineering | 5 | 60000 | 60000 |
192+-------------+-------------+--------+---------------------+
193```
194"#)
195 .build()
196});
197
198fn get_last_value_doc() -> &'static Documentation {
199 &LAST_VALUE_DOCUMENTATION
200}
201
202static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
203 Documentation::builder(
204 DOC_SECTION_ANALYTICAL,
205 "Returns the value evaluated at the nth row of the window frame \
206 (counting from 1). Returns NULL if no such row exists.",
207 "nth_value(expression, n)",
208 )
209 .with_argument(
210 "expression",
211 "The column from which to retrieve the nth value.",
212 )
213 .with_argument(
214 "n",
215 "Integer. Specifies the row number (starting from 1) in the window frame.",
216 )
217 .with_sql_example(
218 r#"
219```sql
220-- Sample employees table:
221CREATE TABLE employees (id INT, salary INT);
222INSERT INTO employees (id, salary) VALUES
223(1, 30000),
224(2, 40000),
225(3, 50000),
226(4, 60000),
227(5, 70000);
228
229-- Example usage of nth_value:
230SELECT nth_value(salary, 2) OVER (
231 ORDER BY salary
232 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
233) AS nth_value
234FROM employees;
235
236+-----------+
237| nth_value |
238+-----------+
239| 40000 |
240| 40000 |
241| 40000 |
242| 40000 |
243| 40000 |
244+-----------+
245```
246"#,
247 )
248 .build()
249});
250
251fn get_nth_value_doc() -> &'static Documentation {
252 &NTH_VALUE_DOCUMENTATION
253}
254
255impl WindowUDFImpl for NthValue {
256 fn as_any(&self) -> &dyn Any {
257 self
258 }
259
260 fn name(&self) -> &str {
261 self.kind.name()
262 }
263
264 fn signature(&self) -> &Signature {
265 &self.signature
266 }
267
268 fn partition_evaluator(
269 &self,
270 partition_evaluator_args: PartitionEvaluatorArgs,
271 ) -> Result<Box<dyn PartitionEvaluator>> {
272 let state = NthValueState {
273 finalized_result: None,
274 kind: self.kind,
275 };
276
277 if !matches!(self.kind, NthValueKind::Nth) {
278 return Ok(Box::new(NthValueEvaluator {
279 state,
280 ignore_nulls: partition_evaluator_args.ignore_nulls(),
281 n: 0,
282 }));
283 }
284
285 let n =
286 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
287 .map_err(|_e| {
288 exec_datafusion_err!(
289 "Expected a signed integer literal for the second argument of nth_value")
290 })?
291 .map(get_signed_integer)
292 {
293 Some(Ok(n)) => {
294 if partition_evaluator_args.is_reversed() {
295 -n
296 } else {
297 n
298 }
299 }
300 _ => {
301 return exec_err!(
302 "Expected a signed integer literal for the second argument of nth_value"
303 )
304 }
305 };
306
307 Ok(Box::new(NthValueEvaluator {
308 state,
309 ignore_nulls: partition_evaluator_args.ignore_nulls(),
310 n,
311 }))
312 }
313
314 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
315 let return_type = field_args
316 .input_fields()
317 .first()
318 .map(|f| f.data_type())
319 .cloned()
320 .unwrap_or(DataType::Null);
321
322 Ok(Field::new(field_args.name(), return_type, true).into())
323 }
324
325 fn reverse_expr(&self) -> ReversedUDWF {
326 match self.kind {
327 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
328 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
329 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
330 }
331 }
332
333 fn documentation(&self) -> Option<&Documentation> {
334 match self.kind {
335 NthValueKind::First => Some(get_first_value_doc()),
336 NthValueKind::Last => Some(get_last_value_doc()),
337 NthValueKind::Nth => Some(get_nth_value_doc()),
338 }
339 }
340}
341
342#[derive(Debug, Clone)]
343pub struct NthValueState {
344 pub finalized_result: Option<ScalarValue>,
353 pub kind: NthValueKind,
354}
355
356#[derive(Debug)]
357pub(crate) struct NthValueEvaluator {
358 state: NthValueState,
359 ignore_nulls: bool,
360 n: i64,
361}
362
363impl PartitionEvaluator for NthValueEvaluator {
364 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
370 let out = &state.out_col;
371 let size = out.len();
372 let mut buffer_size = 1;
373 let (is_prunable, is_reverse_direction) = match self.state.kind {
375 NthValueKind::First => {
376 let n_range =
377 state.window_frame_range.end - state.window_frame_range.start;
378 (n_range > 0 && size > 0, false)
379 }
380 NthValueKind::Last => (true, true),
381 NthValueKind::Nth => {
382 let n_range =
383 state.window_frame_range.end - state.window_frame_range.start;
384 match self.n.cmp(&0) {
385 Ordering::Greater => (
386 n_range >= (self.n as usize) && size > (self.n as usize),
387 false,
388 ),
389 Ordering::Less => {
390 let reverse_index = (-self.n) as usize;
391 buffer_size = reverse_index;
392 (n_range >= reverse_index, true)
394 }
395 Ordering::Equal => (false, false),
396 }
397 }
398 };
399 if is_prunable && !self.ignore_nulls {
401 if self.state.finalized_result.is_none() && !is_reverse_direction {
402 let result = ScalarValue::try_from_array(out, size - 1)?;
403 self.state.finalized_result = Some(result);
404 }
405 state.window_frame_range.start =
406 state.window_frame_range.end.saturating_sub(buffer_size);
407 }
408 Ok(())
409 }
410
411 fn evaluate(
412 &mut self,
413 values: &[ArrayRef],
414 range: &Range<usize>,
415 ) -> Result<ScalarValue> {
416 if let Some(ref result) = self.state.finalized_result {
417 Ok(result.clone())
418 } else {
419 let arr = &values[0];
421 let n_range = range.end - range.start;
422 if n_range == 0 {
423 return ScalarValue::try_from(arr.data_type());
425 }
426
427 let valid_indices = if self.ignore_nulls {
429 let slice = arr.slice(range.start, n_range);
431 match slice.nulls() {
432 Some(nulls) => {
433 let valid_indices = nulls
434 .valid_indices()
435 .map(|idx| {
436 idx + range.start
438 })
439 .collect::<Vec<_>>();
440 if valid_indices.is_empty() {
441 return ScalarValue::try_from(arr.data_type());
443 }
444 Some(valid_indices)
445 }
446 None => None,
447 }
448 } else {
449 None
450 };
451 match self.state.kind {
452 NthValueKind::First => {
453 if let Some(valid_indices) = &valid_indices {
454 ScalarValue::try_from_array(arr, valid_indices[0])
455 } else {
456 ScalarValue::try_from_array(arr, range.start)
457 }
458 }
459 NthValueKind::Last => {
460 if let Some(valid_indices) = &valid_indices {
461 ScalarValue::try_from_array(
462 arr,
463 valid_indices[valid_indices.len() - 1],
464 )
465 } else {
466 ScalarValue::try_from_array(arr, range.end - 1)
467 }
468 }
469 NthValueKind::Nth => {
470 match self.n.cmp(&0) {
471 Ordering::Greater => {
472 let index = (self.n as usize) - 1;
474 if index >= n_range {
475 ScalarValue::try_from(arr.data_type())
477 } else if let Some(valid_indices) = valid_indices {
478 if index >= valid_indices.len() {
479 return ScalarValue::try_from(arr.data_type());
480 }
481 ScalarValue::try_from_array(&arr, valid_indices[index])
482 } else {
483 ScalarValue::try_from_array(arr, range.start + index)
484 }
485 }
486 Ordering::Less => {
487 let reverse_index = (-self.n) as usize;
488 if n_range < reverse_index {
489 ScalarValue::try_from(arr.data_type())
491 } else if let Some(valid_indices) = valid_indices {
492 if reverse_index > valid_indices.len() {
493 return ScalarValue::try_from(arr.data_type());
494 }
495 let new_index =
496 valid_indices[valid_indices.len() - reverse_index];
497 ScalarValue::try_from_array(&arr, new_index)
498 } else {
499 ScalarValue::try_from_array(
500 arr,
501 range.start + n_range - reverse_index,
502 )
503 }
504 }
505 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
506 }
507 }
508 }
509 }
510 }
511
512 fn supports_bounded_execution(&self) -> bool {
513 true
514 }
515
516 fn uses_window_frame(&self) -> bool {
517 true
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use arrow::array::*;
525 use datafusion_common::cast::as_int32_array;
526 use datafusion_physical_expr::expressions::{Column, Literal};
527 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
528 use std::sync::Arc;
529
530 fn test_i32_result(
531 expr: NthValue,
532 partition_evaluator_args: PartitionEvaluatorArgs,
533 expected: Int32Array,
534 ) -> Result<()> {
535 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
536 let values = vec![arr];
537 let mut ranges: Vec<Range<usize>> = vec![];
538 for i in 0..8 {
539 ranges.push(Range {
540 start: 0,
541 end: i + 1,
542 })
543 }
544 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
545 let result = ranges
546 .iter()
547 .map(|range| evaluator.evaluate(&values, range))
548 .collect::<Result<Vec<ScalarValue>>>()?;
549 let result = ScalarValue::iter_to_array(result.into_iter())?;
550 let result = as_int32_array(&result)?;
551 assert_eq!(expected, *result);
552 Ok(())
553 }
554
555 #[test]
556 fn first_value() -> Result<()> {
557 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
558 test_i32_result(
559 NthValue::first(),
560 PartitionEvaluatorArgs::new(
561 &[expr],
562 &[Field::new("f", DataType::Int32, true).into()],
563 false,
564 false,
565 ),
566 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
567 )
568 }
569
570 #[test]
571 fn last_value() -> Result<()> {
572 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
573 test_i32_result(
574 NthValue::last(),
575 PartitionEvaluatorArgs::new(
576 &[expr],
577 &[Field::new("f", DataType::Int32, true).into()],
578 false,
579 false,
580 ),
581 Int32Array::from(vec![
582 Some(1),
583 Some(-2),
584 Some(3),
585 Some(-4),
586 Some(5),
587 Some(-6),
588 Some(7),
589 Some(8),
590 ]),
591 )
592 }
593
594 #[test]
595 fn nth_value_1() -> Result<()> {
596 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
597 let n_value =
598 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
599
600 test_i32_result(
601 NthValue::nth(),
602 PartitionEvaluatorArgs::new(
603 &[expr, n_value],
604 &[Field::new("f", DataType::Int32, true).into()],
605 false,
606 false,
607 ),
608 Int32Array::from(vec![1; 8]),
609 )?;
610 Ok(())
611 }
612
613 #[test]
614 fn nth_value_2() -> Result<()> {
615 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
616 let n_value =
617 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
618
619 test_i32_result(
620 NthValue::nth(),
621 PartitionEvaluatorArgs::new(
622 &[expr, n_value],
623 &[Field::new("f", DataType::Int32, true).into()],
624 false,
625 false,
626 ),
627 Int32Array::from(vec![
628 None,
629 Some(-2),
630 Some(-2),
631 Some(-2),
632 Some(-2),
633 Some(-2),
634 Some(-2),
635 Some(-2),
636 ]),
637 )?;
638 Ok(())
639 }
640}