1use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{ArrayRef, AsArray, BooleanArray};
26use arrow::compute::{self, LexicographicalComparator, SortColumn};
27use arrow::datatypes::{DataType, Field};
28use datafusion_common::utils::{compare_rows, get_row_at_idx};
29use datafusion_common::{
30 arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
31};
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
34use datafusion_expr::{
35 Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature,
36 SortExpr, Volatility,
37};
38use datafusion_functions_aggregate_common::utils::get_sort_options;
39use datafusion_macros::user_doc;
40use datafusion_physical_expr_common::sort_expr::LexOrdering;
41
42create_func!(FirstValue, first_value_udaf);
43
44pub fn first_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
46 if let Some(order_by) = order_by {
47 first_value_udaf()
48 .call(vec![expression])
49 .order_by(order_by)
50 .build()
51 .unwrap()
53 } else {
54 first_value_udaf().call(vec![expression])
55 }
56}
57
58#[user_doc(
59 doc_section(label = "General Functions"),
60 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
61 syntax_example = "first_value(expression [ORDER BY expression])",
62 sql_example = r#"```sql
63> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
64+-----------------------------------------------+
65| first_value(column_name ORDER BY other_column)|
66+-----------------------------------------------+
67| first_element |
68+-----------------------------------------------+
69```"#,
70 standard_argument(name = "expression",)
71)]
72pub struct FirstValue {
73 signature: Signature,
74 requirement_satisfied: bool,
75}
76
77impl Debug for FirstValue {
78 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
79 f.debug_struct("FirstValue")
80 .field("name", &self.name())
81 .field("signature", &self.signature)
82 .field("accumulator", &"<FUNC>")
83 .finish()
84 }
85}
86
87impl Default for FirstValue {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl FirstValue {
94 pub fn new() -> Self {
95 Self {
96 signature: Signature::any(1, Volatility::Immutable),
97 requirement_satisfied: false,
98 }
99 }
100
101 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
102 self.requirement_satisfied = requirement_satisfied;
103 self
104 }
105}
106
107impl AggregateUDFImpl for FirstValue {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 fn name(&self) -> &str {
113 "first_value"
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
121 Ok(arg_types[0].clone())
122 }
123
124 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
125 let ordering_dtypes = acc_args
126 .ordering_req
127 .iter()
128 .map(|e| e.expr.data_type(acc_args.schema))
129 .collect::<Result<Vec<_>>>()?;
130
131 let requirement_satisfied =
134 acc_args.ordering_req.is_empty() || self.requirement_satisfied;
135
136 FirstValueAccumulator::try_new(
137 acc_args.return_type,
138 &ordering_dtypes,
139 acc_args.ordering_req.clone(),
140 acc_args.ignore_nulls,
141 )
142 .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
143 }
144
145 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
146 let mut fields = vec![Field::new(
147 format_state_name(args.name, "first_value"),
148 args.return_type.clone(),
149 true,
150 )];
151 fields.extend(args.ordering_fields.to_vec());
152 fields.push(Field::new("is_set", DataType::Boolean, true));
153 Ok(fields)
154 }
155
156 fn aliases(&self) -> &[String] {
157 &[]
158 }
159
160 fn with_beneficial_ordering(
161 self: Arc<Self>,
162 beneficial_ordering: bool,
163 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
164 Ok(Some(Arc::new(
165 FirstValue::new().with_requirement_satisfied(beneficial_ordering),
166 )))
167 }
168
169 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
170 AggregateOrderSensitivity::Beneficial
171 }
172
173 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
174 datafusion_expr::ReversedUDAF::Reversed(last_value_udaf())
175 }
176
177 fn documentation(&self) -> Option<&Documentation> {
178 self.doc()
179 }
180}
181
182#[derive(Debug)]
183pub struct FirstValueAccumulator {
184 first: ScalarValue,
185 is_set: bool,
188 orderings: Vec<ScalarValue>,
191 ordering_req: LexOrdering,
193 requirement_satisfied: bool,
195 ignore_nulls: bool,
197}
198
199impl FirstValueAccumulator {
200 pub fn try_new(
202 data_type: &DataType,
203 ordering_dtypes: &[DataType],
204 ordering_req: LexOrdering,
205 ignore_nulls: bool,
206 ) -> Result<Self> {
207 let orderings = ordering_dtypes
208 .iter()
209 .map(ScalarValue::try_from)
210 .collect::<Result<Vec<_>>>()?;
211 let requirement_satisfied = ordering_req.is_empty();
212 ScalarValue::try_from(data_type).map(|first| Self {
213 first,
214 is_set: false,
215 orderings,
216 ordering_req,
217 requirement_satisfied,
218 ignore_nulls,
219 })
220 }
221
222 pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
223 self.requirement_satisfied = requirement_satisfied;
224 self
225 }
226
227 fn update_with_new_row(&mut self, row: &[ScalarValue]) {
229 self.first = row[0].clone();
230 self.orderings = row[1..].to_vec();
231 self.is_set = true;
232 }
233
234 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
235 let [value, ordering_values @ ..] = values else {
236 return internal_err!("Empty row in FIRST_VALUE");
237 };
238 if self.requirement_satisfied {
239 if self.ignore_nulls {
241 for i in 0..value.len() {
243 if !value.is_null(i) {
244 return Ok(Some(i));
245 }
246 }
247 return Ok(None);
248 } else {
249 return Ok((!value.is_empty()).then_some(0));
251 }
252 }
253
254 let sort_columns = ordering_values
255 .iter()
256 .zip(self.ordering_req.iter())
257 .map(|(values, req)| SortColumn {
258 values: Arc::clone(values),
259 options: Some(req.options),
260 })
261 .collect::<Vec<_>>();
262
263 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
264
265 let min_index = if self.ignore_nulls {
266 (0..value.len())
267 .filter(|&index| !value.is_null(index))
268 .min_by(|&a, &b| comparator.compare(a, b))
269 } else {
270 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
271 };
272
273 Ok(min_index)
274 }
275}
276
277impl Accumulator for FirstValueAccumulator {
278 fn state(&mut self) -> Result<Vec<ScalarValue>> {
279 let mut result = vec![self.first.clone()];
280 result.extend(self.orderings.iter().cloned());
281 result.push(ScalarValue::Boolean(Some(self.is_set)));
282 Ok(result)
283 }
284
285 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
286 if !self.is_set {
287 if let Some(first_idx) = self.get_first_idx(values)? {
288 let row = get_row_at_idx(values, first_idx)?;
289 self.update_with_new_row(&row);
290 }
291 } else if !self.requirement_satisfied {
292 if let Some(first_idx) = self.get_first_idx(values)? {
293 let row = get_row_at_idx(values, first_idx)?;
294 let orderings = &row[1..];
295 if compare_rows(
296 &self.orderings,
297 orderings,
298 &get_sort_options(self.ordering_req.as_ref()),
299 )?
300 .is_gt()
301 {
302 self.update_with_new_row(&row);
303 }
304 }
305 }
306 Ok(())
307 }
308
309 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
310 let is_set_idx = states.len() - 1;
313 let flags = states[is_set_idx].as_boolean();
314 let filtered_states =
315 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
316 let sort_columns = convert_to_sort_cols(
318 &filtered_states[1..is_set_idx],
319 self.ordering_req.as_ref(),
320 );
321
322 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
323 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
324
325 if let Some(first_idx) = min {
326 let first_row = get_row_at_idx(&filtered_states, first_idx)?;
327 let first_ordering = &first_row[1..is_set_idx];
329 let sort_options = get_sort_options(self.ordering_req.as_ref());
330 if !self.is_set
332 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
333 {
334 self.update_with_new_row(&first_row[0..is_set_idx]);
338 }
339 }
340 Ok(())
341 }
342
343 fn evaluate(&mut self) -> Result<ScalarValue> {
344 Ok(self.first.clone())
345 }
346
347 fn size(&self) -> usize {
348 size_of_val(self) - size_of_val(&self.first)
349 + self.first.size()
350 + ScalarValue::size_of_vec(&self.orderings)
351 - size_of_val(&self.orderings)
352 }
353}
354
355make_udaf_expr_and_func!(
356 LastValue,
357 last_value,
358 "Returns the last value in a group of values.",
359 last_value_udaf
360);
361
362#[user_doc(
363 doc_section(label = "General Functions"),
364 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
365 syntax_example = "last_value(expression [ORDER BY expression])",
366 sql_example = r#"```sql
367> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
368+-----------------------------------------------+
369| last_value(column_name ORDER BY other_column) |
370+-----------------------------------------------+
371| last_element |
372+-----------------------------------------------+
373```"#,
374 standard_argument(name = "expression",)
375)]
376pub struct LastValue {
377 signature: Signature,
378 requirement_satisfied: bool,
379}
380
381impl Debug for LastValue {
382 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
383 f.debug_struct("LastValue")
384 .field("name", &self.name())
385 .field("signature", &self.signature)
386 .field("accumulator", &"<FUNC>")
387 .finish()
388 }
389}
390
391impl Default for LastValue {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397impl LastValue {
398 pub fn new() -> Self {
399 Self {
400 signature: Signature::any(1, Volatility::Immutable),
401 requirement_satisfied: false,
402 }
403 }
404
405 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
406 self.requirement_satisfied = requirement_satisfied;
407 self
408 }
409}
410
411impl AggregateUDFImpl for LastValue {
412 fn as_any(&self) -> &dyn Any {
413 self
414 }
415
416 fn name(&self) -> &str {
417 "last_value"
418 }
419
420 fn signature(&self) -> &Signature {
421 &self.signature
422 }
423
424 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
425 Ok(arg_types[0].clone())
426 }
427
428 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
429 let ordering_dtypes = acc_args
430 .ordering_req
431 .iter()
432 .map(|e| e.expr.data_type(acc_args.schema))
433 .collect::<Result<Vec<_>>>()?;
434
435 let requirement_satisfied =
436 acc_args.ordering_req.is_empty() || self.requirement_satisfied;
437
438 LastValueAccumulator::try_new(
439 acc_args.return_type,
440 &ordering_dtypes,
441 acc_args.ordering_req.clone(),
442 acc_args.ignore_nulls,
443 )
444 .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
445 }
446
447 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
448 let StateFieldsArgs {
449 name,
450 input_types,
451 return_type: _,
452 ordering_fields,
453 is_distinct: _,
454 } = args;
455 let mut fields = vec![Field::new(
456 format_state_name(name, "last_value"),
457 input_types[0].clone(),
458 true,
459 )];
460 fields.extend(ordering_fields.to_vec());
461 fields.push(Field::new("is_set", DataType::Boolean, true));
462 Ok(fields)
463 }
464
465 fn aliases(&self) -> &[String] {
466 &[]
467 }
468
469 fn with_beneficial_ordering(
470 self: Arc<Self>,
471 beneficial_ordering: bool,
472 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
473 Ok(Some(Arc::new(
474 LastValue::new().with_requirement_satisfied(beneficial_ordering),
475 )))
476 }
477
478 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
479 AggregateOrderSensitivity::Beneficial
480 }
481
482 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
483 datafusion_expr::ReversedUDAF::Reversed(first_value_udaf())
484 }
485
486 fn documentation(&self) -> Option<&Documentation> {
487 self.doc()
488 }
489}
490
491#[derive(Debug)]
492struct LastValueAccumulator {
493 last: ScalarValue,
494 is_set: bool,
498 orderings: Vec<ScalarValue>,
499 ordering_req: LexOrdering,
501 requirement_satisfied: bool,
503 ignore_nulls: bool,
505}
506
507impl LastValueAccumulator {
508 pub fn try_new(
510 data_type: &DataType,
511 ordering_dtypes: &[DataType],
512 ordering_req: LexOrdering,
513 ignore_nulls: bool,
514 ) -> Result<Self> {
515 let orderings = ordering_dtypes
516 .iter()
517 .map(ScalarValue::try_from)
518 .collect::<Result<Vec<_>>>()?;
519 let requirement_satisfied = ordering_req.is_empty();
520 ScalarValue::try_from(data_type).map(|last| Self {
521 last,
522 is_set: false,
523 orderings,
524 ordering_req,
525 requirement_satisfied,
526 ignore_nulls,
527 })
528 }
529
530 fn update_with_new_row(&mut self, row: &[ScalarValue]) {
532 self.last = row[0].clone();
533 self.orderings = row[1..].to_vec();
534 self.is_set = true;
535 }
536
537 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
538 let [value, ordering_values @ ..] = values else {
539 return internal_err!("Empty row in LAST_VALUE");
540 };
541 if self.requirement_satisfied {
542 if self.ignore_nulls {
544 for i in (0..value.len()).rev() {
546 if !value.is_null(i) {
547 return Ok(Some(i));
548 }
549 }
550 return Ok(None);
551 } else {
552 return Ok((!value.is_empty()).then_some(value.len() - 1));
553 }
554 }
555 let sort_columns = ordering_values
556 .iter()
557 .zip(self.ordering_req.iter())
558 .map(|(values, req)| SortColumn {
559 values: Arc::clone(values),
560 options: Some(req.options),
561 })
562 .collect::<Vec<_>>();
563
564 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
565 let max_ind = if self.ignore_nulls {
566 (0..value.len())
567 .filter(|&index| !(value.is_null(index)))
568 .max_by(|&a, &b| comparator.compare(a, b))
569 } else {
570 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
571 };
572
573 Ok(max_ind)
574 }
575
576 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
577 self.requirement_satisfied = requirement_satisfied;
578 self
579 }
580}
581
582impl Accumulator for LastValueAccumulator {
583 fn state(&mut self) -> Result<Vec<ScalarValue>> {
584 let mut result = vec![self.last.clone()];
585 result.extend(self.orderings.clone());
586 result.push(ScalarValue::Boolean(Some(self.is_set)));
587 Ok(result)
588 }
589
590 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
591 if !self.is_set || self.requirement_satisfied {
592 if let Some(last_idx) = self.get_last_idx(values)? {
593 let row = get_row_at_idx(values, last_idx)?;
594 self.update_with_new_row(&row);
595 }
596 } else if let Some(last_idx) = self.get_last_idx(values)? {
597 let row = get_row_at_idx(values, last_idx)?;
598 let orderings = &row[1..];
599 if compare_rows(
601 &self.orderings,
602 orderings,
603 &get_sort_options(self.ordering_req.as_ref()),
604 )?
605 .is_lt()
606 {
607 self.update_with_new_row(&row);
608 }
609 }
610
611 Ok(())
612 }
613
614 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
615 let is_set_idx = states.len() - 1;
618 let flags = states[is_set_idx].as_boolean();
619 let filtered_states =
620 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
621 let sort_columns = convert_to_sort_cols(
623 &filtered_states[1..is_set_idx],
624 self.ordering_req.as_ref(),
625 );
626
627 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
628 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
629
630 if let Some(last_idx) = max {
631 let last_row = get_row_at_idx(&filtered_states, last_idx)?;
632 let last_ordering = &last_row[1..is_set_idx];
634 let sort_options = get_sort_options(self.ordering_req.as_ref());
635 if !self.is_set
638 || self.requirement_satisfied
639 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
640 {
641 self.update_with_new_row(&last_row[0..is_set_idx]);
645 }
646 }
647 Ok(())
648 }
649
650 fn evaluate(&mut self) -> Result<ScalarValue> {
651 Ok(self.last.clone())
652 }
653
654 fn size(&self) -> usize {
655 size_of_val(self) - size_of_val(&self.last)
656 + self.last.size()
657 + ScalarValue::size_of_vec(&self.orderings)
658 - size_of_val(&self.orderings)
659 }
660}
661
662fn filter_states_according_to_is_set(
665 states: &[ArrayRef],
666 flags: &BooleanArray,
667) -> Result<Vec<ArrayRef>> {
668 states
669 .iter()
670 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
671 .collect::<Result<Vec<_>>>()
672}
673
674fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
676 arrs.iter()
677 .zip(sort_exprs.iter())
678 .map(|(item, sort_expr)| SortColumn {
679 values: Arc::clone(item),
680 options: Some(sort_expr.options),
681 })
682 .collect::<Vec<_>>()
683}
684
685#[cfg(test)]
686mod tests {
687 use arrow::array::Int64Array;
688
689 use super::*;
690
691 #[test]
692 fn test_first_last_value_value() -> Result<()> {
693 let mut first_accumulator = FirstValueAccumulator::try_new(
694 &DataType::Int64,
695 &[],
696 LexOrdering::default(),
697 false,
698 )?;
699 let mut last_accumulator = LastValueAccumulator::try_new(
700 &DataType::Int64,
701 &[],
702 LexOrdering::default(),
703 false,
704 )?;
705 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
708 let arrs = ranges
710 .into_iter()
711 .map(|(start, end)| {
712 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
713 })
714 .collect::<Vec<_>>();
715 for arr in arrs {
716 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
719 last_accumulator.update_batch(&[arr])?;
721 }
722 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
724 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
726 Ok(())
727 }
728
729 #[test]
730 fn test_first_last_state_after_merge() -> Result<()> {
731 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
732 let arrs = ranges
734 .into_iter()
735 .map(|(start, end)| {
736 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
737 })
738 .collect::<Vec<_>>();
739
740 let mut first_accumulator = FirstValueAccumulator::try_new(
742 &DataType::Int64,
743 &[],
744 LexOrdering::default(),
745 false,
746 )?;
747
748 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
749 let state1 = first_accumulator.state()?;
750
751 let mut first_accumulator = FirstValueAccumulator::try_new(
752 &DataType::Int64,
753 &[],
754 LexOrdering::default(),
755 false,
756 )?;
757 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
758 let state2 = first_accumulator.state()?;
759
760 assert_eq!(state1.len(), state2.len());
761
762 let mut states = vec![];
763
764 for idx in 0..state1.len() {
765 states.push(compute::concat(&[
766 &state1[idx].to_array()?,
767 &state2[idx].to_array()?,
768 ])?);
769 }
770
771 let mut first_accumulator = FirstValueAccumulator::try_new(
772 &DataType::Int64,
773 &[],
774 LexOrdering::default(),
775 false,
776 )?;
777 first_accumulator.merge_batch(&states)?;
778
779 let merged_state = first_accumulator.state()?;
780 assert_eq!(merged_state.len(), state1.len());
781
782 let mut last_accumulator = LastValueAccumulator::try_new(
784 &DataType::Int64,
785 &[],
786 LexOrdering::default(),
787 false,
788 )?;
789
790 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
791 let state1 = last_accumulator.state()?;
792
793 let mut last_accumulator = LastValueAccumulator::try_new(
794 &DataType::Int64,
795 &[],
796 LexOrdering::default(),
797 false,
798 )?;
799 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
800 let state2 = last_accumulator.state()?;
801
802 assert_eq!(state1.len(), state2.len());
803
804 let mut states = vec![];
805
806 for idx in 0..state1.len() {
807 states.push(compute::concat(&[
808 &state1[idx].to_array()?,
809 &state2[idx].to_array()?,
810 ])?);
811 }
812
813 let mut last_accumulator = LastValueAccumulator::try_new(
814 &DataType::Int64,
815 &[],
816 LexOrdering::default(),
817 false,
818 )?;
819 last_accumulator.merge_batch(&states)?;
820
821 let merged_state = last_accumulator.state()?;
822 assert_eq!(merged_state.len(), state1.len());
823
824 Ok(())
825 }
826}