1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use std::any::Any;
23use std::cmp::Ordering;
24use std::fmt::Debug;
25use std::ops::Range;
26use std::sync::LazyLock;
27
28use datafusion_common::arrow::array::ArrayRef;
29use datafusion_common::arrow::datatypes::{DataType, Field};
30use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
31use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
32use datafusion_expr::window_state::WindowAggState;
33use datafusion_expr::{
34 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
35 Volatility, WindowUDFImpl,
36};
37use datafusion_functions_window_common::field;
38use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
39use field::WindowUDFFieldArgs;
40
41get_or_init_udwf!(
42 First,
43 first_value,
44 "returns the first value in the window frame",
45 NthValue::first
46);
47get_or_init_udwf!(
48 Last,
49 last_value,
50 "returns the last value in the window frame",
51 NthValue::last
52);
53get_or_init_udwf!(
54 NthValue,
55 nth_value,
56 "returns the nth value in the window frame",
57 NthValue::nth
58);
59
60pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
63 first_value_udwf().call(vec![arg])
64}
65
66pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
69 last_value_udwf().call(vec![arg])
70}
71
72pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
75 nth_value_udwf().call(vec![arg, n.lit()])
76}
77
78#[derive(Debug, Copy, Clone)]
80pub enum NthValueKind {
81 First,
82 Last,
83 Nth,
84}
85
86impl NthValueKind {
87 fn name(&self) -> &'static str {
88 match self {
89 NthValueKind::First => "first_value",
90 NthValueKind::Last => "last_value",
91 NthValueKind::Nth => "nth_value",
92 }
93 }
94}
95
96#[derive(Debug)]
97pub struct NthValue {
98 signature: Signature,
99 kind: NthValueKind,
100}
101
102impl NthValue {
103 pub fn new(kind: NthValueKind) -> Self {
105 Self {
106 signature: Signature::one_of(
107 vec![
108 TypeSignature::Any(0),
109 TypeSignature::Any(1),
110 TypeSignature::Any(2),
111 ],
112 Volatility::Immutable,
113 ),
114 kind,
115 }
116 }
117
118 pub fn first() -> Self {
119 Self::new(NthValueKind::First)
120 }
121
122 pub fn last() -> Self {
123 Self::new(NthValueKind::Last)
124 }
125 pub fn nth() -> Self {
126 Self::new(NthValueKind::Nth)
127 }
128}
129
130static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
131 Documentation::builder(
132 DOC_SECTION_ANALYTICAL,
133 "Returns value evaluated at the row that is the first row of the window \
134 frame.",
135 "first_value(expression)",
136 )
137 .with_argument("expression", "Expression to operate on")
138 .build()
139});
140
141fn get_first_value_doc() -> &'static Documentation {
142 &FIRST_VALUE_DOCUMENTATION
143}
144
145static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
146 Documentation::builder(
147 DOC_SECTION_ANALYTICAL,
148 "Returns value evaluated at the row that is the last row of the window \
149 frame.",
150 "last_value(expression)",
151 )
152 .with_argument("expression", "Expression to operate on")
153 .build()
154});
155
156fn get_last_value_doc() -> &'static Documentation {
157 &LAST_VALUE_DOCUMENTATION
158}
159
160static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
161 Documentation::builder(
162 DOC_SECTION_ANALYTICAL,
163 "Returns value evaluated at the row that is the nth row of the window \
164 frame (counting from 1); null if no such row.",
165 "nth_value(expression, n)",
166 )
167 .with_argument(
168 "expression",
169 "The name the column of which nth \
170 value to retrieve",
171 )
172 .with_argument("n", "Integer. Specifies the n in nth")
173 .build()
174});
175
176fn get_nth_value_doc() -> &'static Documentation {
177 &NTH_VALUE_DOCUMENTATION
178}
179
180impl WindowUDFImpl for NthValue {
181 fn as_any(&self) -> &dyn Any {
182 self
183 }
184
185 fn name(&self) -> &str {
186 self.kind.name()
187 }
188
189 fn signature(&self) -> &Signature {
190 &self.signature
191 }
192
193 fn partition_evaluator(
194 &self,
195 partition_evaluator_args: PartitionEvaluatorArgs,
196 ) -> Result<Box<dyn PartitionEvaluator>> {
197 let state = NthValueState {
198 finalized_result: None,
199 kind: self.kind,
200 };
201
202 if !matches!(self.kind, NthValueKind::Nth) {
203 return Ok(Box::new(NthValueEvaluator {
204 state,
205 ignore_nulls: partition_evaluator_args.ignore_nulls(),
206 n: 0,
207 }));
208 }
209
210 let n =
211 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
212 .map_err(|_e| {
213 exec_datafusion_err!(
214 "Expected a signed integer literal for the second argument of nth_value")
215 })?
216 .map(get_signed_integer)
217 {
218 Some(Ok(n)) => {
219 if partition_evaluator_args.is_reversed() {
220 -n
221 } else {
222 n
223 }
224 }
225 _ => {
226 return exec_err!(
227 "Expected a signed integer literal for the second argument of nth_value"
228 )
229 }
230 };
231
232 Ok(Box::new(NthValueEvaluator {
233 state,
234 ignore_nulls: partition_evaluator_args.ignore_nulls(),
235 n,
236 }))
237 }
238
239 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
240 let nullable = true;
241 let return_type = field_args.input_types().first().unwrap_or(&DataType::Null);
242
243 Ok(Field::new(field_args.name(), return_type.clone(), nullable))
244 }
245
246 fn reverse_expr(&self) -> ReversedUDWF {
247 match self.kind {
248 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
249 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
250 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
251 }
252 }
253
254 fn documentation(&self) -> Option<&Documentation> {
255 match self.kind {
256 NthValueKind::First => Some(get_first_value_doc()),
257 NthValueKind::Last => Some(get_last_value_doc()),
258 NthValueKind::Nth => Some(get_nth_value_doc()),
259 }
260 }
261}
262
263#[derive(Debug, Clone)]
264pub struct NthValueState {
265 pub finalized_result: Option<ScalarValue>,
274 pub kind: NthValueKind,
275}
276
277#[derive(Debug)]
278pub(crate) struct NthValueEvaluator {
279 state: NthValueState,
280 ignore_nulls: bool,
281 n: i64,
282}
283
284impl PartitionEvaluator for NthValueEvaluator {
285 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
291 let out = &state.out_col;
292 let size = out.len();
293 let mut buffer_size = 1;
294 let (is_prunable, is_reverse_direction) = match self.state.kind {
296 NthValueKind::First => {
297 let n_range =
298 state.window_frame_range.end - state.window_frame_range.start;
299 (n_range > 0 && size > 0, false)
300 }
301 NthValueKind::Last => (true, true),
302 NthValueKind::Nth => {
303 let n_range =
304 state.window_frame_range.end - state.window_frame_range.start;
305 match self.n.cmp(&0) {
306 Ordering::Greater => (
307 n_range >= (self.n as usize) && size > (self.n as usize),
308 false,
309 ),
310 Ordering::Less => {
311 let reverse_index = (-self.n) as usize;
312 buffer_size = reverse_index;
313 (n_range >= reverse_index, true)
315 }
316 Ordering::Equal => (false, false),
317 }
318 }
319 };
320 if is_prunable && !self.ignore_nulls {
322 if self.state.finalized_result.is_none() && !is_reverse_direction {
323 let result = ScalarValue::try_from_array(out, size - 1)?;
324 self.state.finalized_result = Some(result);
325 }
326 state.window_frame_range.start =
327 state.window_frame_range.end.saturating_sub(buffer_size);
328 }
329 Ok(())
330 }
331
332 fn evaluate(
333 &mut self,
334 values: &[ArrayRef],
335 range: &Range<usize>,
336 ) -> Result<ScalarValue> {
337 if let Some(ref result) = self.state.finalized_result {
338 Ok(result.clone())
339 } else {
340 let arr = &values[0];
342 let n_range = range.end - range.start;
343 if n_range == 0 {
344 return ScalarValue::try_from(arr.data_type());
346 }
347
348 let valid_indices = if self.ignore_nulls {
350 let slice = arr.slice(range.start, n_range);
352 match slice.nulls() {
353 Some(nulls) => {
354 let valid_indices = nulls
355 .valid_indices()
356 .map(|idx| {
357 idx + range.start
359 })
360 .collect::<Vec<_>>();
361 if valid_indices.is_empty() {
362 return ScalarValue::try_from(arr.data_type());
364 }
365 Some(valid_indices)
366 }
367 None => None,
368 }
369 } else {
370 None
371 };
372 match self.state.kind {
373 NthValueKind::First => {
374 if let Some(valid_indices) = &valid_indices {
375 ScalarValue::try_from_array(arr, valid_indices[0])
376 } else {
377 ScalarValue::try_from_array(arr, range.start)
378 }
379 }
380 NthValueKind::Last => {
381 if let Some(valid_indices) = &valid_indices {
382 ScalarValue::try_from_array(
383 arr,
384 valid_indices[valid_indices.len() - 1],
385 )
386 } else {
387 ScalarValue::try_from_array(arr, range.end - 1)
388 }
389 }
390 NthValueKind::Nth => {
391 match self.n.cmp(&0) {
392 Ordering::Greater => {
393 let index = (self.n as usize) - 1;
395 if index >= n_range {
396 ScalarValue::try_from(arr.data_type())
398 } else if let Some(valid_indices) = valid_indices {
399 if index >= valid_indices.len() {
400 return ScalarValue::try_from(arr.data_type());
401 }
402 ScalarValue::try_from_array(&arr, valid_indices[index])
403 } else {
404 ScalarValue::try_from_array(arr, range.start + index)
405 }
406 }
407 Ordering::Less => {
408 let reverse_index = (-self.n) as usize;
409 if n_range < reverse_index {
410 ScalarValue::try_from(arr.data_type())
412 } else if let Some(valid_indices) = valid_indices {
413 if reverse_index > valid_indices.len() {
414 return ScalarValue::try_from(arr.data_type());
415 }
416 let new_index =
417 valid_indices[valid_indices.len() - reverse_index];
418 ScalarValue::try_from_array(&arr, new_index)
419 } else {
420 ScalarValue::try_from_array(
421 arr,
422 range.start + n_range - reverse_index,
423 )
424 }
425 }
426 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
427 }
428 }
429 }
430 }
431 }
432
433 fn supports_bounded_execution(&self) -> bool {
434 true
435 }
436
437 fn uses_window_frame(&self) -> bool {
438 true
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use arrow::array::*;
446 use datafusion_common::cast::as_int32_array;
447 use datafusion_physical_expr::expressions::{Column, Literal};
448 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
449 use std::sync::Arc;
450
451 fn test_i32_result(
452 expr: NthValue,
453 partition_evaluator_args: PartitionEvaluatorArgs,
454 expected: Int32Array,
455 ) -> Result<()> {
456 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
457 let values = vec![arr];
458 let mut ranges: Vec<Range<usize>> = vec![];
459 for i in 0..8 {
460 ranges.push(Range {
461 start: 0,
462 end: i + 1,
463 })
464 }
465 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
466 let result = ranges
467 .iter()
468 .map(|range| evaluator.evaluate(&values, range))
469 .collect::<Result<Vec<ScalarValue>>>()?;
470 let result = ScalarValue::iter_to_array(result.into_iter())?;
471 let result = as_int32_array(&result)?;
472 assert_eq!(expected, *result);
473 Ok(())
474 }
475
476 #[test]
477 fn first_value() -> Result<()> {
478 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
479 test_i32_result(
480 NthValue::first(),
481 PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
482 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
483 )
484 }
485
486 #[test]
487 fn last_value() -> Result<()> {
488 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
489 test_i32_result(
490 NthValue::last(),
491 PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
492 Int32Array::from(vec![
493 Some(1),
494 Some(-2),
495 Some(3),
496 Some(-4),
497 Some(5),
498 Some(-6),
499 Some(7),
500 Some(8),
501 ]),
502 )
503 }
504
505 #[test]
506 fn nth_value_1() -> Result<()> {
507 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
508 let n_value =
509 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
510
511 test_i32_result(
512 NthValue::nth(),
513 PartitionEvaluatorArgs::new(
514 &[expr, n_value],
515 &[DataType::Int32],
516 false,
517 false,
518 ),
519 Int32Array::from(vec![1; 8]),
520 )?;
521 Ok(())
522 }
523
524 #[test]
525 fn nth_value_2() -> Result<()> {
526 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
527 let n_value =
528 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
529
530 test_i32_result(
531 NthValue::nth(),
532 PartitionEvaluatorArgs::new(
533 &[expr, n_value],
534 &[DataType::Int32],
535 false,
536 false,
537 ),
538 Int32Array::from(vec![
539 None,
540 Some(-2),
541 Some(-2),
542 Some(-2),
543 Some(-2),
544 Some(-2),
545 Some(-2),
546 Some(-2),
547 ]),
548 )?;
549 Ok(())
550 }
551}