1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature,
30 TypeSignature, Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use field::WindowUDFFieldArgs;
36use std::any::Any;
37use std::cmp::Ordering;
38use std::fmt::Debug;
39use std::hash::Hash;
40use std::ops::Range;
41use std::sync::{Arc, LazyLock};
42
43get_or_init_udwf!(
44 First,
45 first_value,
46 "returns the first value in the window frame",
47 NthValue::first
48);
49get_or_init_udwf!(
50 Last,
51 last_value,
52 "returns the last value in the window frame",
53 NthValue::last
54);
55get_or_init_udwf!(
56 NthValue,
57 nth_value,
58 "returns the nth value in the window frame",
59 NthValue::nth
60);
61
62pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
65 first_value_udwf().call(vec![arg])
66}
67
68pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
71 last_value_udwf().call(vec![arg])
72}
73
74pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
77 nth_value_udwf().call(vec![arg, n.lit()])
78}
79
80#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
82pub enum NthValueKind {
83 First,
84 Last,
85 Nth,
86}
87
88impl NthValueKind {
89 fn name(&self) -> &'static str {
90 match self {
91 NthValueKind::First => "first_value",
92 NthValueKind::Last => "last_value",
93 NthValueKind::Nth => "nth_value",
94 }
95 }
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
99pub struct NthValue {
100 signature: Signature,
101 kind: NthValueKind,
102}
103
104impl NthValue {
105 pub fn new(kind: NthValueKind) -> Self {
107 Self {
108 signature: Signature::one_of(
109 vec![
110 TypeSignature::Any(0),
111 TypeSignature::Any(1),
112 TypeSignature::Any(2),
113 ],
114 Volatility::Immutable,
115 ),
116 kind,
117 }
118 }
119
120 pub fn first() -> Self {
121 Self::new(NthValueKind::First)
122 }
123
124 pub fn last() -> Self {
125 Self::new(NthValueKind::Last)
126 }
127 pub fn nth() -> Self {
128 Self::new(NthValueKind::Nth)
129 }
130
131 pub fn kind(&self) -> &NthValueKind {
132 &self.kind
133 }
134}
135
136static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
137 Documentation::builder(
138 DOC_SECTION_ANALYTICAL,
139 "Returns value evaluated at the row that is the first row of the window \
140 frame.",
141 "first_value(expression)",
142 )
143 .with_argument("expression", "Expression to operate on")
144 .with_sql_example(
145 r#"
146```sql
147-- Example usage of the first_value window function:
148SELECT department,
149 employee_id,
150 salary,
151 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
152FROM employees;
153
154+-------------+-------------+--------+------------+
155| department | employee_id | salary | top_salary |
156+-------------+-------------+--------+------------+
157| Sales | 1 | 70000 | 70000 |
158| Sales | 2 | 50000 | 70000 |
159| Sales | 3 | 30000 | 70000 |
160| Engineering | 4 | 90000 | 90000 |
161| Engineering | 5 | 80000 | 90000 |
162+-------------+-------------+--------+------------+
163```
164"#,
165 )
166 .build()
167});
168
169fn get_first_value_doc() -> &'static Documentation {
170 &FIRST_VALUE_DOCUMENTATION
171}
172
173static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
174 Documentation::builder(
175 DOC_SECTION_ANALYTICAL,
176 "Returns value evaluated at the row that is the last row of the window \
177 frame.",
178 "last_value(expression)",
179 )
180 .with_argument("expression", "Expression to operate on")
181 .with_sql_example(r#"```sql
182-- SQL example of last_value:
183SELECT department,
184 employee_id,
185 salary,
186 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
187FROM employees;
188
189+-------------+-------------+--------+---------------------+
190| department | employee_id | salary | running_last_salary |
191+-------------+-------------+--------+---------------------+
192| Sales | 1 | 30000 | 30000 |
193| Sales | 2 | 50000 | 50000 |
194| Sales | 3 | 70000 | 70000 |
195| Engineering | 4 | 40000 | 40000 |
196| Engineering | 5 | 60000 | 60000 |
197+-------------+-------------+--------+---------------------+
198```
199"#)
200 .build()
201});
202
203fn get_last_value_doc() -> &'static Documentation {
204 &LAST_VALUE_DOCUMENTATION
205}
206
207static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
208 Documentation::builder(
209 DOC_SECTION_ANALYTICAL,
210 "Returns the value evaluated at the nth row of the window frame \
211 (counting from 1). Returns NULL if no such row exists.",
212 "nth_value(expression, n)",
213 )
214 .with_argument(
215 "expression",
216 "The column from which to retrieve the nth value.",
217 )
218 .with_argument(
219 "n",
220 "Integer. Specifies the row number (starting from 1) in the window frame.",
221 )
222 .with_sql_example(
223 r#"
224```sql
225-- Sample employees table:
226CREATE TABLE employees (id INT, salary INT);
227INSERT INTO employees (id, salary) VALUES
228(1, 30000),
229(2, 40000),
230(3, 50000),
231(4, 60000),
232(5, 70000);
233
234-- Example usage of nth_value:
235SELECT nth_value(salary, 2) OVER (
236 ORDER BY salary
237 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
238) AS nth_value
239FROM employees;
240
241+-----------+
242| nth_value |
243+-----------+
244| 40000 |
245| 40000 |
246| 40000 |
247| 40000 |
248| 40000 |
249+-----------+
250```
251"#,
252 )
253 .build()
254});
255
256fn get_nth_value_doc() -> &'static Documentation {
257 &NTH_VALUE_DOCUMENTATION
258}
259
260impl WindowUDFImpl for NthValue {
261 fn as_any(&self) -> &dyn Any {
262 self
263 }
264
265 fn name(&self) -> &str {
266 self.kind.name()
267 }
268
269 fn signature(&self) -> &Signature {
270 &self.signature
271 }
272
273 fn partition_evaluator(
274 &self,
275 partition_evaluator_args: PartitionEvaluatorArgs,
276 ) -> Result<Box<dyn PartitionEvaluator>> {
277 let state = NthValueState {
278 finalized_result: None,
279 kind: self.kind,
280 };
281
282 if !matches!(self.kind, NthValueKind::Nth) {
283 return Ok(Box::new(NthValueEvaluator {
284 state,
285 ignore_nulls: partition_evaluator_args.ignore_nulls(),
286 n: 0,
287 }));
288 }
289
290 let n =
291 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
292 .map_err(|_e| {
293 exec_datafusion_err!(
294 "Expected a signed integer literal for the second argument of nth_value")
295 })?
296 .map(get_signed_integer)
297 {
298 Some(Ok(n)) => {
299 if partition_evaluator_args.is_reversed() {
300 -n
301 } else {
302 n
303 }
304 }
305 _ => {
306 return exec_err!(
307 "Expected a signed integer literal for the second argument of nth_value"
308 )
309 }
310 };
311
312 Ok(Box::new(NthValueEvaluator {
313 state,
314 ignore_nulls: partition_evaluator_args.ignore_nulls(),
315 n,
316 }))
317 }
318
319 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
320 let return_type = field_args
321 .input_fields()
322 .first()
323 .map(|f| f.data_type())
324 .cloned()
325 .unwrap_or(DataType::Null);
326
327 Ok(Field::new(field_args.name(), return_type, true).into())
328 }
329
330 fn reverse_expr(&self) -> ReversedUDWF {
331 match self.kind {
332 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
333 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
334 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
335 }
336 }
337
338 fn documentation(&self) -> Option<&Documentation> {
339 match self.kind {
340 NthValueKind::First => Some(get_first_value_doc()),
341 NthValueKind::Last => Some(get_last_value_doc()),
342 NthValueKind::Nth => Some(get_nth_value_doc()),
343 }
344 }
345
346 fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
347 LimitEffect::None }
349}
350
351#[derive(Debug, Clone)]
352pub struct NthValueState {
353 pub finalized_result: Option<ScalarValue>,
362 pub kind: NthValueKind,
363}
364
365#[derive(Debug)]
366pub(crate) struct NthValueEvaluator {
367 state: NthValueState,
368 ignore_nulls: bool,
369 n: i64,
370}
371
372impl PartitionEvaluator for NthValueEvaluator {
373 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
379 let out = &state.out_col;
380 let size = out.len();
381 let mut buffer_size = 1;
382 let (is_prunable, is_reverse_direction) = match self.state.kind {
384 NthValueKind::First => {
385 let n_range =
386 state.window_frame_range.end - state.window_frame_range.start;
387 (n_range > 0 && size > 0, false)
388 }
389 NthValueKind::Last => (true, true),
390 NthValueKind::Nth => {
391 let n_range =
392 state.window_frame_range.end - state.window_frame_range.start;
393 match self.n.cmp(&0) {
394 Ordering::Greater => (
395 n_range >= (self.n as usize) && size > (self.n as usize),
396 false,
397 ),
398 Ordering::Less => {
399 let reverse_index = (-self.n) as usize;
400 buffer_size = reverse_index;
401 (n_range >= reverse_index, true)
403 }
404 Ordering::Equal => (false, false),
405 }
406 }
407 };
408 if is_prunable && !self.ignore_nulls {
410 if self.state.finalized_result.is_none() && !is_reverse_direction {
411 let result = ScalarValue::try_from_array(out, size - 1)?;
412 self.state.finalized_result = Some(result);
413 }
414 state.window_frame_range.start =
415 state.window_frame_range.end.saturating_sub(buffer_size);
416 }
417 Ok(())
418 }
419
420 fn evaluate(
421 &mut self,
422 values: &[ArrayRef],
423 range: &Range<usize>,
424 ) -> Result<ScalarValue> {
425 if let Some(ref result) = self.state.finalized_result {
426 Ok(result.clone())
427 } else {
428 let arr = &values[0];
430 let n_range = range.end - range.start;
431 if n_range == 0 {
432 return ScalarValue::try_from(arr.data_type());
434 }
435
436 let valid_indices = if self.ignore_nulls {
438 let slice = arr.slice(range.start, n_range);
440 match slice.nulls() {
441 Some(nulls) => {
442 let valid_indices = nulls
443 .valid_indices()
444 .map(|idx| {
445 idx + range.start
447 })
448 .collect::<Vec<_>>();
449 if valid_indices.is_empty() {
450 return ScalarValue::try_from(arr.data_type());
452 }
453 Some(valid_indices)
454 }
455 None => None,
456 }
457 } else {
458 None
459 };
460 match self.state.kind {
461 NthValueKind::First => {
462 if let Some(valid_indices) = &valid_indices {
463 ScalarValue::try_from_array(arr, valid_indices[0])
464 } else {
465 ScalarValue::try_from_array(arr, range.start)
466 }
467 }
468 NthValueKind::Last => {
469 if let Some(valid_indices) = &valid_indices {
470 ScalarValue::try_from_array(
471 arr,
472 valid_indices[valid_indices.len() - 1],
473 )
474 } else {
475 ScalarValue::try_from_array(arr, range.end - 1)
476 }
477 }
478 NthValueKind::Nth => {
479 match self.n.cmp(&0) {
480 Ordering::Greater => {
481 let index = (self.n as usize) - 1;
483 if index >= n_range {
484 ScalarValue::try_from(arr.data_type())
486 } else if let Some(valid_indices) = valid_indices {
487 if index >= valid_indices.len() {
488 return ScalarValue::try_from(arr.data_type());
489 }
490 ScalarValue::try_from_array(&arr, valid_indices[index])
491 } else {
492 ScalarValue::try_from_array(arr, range.start + index)
493 }
494 }
495 Ordering::Less => {
496 let reverse_index = (-self.n) as usize;
497 if n_range < reverse_index {
498 ScalarValue::try_from(arr.data_type())
500 } else if let Some(valid_indices) = valid_indices {
501 if reverse_index > valid_indices.len() {
502 return ScalarValue::try_from(arr.data_type());
503 }
504 let new_index =
505 valid_indices[valid_indices.len() - reverse_index];
506 ScalarValue::try_from_array(&arr, new_index)
507 } else {
508 ScalarValue::try_from_array(
509 arr,
510 range.start + n_range - reverse_index,
511 )
512 }
513 }
514 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
515 }
516 }
517 }
518 }
519 }
520
521 fn supports_bounded_execution(&self) -> bool {
522 true
523 }
524
525 fn uses_window_frame(&self) -> bool {
526 true
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use arrow::array::*;
534 use datafusion_common::cast::as_int32_array;
535 use datafusion_physical_expr::expressions::{Column, Literal};
536 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
537 use std::sync::Arc;
538
539 fn test_i32_result(
540 expr: NthValue,
541 partition_evaluator_args: PartitionEvaluatorArgs,
542 expected: Int32Array,
543 ) -> Result<()> {
544 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
545 let values = vec![arr];
546 let mut ranges: Vec<Range<usize>> = vec![];
547 for i in 0..8 {
548 ranges.push(Range {
549 start: 0,
550 end: i + 1,
551 })
552 }
553 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
554 let result = ranges
555 .iter()
556 .map(|range| evaluator.evaluate(&values, range))
557 .collect::<Result<Vec<ScalarValue>>>()?;
558 let result = ScalarValue::iter_to_array(result.into_iter())?;
559 let result = as_int32_array(&result)?;
560 assert_eq!(expected, *result);
561 Ok(())
562 }
563
564 #[test]
565 fn first_value() -> Result<()> {
566 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
567 test_i32_result(
568 NthValue::first(),
569 PartitionEvaluatorArgs::new(
570 &[expr],
571 &[Field::new("f", DataType::Int32, true).into()],
572 false,
573 false,
574 ),
575 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
576 )
577 }
578
579 #[test]
580 fn last_value() -> Result<()> {
581 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
582 test_i32_result(
583 NthValue::last(),
584 PartitionEvaluatorArgs::new(
585 &[expr],
586 &[Field::new("f", DataType::Int32, true).into()],
587 false,
588 false,
589 ),
590 Int32Array::from(vec![
591 Some(1),
592 Some(-2),
593 Some(3),
594 Some(-4),
595 Some(5),
596 Some(-6),
597 Some(7),
598 Some(8),
599 ]),
600 )
601 }
602
603 #[test]
604 fn nth_value_1() -> Result<()> {
605 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
606 let n_value =
607 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
608
609 test_i32_result(
610 NthValue::nth(),
611 PartitionEvaluatorArgs::new(
612 &[expr, n_value],
613 &[Field::new("f", DataType::Int32, true).into()],
614 false,
615 false,
616 ),
617 Int32Array::from(vec![1; 8]),
618 )?;
619 Ok(())
620 }
621
622 #[test]
623 fn nth_value_2() -> Result<()> {
624 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
625 let n_value =
626 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
627
628 test_i32_result(
629 NthValue::nth(),
630 PartitionEvaluatorArgs::new(
631 &[expr, n_value],
632 &[Field::new("f", DataType::Int32, true).into()],
633 false,
634 false,
635 ),
636 Int32Array::from(vec![
637 None,
638 Some(-2),
639 Some(-2),
640 Some(-2),
641 Some(-2),
642 Some(-2),
643 Some(-2),
644 Some(-2),
645 ]),
646 )?;
647 Ok(())
648 }
649}