datafusion_functions_aggregate/
nth_value.rs1use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{ArrayRef, AsArray, StructArray, new_empty_array};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{SingleRowListArrayBuilder, get_row_at_idx};
30use datafusion_common::{
31 Result, ScalarValue, assert_or_internal_err, exec_err, not_impl_err,
32};
33use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
34use datafusion_expr::utils::format_state_name;
35use datafusion_expr::{
36 Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
37 Signature, SortExpr, Volatility, lit,
38};
39use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
40use datafusion_functions_aggregate_common::utils::ordering_fields;
41use datafusion_macros::user_doc;
42use datafusion_physical_expr::expressions::Literal;
43use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
44
45create_func!(NthValueAgg, nth_value_udaf);
46
47pub fn nth_value(
49 expr: datafusion_expr::Expr,
50 n: i64,
51 order_by: Vec<SortExpr>,
52) -> datafusion_expr::Expr {
53 let args = vec![expr, lit(n)];
54 if !order_by.is_empty() {
55 nth_value_udaf()
56 .call(args)
57 .order_by(order_by)
58 .build()
59 .unwrap()
60 } else {
61 nth_value_udaf().call(args)
62 }
63}
64
65#[user_doc(
66 doc_section(label = "Statistical Functions"),
67 description = "Returns the nth value in a group of values.",
68 syntax_example = "nth_value(expression, n ORDER BY expression)",
69 sql_example = r#"```sql
70> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
71 FROM employee;
72+---------+--------+-------------------------+
73| dept_id | salary | second_salary_by_dept |
74+---------+--------+-------------------------+
75| 1 | 30000 | NULL |
76| 1 | 40000 | 40000 |
77| 1 | 50000 | 40000 |
78| 2 | 35000 | NULL |
79| 2 | 45000 | 45000 |
80+---------+--------+-------------------------+
81```"#,
82 argument(
83 name = "expression",
84 description = "The column or expression to retrieve the nth value from."
85 ),
86 argument(
87 name = "n",
88 description = "The position (nth) of the value to retrieve, based on the ordering."
89 )
90)]
91#[derive(Debug, PartialEq, Eq, Hash)]
95pub struct NthValueAgg {
96 signature: Signature,
97}
98
99impl NthValueAgg {
100 pub fn new() -> Self {
102 Self {
103 signature: Signature::any(2, Volatility::Immutable),
104 }
105 }
106}
107
108impl Default for NthValueAgg {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl AggregateUDFImpl for NthValueAgg {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "nth_value"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
128 Ok(arg_types[0].clone())
129 }
130
131 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
132 let n = match acc_args.exprs[1]
133 .as_any()
134 .downcast_ref::<Literal>()
135 .map(|lit| lit.value())
136 {
137 Some(ScalarValue::Int64(Some(value))) => {
138 if acc_args.is_reversed {
139 -*value
140 } else {
141 *value
142 }
143 }
144 _ => {
145 return not_impl_err!(
146 "{} not supported for n: {}",
147 self.name(),
148 &acc_args.exprs[1]
149 );
150 }
151 };
152
153 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
154 return TrivialNthValueAccumulator::try_new(
155 n,
156 acc_args.return_field.data_type(),
157 )
158 .map(|acc| Box::new(acc) as _);
159 };
160 let ordering_dtypes = ordering
161 .iter()
162 .map(|e| e.expr.data_type(acc_args.schema))
163 .collect::<Result<Vec<_>>>()?;
164
165 let data_type = acc_args.expr_fields[0].data_type();
166 NthValueAccumulator::try_new(n, data_type, &ordering_dtypes, ordering)
167 .map(|acc| Box::new(acc) as _)
168 }
169
170 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
171 let mut fields = vec![Field::new_list(
172 format_state_name(self.name(), "nth_value"),
173 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
175 false,
176 )];
177 let orderings = args.ordering_fields.to_vec();
178 if !orderings.is_empty() {
179 fields.push(Field::new_list(
180 format_state_name(self.name(), "nth_value_orderings"),
181 Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
182 false,
183 ));
184 }
185 Ok(fields.into_iter().map(Arc::new).collect())
186 }
187
188 fn reverse_expr(&self) -> ReversedUDAF {
189 ReversedUDAF::Reversed(nth_value_udaf())
190 }
191
192 fn documentation(&self) -> Option<&Documentation> {
193 self.doc()
194 }
195}
196
197#[derive(Debug)]
198pub struct TrivialNthValueAccumulator {
199 n: i64,
201 values: VecDeque<ScalarValue>,
203 datatype: DataType,
205}
206
207impl TrivialNthValueAccumulator {
208 pub fn try_new(n: i64, datatype: &DataType) -> Result<Self> {
211 assert_or_internal_err!(
213 n != 0,
214 "Nth value indices are 1 based. 0 is invalid index"
215 );
216 Ok(Self {
217 n,
218 values: VecDeque::new(),
219 datatype: datatype.clone(),
220 })
221 }
222
223 fn append_new_data(
226 &mut self,
227 values: &[ArrayRef],
228 fetch: Option<usize>,
229 ) -> Result<()> {
230 let n_row = values[0].len();
231 let n_to_add = if let Some(fetch) = fetch {
232 std::cmp::min(fetch, n_row)
233 } else {
234 n_row
235 };
236 for index in 0..n_to_add {
237 let mut row = get_row_at_idx(values, index)?;
238 self.values.push_back(row.swap_remove(0));
239 }
241 Ok(())
242 }
243}
244
245impl Accumulator for TrivialNthValueAccumulator {
246 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
249 if !values.is_empty() {
250 let n_required = self.n.unsigned_abs() as usize;
251 let from_start = self.n > 0;
252 if from_start {
253 let n_remaining = n_required.saturating_sub(self.values.len());
255 self.append_new_data(values, Some(n_remaining))?;
256 } else {
257 self.append_new_data(values, None)?;
259 let start_offset = self.values.len().saturating_sub(n_required);
260 if start_offset > 0 {
261 self.values.drain(0..start_offset);
262 }
263 }
264 }
265 Ok(())
266 }
267
268 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
269 if !states.is_empty() {
270 let n_required = self.n.unsigned_abs() as usize;
272 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
273 for v in array_agg_res.into_iter().flatten() {
274 self.values.extend(v);
275 if self.values.len() > n_required {
276 break;
278 }
279 }
280 }
281 Ok(())
282 }
283
284 fn state(&mut self) -> Result<Vec<ScalarValue>> {
285 let mut values_cloned = self.values.clone();
286 let values_slice = values_cloned.make_contiguous();
287 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
288 values_slice,
289 &self.datatype,
290 ))])
291 }
292
293 fn evaluate(&mut self) -> Result<ScalarValue> {
294 let n_required = self.n.unsigned_abs() as usize;
295 let from_start = self.n > 0;
296 let nth_value_idx = if from_start {
297 let forward_idx = n_required - 1;
299 (forward_idx < self.values.len()).then_some(forward_idx)
300 } else {
301 self.values.len().checked_sub(n_required)
303 };
304 if let Some(idx) = nth_value_idx {
305 Ok(self.values[idx].clone())
306 } else {
307 ScalarValue::try_from(self.datatype.clone())
308 }
309 }
310
311 fn size(&self) -> usize {
312 size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
313 - size_of_val(&self.values)
314 + size_of::<DataType>()
315 }
316}
317
318#[derive(Debug)]
319pub struct NthValueAccumulator {
320 n: i64,
322 values: VecDeque<ScalarValue>,
324 ordering_values: VecDeque<Vec<ScalarValue>>,
329 datatypes: Vec<DataType>,
332 ordering_req: LexOrdering,
334}
335
336impl NthValueAccumulator {
337 pub fn try_new(
340 n: i64,
341 datatype: &DataType,
342 ordering_dtypes: &[DataType],
343 ordering_req: LexOrdering,
344 ) -> Result<Self> {
345 assert_or_internal_err!(
347 n != 0,
348 "Nth value indices are 1 based. 0 is invalid index"
349 );
350 let mut datatypes = vec![datatype.clone()];
351 datatypes.extend(ordering_dtypes.iter().cloned());
352 Ok(Self {
353 n,
354 values: VecDeque::new(),
355 ordering_values: VecDeque::new(),
356 datatypes,
357 ordering_req,
358 })
359 }
360
361 fn evaluate_orderings(&self) -> Result<ScalarValue> {
362 let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
363
364 let mut column_wise_ordering_values = vec![];
365 let num_columns = fields.len();
366 for i in 0..num_columns {
367 let column_values = self
368 .ordering_values
369 .iter()
370 .map(|x| x[i].clone())
371 .collect::<Vec<_>>();
372 let array = if column_values.is_empty() {
373 new_empty_array(fields[i].data_type())
374 } else {
375 ScalarValue::iter_to_array(column_values.into_iter())?
376 };
377 column_wise_ordering_values.push(array);
378 }
379
380 let struct_field = Fields::from(fields);
381 let ordering_array =
382 StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
383
384 Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
385 }
386
387 fn evaluate_values(&self) -> ScalarValue {
388 let mut values_cloned = self.values.clone();
389 let values_slice = values_cloned.make_contiguous();
390 ScalarValue::List(ScalarValue::new_list_nullable(
391 values_slice,
392 &self.datatypes[0],
393 ))
394 }
395
396 fn append_new_data(
399 &mut self,
400 values: &[ArrayRef],
401 fetch: Option<usize>,
402 ) -> Result<()> {
403 let n_row = values[0].len();
404 let n_to_add = if let Some(fetch) = fetch {
405 std::cmp::min(fetch, n_row)
406 } else {
407 n_row
408 };
409 for index in 0..n_to_add {
410 let row = get_row_at_idx(values, index)?;
411 self.values.push_back(row[0].clone());
412 self.ordering_values.push_back(row[2..].to_vec());
415 }
416 Ok(())
417 }
418}
419
420impl Accumulator for NthValueAccumulator {
421 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
424 if values.is_empty() {
425 return Ok(());
426 }
427
428 let n_required = self.n.unsigned_abs() as usize;
429 let from_start = self.n > 0;
430 if from_start {
431 let n_remaining = n_required.saturating_sub(self.values.len());
433 self.append_new_data(values, Some(n_remaining))?;
434 } else {
435 self.append_new_data(values, None)?;
437 let start_offset = self.values.len().saturating_sub(n_required);
438 if start_offset > 0 {
439 self.values.drain(0..start_offset);
440 self.ordering_values.drain(0..start_offset);
441 }
442 }
443
444 Ok(())
445 }
446
447 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
448 if states.is_empty() {
449 return Ok(());
450 }
451 let Some(agg_orderings) = states[1].as_list_opt::<i32>() else {
457 return exec_err!("Expects to receive a list array");
458 };
459
460 let mut partition_values = vec![self.values.clone()];
462 let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
464 for v in array_agg_res.into_iter().flatten() {
465 partition_values.push(v.into());
466 }
467 let mut partition_ordering_values = vec![self.ordering_values.clone()];
469 let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
470 for partition_ordering_rows in orderings.into_iter().flatten() {
472 let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
473 let ScalarValue::Struct(s_array) = ordering_row else {
474 return exec_err!(
475 "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
476 ordering_row.data_type()
477 );
478 };
479 s_array
480 .columns()
481 .iter()
482 .map(|column| ScalarValue::try_from_array(column, 0))
483 .collect()
484 }).collect::<Result<VecDeque<_>>>()?;
485 partition_ordering_values.push(ordering_values);
486 }
487
488 let sort_options = self
489 .ordering_req
490 .iter()
491 .map(|sort_expr| sort_expr.options)
492 .collect::<Vec<_>>();
493 let (new_values, new_orderings) = merge_ordered_arrays(
494 &mut partition_values,
495 &mut partition_ordering_values,
496 &sort_options,
497 )?;
498 self.values = new_values.into();
499 self.ordering_values = new_orderings.into();
500 Ok(())
501 }
502
503 fn state(&mut self) -> Result<Vec<ScalarValue>> {
504 Ok(vec![self.evaluate_values(), self.evaluate_orderings()?])
505 }
506
507 fn evaluate(&mut self) -> Result<ScalarValue> {
508 let n_required = self.n.unsigned_abs() as usize;
509 let from_start = self.n > 0;
510 let nth_value_idx = if from_start {
511 let forward_idx = n_required - 1;
513 (forward_idx < self.values.len()).then_some(forward_idx)
514 } else {
515 self.values.len().checked_sub(n_required)
517 };
518 if let Some(idx) = nth_value_idx {
519 Ok(self.values[idx].clone())
520 } else {
521 ScalarValue::try_from(self.datatypes[0].clone())
522 }
523 }
524
525 fn size(&self) -> usize {
526 let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
527 - size_of_val(&self.values);
528
529 total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
531 for row in &self.ordering_values {
532 total += ScalarValue::size_of_vec(row) - size_of_val(row);
533 }
534
535 total += size_of::<DataType>() * self.datatypes.capacity();
537 for dtype in &self.datatypes {
538 total += dtype.size() - size_of_val(dtype);
539 }
540
541 total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
543 total
545 }
546}