datafusion_functions_aggregate/
nth_value.rs1use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
30use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34 lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
35 Signature, SortExpr, Volatility,
36};
37use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
38use datafusion_functions_aggregate_common::utils::ordering_fields;
39use datafusion_macros::user_doc;
40use datafusion_physical_expr::expressions::Literal;
41use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
42
43create_func!(NthValueAgg, nth_value_udaf);
44
45pub fn nth_value(
47 expr: datafusion_expr::Expr,
48 n: i64,
49 order_by: Vec<SortExpr>,
50) -> datafusion_expr::Expr {
51 let args = vec![expr, lit(n)];
52 if !order_by.is_empty() {
53 nth_value_udaf()
54 .call(args)
55 .order_by(order_by)
56 .build()
57 .unwrap()
58 } else {
59 nth_value_udaf().call(args)
60 }
61}
62
63#[user_doc(
64 doc_section(label = "Statistical Functions"),
65 description = "Returns the nth value in a group of values.",
66 syntax_example = "nth_value(expression, n ORDER BY expression)",
67 sql_example = r#"```sql
68> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
69 FROM employee;
70+---------+--------+-------------------------+
71| dept_id | salary | second_salary_by_dept |
72+---------+--------+-------------------------+
73| 1 | 30000 | NULL |
74| 1 | 40000 | 40000 |
75| 1 | 50000 | 40000 |
76| 2 | 35000 | NULL |
77| 2 | 45000 | 45000 |
78+---------+--------+-------------------------+
79```"#,
80 argument(
81 name = "expression",
82 description = "The column or expression to retrieve the nth value from."
83 ),
84 argument(
85 name = "n",
86 description = "The position (nth) of the value to retrieve, based on the ordering."
87 )
88)]
89#[derive(Debug)]
93pub struct NthValueAgg {
94 signature: Signature,
95}
96
97impl NthValueAgg {
98 pub fn new() -> Self {
100 Self {
101 signature: Signature::any(2, Volatility::Immutable),
102 }
103 }
104}
105
106impl Default for NthValueAgg {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl AggregateUDFImpl for NthValueAgg {
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116
117 fn name(&self) -> &str {
118 "nth_value"
119 }
120
121 fn signature(&self) -> &Signature {
122 &self.signature
123 }
124
125 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126 Ok(arg_types[0].clone())
127 }
128
129 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
130 let n = match acc_args.exprs[1]
131 .as_any()
132 .downcast_ref::<Literal>()
133 .map(|lit| lit.value())
134 {
135 Some(ScalarValue::Int64(Some(value))) => {
136 if acc_args.is_reversed {
137 -*value
138 } else {
139 *value
140 }
141 }
142 _ => {
143 return not_impl_err!(
144 "{} not supported for n: {}",
145 self.name(),
146 &acc_args.exprs[1]
147 )
148 }
149 };
150
151 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
152 return TrivialNthValueAccumulator::try_new(
153 n,
154 acc_args.return_field.data_type(),
155 )
156 .map(|acc| Box::new(acc) as _);
157 };
158 let ordering_dtypes = ordering
159 .iter()
160 .map(|e| e.expr.data_type(acc_args.schema))
161 .collect::<Result<Vec<_>>>()?;
162
163 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
164 NthValueAccumulator::try_new(n, &data_type, &ordering_dtypes, ordering)
165 .map(|acc| Box::new(acc) as _)
166 }
167
168 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
169 let mut fields = vec![Field::new_list(
170 format_state_name(self.name(), "nth_value"),
171 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
173 false,
174 )];
175 let orderings = args.ordering_fields.to_vec();
176 if !orderings.is_empty() {
177 fields.push(Field::new_list(
178 format_state_name(self.name(), "nth_value_orderings"),
179 Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
180 false,
181 ));
182 }
183 Ok(fields.into_iter().map(Arc::new).collect())
184 }
185
186 fn reverse_expr(&self) -> ReversedUDAF {
187 ReversedUDAF::Reversed(nth_value_udaf())
188 }
189
190 fn documentation(&self) -> Option<&Documentation> {
191 self.doc()
192 }
193}
194
195#[derive(Debug)]
196pub struct TrivialNthValueAccumulator {
197 n: i64,
199 values: VecDeque<ScalarValue>,
201 datatype: DataType,
203}
204
205impl TrivialNthValueAccumulator {
206 pub fn try_new(n: i64, datatype: &DataType) -> Result<Self> {
209 if n == 0 {
210 return internal_err!("Nth value indices are 1 based. 0 is invalid index");
212 }
213 Ok(Self {
214 n,
215 values: VecDeque::new(),
216 datatype: datatype.clone(),
217 })
218 }
219
220 fn append_new_data(
223 &mut self,
224 values: &[ArrayRef],
225 fetch: Option<usize>,
226 ) -> Result<()> {
227 let n_row = values[0].len();
228 let n_to_add = if let Some(fetch) = fetch {
229 std::cmp::min(fetch, n_row)
230 } else {
231 n_row
232 };
233 for index in 0..n_to_add {
234 let mut row = get_row_at_idx(values, index)?;
235 self.values.push_back(row.swap_remove(0));
236 }
238 Ok(())
239 }
240}
241
242impl Accumulator for TrivialNthValueAccumulator {
243 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
246 if !values.is_empty() {
247 let n_required = self.n.unsigned_abs() as usize;
248 let from_start = self.n > 0;
249 if from_start {
250 let n_remaining = n_required.saturating_sub(self.values.len());
252 self.append_new_data(values, Some(n_remaining))?;
253 } else {
254 self.append_new_data(values, None)?;
256 let start_offset = self.values.len().saturating_sub(n_required);
257 if start_offset > 0 {
258 self.values.drain(0..start_offset);
259 }
260 }
261 }
262 Ok(())
263 }
264
265 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
266 if !states.is_empty() {
267 let n_required = self.n.unsigned_abs() as usize;
269 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
270 for v in array_agg_res.into_iter() {
271 self.values.extend(v);
272 if self.values.len() > n_required {
273 break;
275 }
276 }
277 }
278 Ok(())
279 }
280
281 fn state(&mut self) -> Result<Vec<ScalarValue>> {
282 let mut values_cloned = self.values.clone();
283 let values_slice = values_cloned.make_contiguous();
284 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
285 values_slice,
286 &self.datatype,
287 ))])
288 }
289
290 fn evaluate(&mut self) -> Result<ScalarValue> {
291 let n_required = self.n.unsigned_abs() as usize;
292 let from_start = self.n > 0;
293 let nth_value_idx = if from_start {
294 let forward_idx = n_required - 1;
296 (forward_idx < self.values.len()).then_some(forward_idx)
297 } else {
298 self.values.len().checked_sub(n_required)
300 };
301 if let Some(idx) = nth_value_idx {
302 Ok(self.values[idx].clone())
303 } else {
304 ScalarValue::try_from(self.datatype.clone())
305 }
306 }
307
308 fn size(&self) -> usize {
309 size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
310 - size_of_val(&self.values)
311 + size_of::<DataType>()
312 }
313}
314
315#[derive(Debug)]
316pub struct NthValueAccumulator {
317 n: i64,
319 values: VecDeque<ScalarValue>,
321 ordering_values: VecDeque<Vec<ScalarValue>>,
326 datatypes: Vec<DataType>,
329 ordering_req: LexOrdering,
331}
332
333impl NthValueAccumulator {
334 pub fn try_new(
337 n: i64,
338 datatype: &DataType,
339 ordering_dtypes: &[DataType],
340 ordering_req: LexOrdering,
341 ) -> Result<Self> {
342 if n == 0 {
343 return internal_err!("Nth value indices are 1 based. 0 is invalid index");
345 }
346 let mut datatypes = vec![datatype.clone()];
347 datatypes.extend(ordering_dtypes.iter().cloned());
348 Ok(Self {
349 n,
350 values: VecDeque::new(),
351 ordering_values: VecDeque::new(),
352 datatypes,
353 ordering_req,
354 })
355 }
356
357 fn evaluate_orderings(&self) -> Result<ScalarValue> {
358 let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
359
360 let mut column_wise_ordering_values = vec![];
361 let num_columns = fields.len();
362 for i in 0..num_columns {
363 let column_values = self
364 .ordering_values
365 .iter()
366 .map(|x| x[i].clone())
367 .collect::<Vec<_>>();
368 let array = if column_values.is_empty() {
369 new_empty_array(fields[i].data_type())
370 } else {
371 ScalarValue::iter_to_array(column_values.into_iter())?
372 };
373 column_wise_ordering_values.push(array);
374 }
375
376 let struct_field = Fields::from(fields);
377 let ordering_array =
378 StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
379
380 Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
381 }
382
383 fn evaluate_values(&self) -> ScalarValue {
384 let mut values_cloned = self.values.clone();
385 let values_slice = values_cloned.make_contiguous();
386 ScalarValue::List(ScalarValue::new_list_nullable(
387 values_slice,
388 &self.datatypes[0],
389 ))
390 }
391
392 fn append_new_data(
395 &mut self,
396 values: &[ArrayRef],
397 fetch: Option<usize>,
398 ) -> Result<()> {
399 let n_row = values[0].len();
400 let n_to_add = if let Some(fetch) = fetch {
401 std::cmp::min(fetch, n_row)
402 } else {
403 n_row
404 };
405 for index in 0..n_to_add {
406 let row = get_row_at_idx(values, index)?;
407 self.values.push_back(row[0].clone());
408 self.ordering_values.push_back(row[2..].to_vec());
411 }
412 Ok(())
413 }
414}
415
416impl Accumulator for NthValueAccumulator {
417 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
420 if values.is_empty() {
421 return Ok(());
422 }
423
424 let n_required = self.n.unsigned_abs() as usize;
425 let from_start = self.n > 0;
426 if from_start {
427 let n_remaining = n_required.saturating_sub(self.values.len());
429 self.append_new_data(values, Some(n_remaining))?;
430 } else {
431 self.append_new_data(values, None)?;
433 let start_offset = self.values.len().saturating_sub(n_required);
434 if start_offset > 0 {
435 self.values.drain(0..start_offset);
436 self.ordering_values.drain(0..start_offset);
437 }
438 }
439
440 Ok(())
441 }
442
443 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
444 if states.is_empty() {
445 return Ok(());
446 }
447 let Some(agg_orderings) = states[1].as_list_opt::<i32>() else {
453 return exec_err!("Expects to receive a list array");
454 };
455
456 let mut partition_values = vec![self.values.clone()];
458 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
460 for v in array_agg_res.into_iter() {
461 partition_values.push(v.into());
462 }
463 let mut partition_ordering_values = vec![self.ordering_values.clone()];
465 let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
466 for partition_ordering_rows in orderings.into_iter() {
468 let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
469 let ScalarValue::Struct(s_array) = ordering_row else {
470 return exec_err!(
471 "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
472 ordering_row.data_type()
473 );
474 };
475 s_array
476 .columns()
477 .iter()
478 .map(|column| ScalarValue::try_from_array(column, 0))
479 .collect()
480 }).collect::<Result<VecDeque<_>>>()?;
481 partition_ordering_values.push(ordering_values);
482 }
483
484 let sort_options = self
485 .ordering_req
486 .iter()
487 .map(|sort_expr| sort_expr.options)
488 .collect::<Vec<_>>();
489 let (new_values, new_orderings) = merge_ordered_arrays(
490 &mut partition_values,
491 &mut partition_ordering_values,
492 &sort_options,
493 )?;
494 self.values = new_values.into();
495 self.ordering_values = new_orderings.into();
496 Ok(())
497 }
498
499 fn state(&mut self) -> Result<Vec<ScalarValue>> {
500 Ok(vec![self.evaluate_values(), self.evaluate_orderings()?])
501 }
502
503 fn evaluate(&mut self) -> Result<ScalarValue> {
504 let n_required = self.n.unsigned_abs() as usize;
505 let from_start = self.n > 0;
506 let nth_value_idx = if from_start {
507 let forward_idx = n_required - 1;
509 (forward_idx < self.values.len()).then_some(forward_idx)
510 } else {
511 self.values.len().checked_sub(n_required)
513 };
514 if let Some(idx) = nth_value_idx {
515 Ok(self.values[idx].clone())
516 } else {
517 ScalarValue::try_from(self.datatypes[0].clone())
518 }
519 }
520
521 fn size(&self) -> usize {
522 let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
523 - size_of_val(&self.values);
524
525 total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
527 for row in &self.ordering_values {
528 total += ScalarValue::size_of_vec(row) - size_of_val(row);
529 }
530
531 total += size_of::<DataType>() * self.datatypes.capacity();
533 for dtype in &self.datatypes {
534 total += dtype.size() - size_of_val(dtype);
535 }
536
537 total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
539 total
541 }
542}