datafusion_functions_aggregate/
nth_value.rs1use std::any::Any;
22use std::collections::VecDeque;
23use std::mem::{size_of, size_of_val};
24use std::sync::Arc;
25
26use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields};
28
29use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
30use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
31use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
32use datafusion_expr::utils::format_state_name;
33use datafusion_expr::{
34 lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
35 Signature, SortExpr, Volatility,
36};
37use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
38use datafusion_functions_aggregate_common::utils::ordering_fields;
39use datafusion_macros::user_doc;
40use datafusion_physical_expr::expressions::Literal;
41use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
42
43create_func!(NthValueAgg, nth_value_udaf);
44
45pub fn nth_value(
47 expr: datafusion_expr::Expr,
48 n: i64,
49 order_by: Vec<SortExpr>,
50) -> datafusion_expr::Expr {
51 let args = vec![expr, lit(n)];
52 if !order_by.is_empty() {
53 nth_value_udaf()
54 .call(args)
55 .order_by(order_by)
56 .build()
57 .unwrap()
58 } else {
59 nth_value_udaf().call(args)
60 }
61}
62
63#[user_doc(
64 doc_section(label = "Statistical Functions"),
65 description = "Returns the nth value in a group of values.",
66 syntax_example = "nth_value(expression, n ORDER BY expression)",
67 sql_example = r#"```sql
68> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
69 FROM employee;
70+---------+--------+-------------------------+
71| dept_id | salary | second_salary_by_dept |
72+---------+--------+-------------------------+
73| 1 | 30000 | NULL |
74| 1 | 40000 | 40000 |
75| 1 | 50000 | 40000 |
76| 2 | 35000 | NULL |
77| 2 | 45000 | 45000 |
78+---------+--------+-------------------------+
79```"#,
80 argument(
81 name = "expression",
82 description = "The column or expression to retrieve the nth value from."
83 ),
84 argument(
85 name = "n",
86 description = "The position (nth) of the value to retrieve, based on the ordering."
87 )
88)]
89#[derive(Debug)]
93pub struct NthValueAgg {
94 signature: Signature,
95}
96
97impl NthValueAgg {
98 pub fn new() -> Self {
100 Self {
101 signature: Signature::any(2, Volatility::Immutable),
102 }
103 }
104}
105
106impl Default for NthValueAgg {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl AggregateUDFImpl for NthValueAgg {
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116
117 fn name(&self) -> &str {
118 "nth_value"
119 }
120
121 fn signature(&self) -> &Signature {
122 &self.signature
123 }
124
125 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126 Ok(arg_types[0].clone())
127 }
128
129 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
130 let n = match acc_args.exprs[1]
131 .as_any()
132 .downcast_ref::<Literal>()
133 .map(|lit| lit.value())
134 {
135 Some(ScalarValue::Int64(Some(value))) => {
136 if acc_args.is_reversed {
137 -*value
138 } else {
139 *value
140 }
141 }
142 _ => {
143 return not_impl_err!(
144 "{} not supported for n: {}",
145 self.name(),
146 &acc_args.exprs[1]
147 )
148 }
149 };
150
151 let ordering_dtypes = acc_args
152 .ordering_req
153 .iter()
154 .map(|e| e.expr.data_type(acc_args.schema))
155 .collect::<Result<Vec<_>>>()?;
156
157 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
158 NthValueAccumulator::try_new(
159 n,
160 &data_type,
161 &ordering_dtypes,
162 acc_args.ordering_req.clone(),
163 )
164 .map(|acc| Box::new(acc) as _)
165 }
166
167 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
168 let mut fields = vec![Field::new_list(
169 format_state_name(self.name(), "nth_value"),
170 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
172 false,
173 )];
174 let orderings = args.ordering_fields.to_vec();
175 if !orderings.is_empty() {
176 fields.push(Field::new_list(
177 format_state_name(self.name(), "nth_value_orderings"),
178 Field::new_list_field(DataType::Struct(Fields::from(orderings)), true),
179 false,
180 ));
181 }
182 Ok(fields.into_iter().map(Arc::new).collect())
183 }
184
185 fn aliases(&self) -> &[String] {
186 &[]
187 }
188
189 fn reverse_expr(&self) -> ReversedUDAF {
190 ReversedUDAF::Reversed(nth_value_udaf())
191 }
192
193 fn documentation(&self) -> Option<&Documentation> {
194 self.doc()
195 }
196}
197
198#[derive(Debug)]
199pub struct NthValueAccumulator {
200 n: i64,
202 values: VecDeque<ScalarValue>,
204 ordering_values: VecDeque<Vec<ScalarValue>>,
209 datatypes: Vec<DataType>,
212 ordering_req: LexOrdering,
214}
215
216impl NthValueAccumulator {
217 pub fn try_new(
220 n: i64,
221 datatype: &DataType,
222 ordering_dtypes: &[DataType],
223 ordering_req: LexOrdering,
224 ) -> Result<Self> {
225 if n == 0 {
226 return internal_err!("Nth value indices are 1 based. 0 is invalid index");
228 }
229 let mut datatypes = vec![datatype.clone()];
230 datatypes.extend(ordering_dtypes.iter().cloned());
231 Ok(Self {
232 n,
233 values: VecDeque::new(),
234 ordering_values: VecDeque::new(),
235 datatypes,
236 ordering_req,
237 })
238 }
239}
240
241impl Accumulator for NthValueAccumulator {
242 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
245 if values.is_empty() {
246 return Ok(());
247 }
248
249 let n_required = self.n.unsigned_abs() as usize;
250 let from_start = self.n > 0;
251 if from_start {
252 let n_remaining = n_required.saturating_sub(self.values.len());
254 self.append_new_data(values, Some(n_remaining))?;
255 } else {
256 self.append_new_data(values, None)?;
258 let start_offset = self.values.len().saturating_sub(n_required);
259 if start_offset > 0 {
260 self.values.drain(0..start_offset);
261 self.ordering_values.drain(0..start_offset);
262 }
263 }
264
265 Ok(())
266 }
267
268 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
269 if states.is_empty() {
270 return Ok(());
271 }
272 let array_agg_values = &states[0];
274 let n_required = self.n.unsigned_abs() as usize;
275 if self.ordering_req.is_empty() {
276 let array_agg_res =
277 ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
278 for v in array_agg_res.into_iter() {
279 self.values.extend(v);
280 if self.values.len() > n_required {
281 break;
283 }
284 }
285 } else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() {
286 let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![];
292 let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![];
294
295 partition_values.push(self.values.clone());
297
298 partition_ordering_values.push(self.ordering_values.clone());
299
300 let array_agg_res =
301 ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
302
303 for v in array_agg_res.into_iter() {
304 partition_values.push(v.into());
305 }
306
307 let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
308
309 let ordering_values = orderings.into_iter().map(|partition_ordering_rows| {
310 partition_ordering_rows.into_iter().map(|ordering_row| {
312 if let ScalarValue::Struct(s) = ordering_row {
313 let mut ordering_columns_per_row = vec![];
314
315 for column in s.columns() {
316 let sv = ScalarValue::try_from_array(column, 0)?;
317 ordering_columns_per_row.push(sv);
318 }
319
320 Ok(ordering_columns_per_row)
321 } else {
322 exec_err!(
323 "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
324 ordering_row.data_type()
325 )
326 }
327 }).collect::<Result<Vec<_>>>()
328 }).collect::<Result<Vec<_>>>()?;
329 for ordering_values in ordering_values.into_iter() {
330 partition_ordering_values.push(ordering_values.into());
331 }
332
333 let sort_options = self
334 .ordering_req
335 .iter()
336 .map(|sort_expr| sort_expr.options)
337 .collect::<Vec<_>>();
338 let (new_values, new_orderings) = merge_ordered_arrays(
339 &mut partition_values,
340 &mut partition_ordering_values,
341 &sort_options,
342 )?;
343 self.values = new_values.into();
344 self.ordering_values = new_orderings.into();
345 } else {
346 return exec_err!("Expects to receive a list array");
347 }
348 Ok(())
349 }
350
351 fn state(&mut self) -> Result<Vec<ScalarValue>> {
352 let mut result = vec![self.evaluate_values()];
353 if !self.ordering_req.is_empty() {
354 result.push(self.evaluate_orderings()?);
355 }
356 Ok(result)
357 }
358
359 fn evaluate(&mut self) -> Result<ScalarValue> {
360 let n_required = self.n.unsigned_abs() as usize;
361 let from_start = self.n > 0;
362 let nth_value_idx = if from_start {
363 let forward_idx = n_required - 1;
365 (forward_idx < self.values.len()).then_some(forward_idx)
366 } else {
367 self.values.len().checked_sub(n_required)
369 };
370 if let Some(idx) = nth_value_idx {
371 Ok(self.values[idx].clone())
372 } else {
373 ScalarValue::try_from(self.datatypes[0].clone())
374 }
375 }
376
377 fn size(&self) -> usize {
378 let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
379 - size_of_val(&self.values);
380
381 total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
383 for row in &self.ordering_values {
384 total += ScalarValue::size_of_vec(row) - size_of_val(row);
385 }
386
387 total += size_of::<DataType>() * self.datatypes.capacity();
389 for dtype in &self.datatypes {
390 total += dtype.size() - size_of_val(dtype);
391 }
392
393 total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
395 total
397 }
398}
399
400impl NthValueAccumulator {
401 fn evaluate_orderings(&self) -> Result<ScalarValue> {
402 let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
403
404 let mut column_wise_ordering_values = vec![];
405 let num_columns = fields.len();
406 for i in 0..num_columns {
407 let column_values = self
408 .ordering_values
409 .iter()
410 .map(|x| x[i].clone())
411 .collect::<Vec<_>>();
412 let array = if column_values.is_empty() {
413 new_empty_array(fields[i].data_type())
414 } else {
415 ScalarValue::iter_to_array(column_values.into_iter())?
416 };
417 column_wise_ordering_values.push(array);
418 }
419
420 let struct_field = Fields::from(fields);
421 let ordering_array =
422 StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
423
424 Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar())
425 }
426
427 fn evaluate_values(&self) -> ScalarValue {
428 let mut values_cloned = self.values.clone();
429 let values_slice = values_cloned.make_contiguous();
430 ScalarValue::List(ScalarValue::new_list_nullable(
431 values_slice,
432 &self.datatypes[0],
433 ))
434 }
435
436 fn append_new_data(
439 &mut self,
440 values: &[ArrayRef],
441 fetch: Option<usize>,
442 ) -> Result<()> {
443 let n_row = values[0].len();
444 let n_to_add = if let Some(fetch) = fetch {
445 std::cmp::min(fetch, n_row)
446 } else {
447 n_row
448 };
449 for index in 0..n_to_add {
450 let row = get_row_at_idx(values, index)?;
451 self.values.push_back(row[0].clone());
452 self.ordering_values.push_back(row[2..].to_vec());
455 }
456 Ok(())
457 }
458}