1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
30 Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::hash::{DefaultHasher, Hash, Hasher};
39use std::ops::Range;
40use std::sync::LazyLock;
41
42get_or_init_udwf!(
43 First,
44 first_value,
45 "returns the first value in the window frame",
46 NthValue::first
47);
48get_or_init_udwf!(
49 Last,
50 last_value,
51 "returns the last value in the window frame",
52 NthValue::last
53);
54get_or_init_udwf!(
55 NthValue,
56 nth_value,
57 "returns the nth value in the window frame",
58 NthValue::nth
59);
60
61pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
64 first_value_udwf().call(vec![arg])
65}
66
67pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
70 last_value_udwf().call(vec![arg])
71}
72
73pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
76 nth_value_udwf().call(vec![arg, n.lit()])
77}
78
79#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
81pub enum NthValueKind {
82 First,
83 Last,
84 Nth,
85}
86
87impl NthValueKind {
88 fn name(&self) -> &'static str {
89 match self {
90 NthValueKind::First => "first_value",
91 NthValueKind::Last => "last_value",
92 NthValueKind::Nth => "nth_value",
93 }
94 }
95}
96
97#[derive(Debug)]
98pub struct NthValue {
99 signature: Signature,
100 kind: NthValueKind,
101}
102
103impl NthValue {
104 pub fn new(kind: NthValueKind) -> Self {
106 Self {
107 signature: Signature::one_of(
108 vec![
109 TypeSignature::Any(0),
110 TypeSignature::Any(1),
111 TypeSignature::Any(2),
112 ],
113 Volatility::Immutable,
114 ),
115 kind,
116 }
117 }
118
119 pub fn first() -> Self {
120 Self::new(NthValueKind::First)
121 }
122
123 pub fn last() -> Self {
124 Self::new(NthValueKind::Last)
125 }
126 pub fn nth() -> Self {
127 Self::new(NthValueKind::Nth)
128 }
129}
130
131static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
132 Documentation::builder(
133 DOC_SECTION_ANALYTICAL,
134 "Returns value evaluated at the row that is the first row of the window \
135 frame.",
136 "first_value(expression)",
137 )
138 .with_argument("expression", "Expression to operate on")
139 .with_sql_example(r#"```sql
140 --Example usage of the first_value window function:
141 SELECT department,
142 employee_id,
143 salary,
144 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
145 FROM employees;
146```
147
148```sql
149+-------------+-------------+--------+------------+
150| department | employee_id | salary | top_salary |
151+-------------+-------------+--------+------------+
152| Sales | 1 | 70000 | 70000 |
153| Sales | 2 | 50000 | 70000 |
154| Sales | 3 | 30000 | 70000 |
155| Engineering | 4 | 90000 | 90000 |
156| Engineering | 5 | 80000 | 90000 |
157+-------------+-------------+--------+------------+
158```"#)
159 .build()
160});
161
162fn get_first_value_doc() -> &'static Documentation {
163 &FIRST_VALUE_DOCUMENTATION
164}
165
166static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
167 Documentation::builder(
168 DOC_SECTION_ANALYTICAL,
169 "Returns value evaluated at the row that is the last row of the window \
170 frame.",
171 "last_value(expression)",
172 )
173 .with_argument("expression", "Expression to operate on")
174 .with_sql_example(r#"```sql
175-- SQL example of last_value:
176SELECT department,
177 employee_id,
178 salary,
179 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
180FROM employees;
181```
182
183```sql
184+-------------+-------------+--------+---------------------+
185| department | employee_id | salary | running_last_salary |
186+-------------+-------------+--------+---------------------+
187| Sales | 1 | 30000 | 30000 |
188| Sales | 2 | 50000 | 50000 |
189| Sales | 3 | 70000 | 70000 |
190| Engineering | 4 | 40000 | 40000 |
191| Engineering | 5 | 60000 | 60000 |
192+-------------+-------------+--------+---------------------+
193```"#)
194 .build()
195});
196
197fn get_last_value_doc() -> &'static Documentation {
198 &LAST_VALUE_DOCUMENTATION
199}
200
201static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
202 Documentation::builder(
203 DOC_SECTION_ANALYTICAL,
204 "Returns the value evaluated at the nth row of the window frame \
205 (counting from 1). Returns NULL if no such row exists.",
206 "nth_value(expression, n)",
207 )
208 .with_argument(
209 "expression",
210 "The column from which to retrieve the nth value.",
211 )
212 .with_argument(
213 "n",
214 "Integer. Specifies the row number (starting from 1) in the window frame.",
215 )
216 .with_sql_example(
217 r#"```sql
218-- Sample employees table:
219CREATE TABLE employees (id INT, salary INT);
220INSERT INTO employees (id, salary) VALUES
221(1, 30000),
222(2, 40000),
223(3, 50000),
224(4, 60000),
225(5, 70000);
226
227-- Example usage of nth_value:
228SELECT nth_value(salary, 2) OVER (
229 ORDER BY salary
230 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
231) AS nth_value
232FROM employees;
233```
234
235```text
236+-----------+
237| nth_value |
238+-----------+
239| 40000 |
240| 40000 |
241| 40000 |
242| 40000 |
243| 40000 |
244+-----------+
245```"#,
246 )
247 .build()
248});
249
250fn get_nth_value_doc() -> &'static Documentation {
251 &NTH_VALUE_DOCUMENTATION
252}
253
254impl WindowUDFImpl for NthValue {
255 fn as_any(&self) -> &dyn Any {
256 self
257 }
258
259 fn name(&self) -> &str {
260 self.kind.name()
261 }
262
263 fn signature(&self) -> &Signature {
264 &self.signature
265 }
266
267 fn partition_evaluator(
268 &self,
269 partition_evaluator_args: PartitionEvaluatorArgs,
270 ) -> Result<Box<dyn PartitionEvaluator>> {
271 let state = NthValueState {
272 finalized_result: None,
273 kind: self.kind,
274 };
275
276 if !matches!(self.kind, NthValueKind::Nth) {
277 return Ok(Box::new(NthValueEvaluator {
278 state,
279 ignore_nulls: partition_evaluator_args.ignore_nulls(),
280 n: 0,
281 }));
282 }
283
284 let n =
285 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
286 .map_err(|_e| {
287 exec_datafusion_err!(
288 "Expected a signed integer literal for the second argument of nth_value")
289 })?
290 .map(get_signed_integer)
291 {
292 Some(Ok(n)) => {
293 if partition_evaluator_args.is_reversed() {
294 -n
295 } else {
296 n
297 }
298 }
299 _ => {
300 return exec_err!(
301 "Expected a signed integer literal for the second argument of nth_value"
302 )
303 }
304 };
305
306 Ok(Box::new(NthValueEvaluator {
307 state,
308 ignore_nulls: partition_evaluator_args.ignore_nulls(),
309 n,
310 }))
311 }
312
313 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
314 let return_type = field_args
315 .input_fields()
316 .first()
317 .map(|f| f.data_type())
318 .cloned()
319 .unwrap_or(DataType::Null);
320
321 Ok(Field::new(field_args.name(), return_type, true).into())
322 }
323
324 fn reverse_expr(&self) -> ReversedUDWF {
325 match self.kind {
326 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
327 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
328 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
329 }
330 }
331
332 fn documentation(&self) -> Option<&Documentation> {
333 match self.kind {
334 NthValueKind::First => Some(get_first_value_doc()),
335 NthValueKind::Last => Some(get_last_value_doc()),
336 NthValueKind::Nth => Some(get_nth_value_doc()),
337 }
338 }
339
340 fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
341 let Some(other) = other.as_any().downcast_ref::<Self>() else {
342 return false;
343 };
344 let Self { signature, kind } = self;
345 signature == &other.signature && kind == &other.kind
346 }
347
348 fn hash_value(&self) -> u64 {
349 let Self { signature, kind } = self;
350 let mut hasher = DefaultHasher::new();
351 std::any::type_name::<Self>().hash(&mut hasher);
352 signature.hash(&mut hasher);
353 kind.hash(&mut hasher);
354 hasher.finish()
355 }
356}
357
358#[derive(Debug, Clone)]
359pub struct NthValueState {
360 pub finalized_result: Option<ScalarValue>,
369 pub kind: NthValueKind,
370}
371
372#[derive(Debug)]
373pub(crate) struct NthValueEvaluator {
374 state: NthValueState,
375 ignore_nulls: bool,
376 n: i64,
377}
378
379impl PartitionEvaluator for NthValueEvaluator {
380 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
386 let out = &state.out_col;
387 let size = out.len();
388 let mut buffer_size = 1;
389 let (is_prunable, is_reverse_direction) = match self.state.kind {
391 NthValueKind::First => {
392 let n_range =
393 state.window_frame_range.end - state.window_frame_range.start;
394 (n_range > 0 && size > 0, false)
395 }
396 NthValueKind::Last => (true, true),
397 NthValueKind::Nth => {
398 let n_range =
399 state.window_frame_range.end - state.window_frame_range.start;
400 match self.n.cmp(&0) {
401 Ordering::Greater => (
402 n_range >= (self.n as usize) && size > (self.n as usize),
403 false,
404 ),
405 Ordering::Less => {
406 let reverse_index = (-self.n) as usize;
407 buffer_size = reverse_index;
408 (n_range >= reverse_index, true)
410 }
411 Ordering::Equal => (false, false),
412 }
413 }
414 };
415 if is_prunable && !self.ignore_nulls {
417 if self.state.finalized_result.is_none() && !is_reverse_direction {
418 let result = ScalarValue::try_from_array(out, size - 1)?;
419 self.state.finalized_result = Some(result);
420 }
421 state.window_frame_range.start =
422 state.window_frame_range.end.saturating_sub(buffer_size);
423 }
424 Ok(())
425 }
426
427 fn evaluate(
428 &mut self,
429 values: &[ArrayRef],
430 range: &Range<usize>,
431 ) -> Result<ScalarValue> {
432 if let Some(ref result) = self.state.finalized_result {
433 Ok(result.clone())
434 } else {
435 let arr = &values[0];
437 let n_range = range.end - range.start;
438 if n_range == 0 {
439 return ScalarValue::try_from(arr.data_type());
441 }
442
443 let valid_indices = if self.ignore_nulls {
445 let slice = arr.slice(range.start, n_range);
447 match slice.nulls() {
448 Some(nulls) => {
449 let valid_indices = nulls
450 .valid_indices()
451 .map(|idx| {
452 idx + range.start
454 })
455 .collect::<Vec<_>>();
456 if valid_indices.is_empty() {
457 return ScalarValue::try_from(arr.data_type());
459 }
460 Some(valid_indices)
461 }
462 None => None,
463 }
464 } else {
465 None
466 };
467 match self.state.kind {
468 NthValueKind::First => {
469 if let Some(valid_indices) = &valid_indices {
470 ScalarValue::try_from_array(arr, valid_indices[0])
471 } else {
472 ScalarValue::try_from_array(arr, range.start)
473 }
474 }
475 NthValueKind::Last => {
476 if let Some(valid_indices) = &valid_indices {
477 ScalarValue::try_from_array(
478 arr,
479 valid_indices[valid_indices.len() - 1],
480 )
481 } else {
482 ScalarValue::try_from_array(arr, range.end - 1)
483 }
484 }
485 NthValueKind::Nth => {
486 match self.n.cmp(&0) {
487 Ordering::Greater => {
488 let index = (self.n as usize) - 1;
490 if index >= n_range {
491 ScalarValue::try_from(arr.data_type())
493 } else if let Some(valid_indices) = valid_indices {
494 if index >= valid_indices.len() {
495 return ScalarValue::try_from(arr.data_type());
496 }
497 ScalarValue::try_from_array(&arr, valid_indices[index])
498 } else {
499 ScalarValue::try_from_array(arr, range.start + index)
500 }
501 }
502 Ordering::Less => {
503 let reverse_index = (-self.n) as usize;
504 if n_range < reverse_index {
505 ScalarValue::try_from(arr.data_type())
507 } else if let Some(valid_indices) = valid_indices {
508 if reverse_index > valid_indices.len() {
509 return ScalarValue::try_from(arr.data_type());
510 }
511 let new_index =
512 valid_indices[valid_indices.len() - reverse_index];
513 ScalarValue::try_from_array(&arr, new_index)
514 } else {
515 ScalarValue::try_from_array(
516 arr,
517 range.start + n_range - reverse_index,
518 )
519 }
520 }
521 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
522 }
523 }
524 }
525 }
526 }
527
528 fn supports_bounded_execution(&self) -> bool {
529 true
530 }
531
532 fn uses_window_frame(&self) -> bool {
533 true
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use arrow::array::*;
541 use datafusion_common::cast::as_int32_array;
542 use datafusion_physical_expr::expressions::{Column, Literal};
543 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
544 use std::sync::Arc;
545
546 fn test_i32_result(
547 expr: NthValue,
548 partition_evaluator_args: PartitionEvaluatorArgs,
549 expected: Int32Array,
550 ) -> Result<()> {
551 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
552 let values = vec![arr];
553 let mut ranges: Vec<Range<usize>> = vec![];
554 for i in 0..8 {
555 ranges.push(Range {
556 start: 0,
557 end: i + 1,
558 })
559 }
560 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
561 let result = ranges
562 .iter()
563 .map(|range| evaluator.evaluate(&values, range))
564 .collect::<Result<Vec<ScalarValue>>>()?;
565 let result = ScalarValue::iter_to_array(result.into_iter())?;
566 let result = as_int32_array(&result)?;
567 assert_eq!(expected, *result);
568 Ok(())
569 }
570
571 #[test]
572 fn first_value() -> Result<()> {
573 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
574 test_i32_result(
575 NthValue::first(),
576 PartitionEvaluatorArgs::new(
577 &[expr],
578 &[Field::new("f", DataType::Int32, true).into()],
579 false,
580 false,
581 ),
582 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
583 )
584 }
585
586 #[test]
587 fn last_value() -> Result<()> {
588 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
589 test_i32_result(
590 NthValue::last(),
591 PartitionEvaluatorArgs::new(
592 &[expr],
593 &[Field::new("f", DataType::Int32, true).into()],
594 false,
595 false,
596 ),
597 Int32Array::from(vec![
598 Some(1),
599 Some(-2),
600 Some(3),
601 Some(-4),
602 Some(5),
603 Some(-6),
604 Some(7),
605 Some(8),
606 ]),
607 )
608 }
609
610 #[test]
611 fn nth_value_1() -> Result<()> {
612 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
613 let n_value =
614 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
615
616 test_i32_result(
617 NthValue::nth(),
618 PartitionEvaluatorArgs::new(
619 &[expr, n_value],
620 &[Field::new("f", DataType::Int32, true).into()],
621 false,
622 false,
623 ),
624 Int32Array::from(vec![1; 8]),
625 )?;
626 Ok(())
627 }
628
629 #[test]
630 fn nth_value_2() -> Result<()> {
631 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
632 let n_value =
633 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
634
635 test_i32_result(
636 NthValue::nth(),
637 PartitionEvaluatorArgs::new(
638 &[expr, n_value],
639 &[Field::new("f", DataType::Int32, true).into()],
640 false,
641 false,
642 ),
643 Int32Array::from(vec![
644 None,
645 Some(-2),
646 Some(-2),
647 Some(-2),
648 Some(-2),
649 Some(-2),
650 Some(-2),
651 Some(-2),
652 ]),
653 )?;
654 Ok(())
655 }
656}