1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::size_of_val;
24use std::sync::Arc;
25
26use arrow::array::{
27 Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
28 PrimitiveArray,
29};
30use arrow::buffer::{BooleanBuffer, NullBuffer};
31use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
32use arrow::datatypes::{
33 DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type,
34 Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int8Type,
35 Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
36 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
38 UInt16Type, UInt32Type, UInt64Type,
39};
40use datafusion_common::cast::as_boolean_array;
41use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
42use datafusion_common::{
43 DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err,
44 not_impl_err,
45};
46use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
47use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
48use datafusion_expr::{
49 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
50 GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
51};
52use datafusion_functions_aggregate_common::utils::get_sort_options;
53use datafusion_macros::user_doc;
54use datafusion_physical_expr_common::sort_expr::LexOrdering;
55
56create_func!(FirstValue, first_value_udaf);
57create_func!(LastValue, last_value_udaf);
58
59pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
61 first_value_udaf()
62 .call(vec![expression])
63 .order_by(order_by)
64 .build()
65 .unwrap()
67}
68
69pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
71 last_value_udaf()
72 .call(vec![expression])
73 .order_by(order_by)
74 .build()
75 .unwrap()
77}
78
79#[user_doc(
80 doc_section(label = "General Functions"),
81 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
82 syntax_example = "first_value(expression [ORDER BY expression])",
83 sql_example = r#"```sql
84> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
85+-----------------------------------------------+
86| first_value(column_name ORDER BY other_column)|
87+-----------------------------------------------+
88| first_element |
89+-----------------------------------------------+
90```"#,
91 standard_argument(name = "expression",)
92)]
93#[derive(PartialEq, Eq, Hash, Debug)]
94pub struct FirstValue {
95 signature: Signature,
96 is_input_pre_ordered: bool,
97}
98
99impl Default for FirstValue {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl FirstValue {
106 pub fn new() -> Self {
107 Self {
108 signature: Signature::any(1, Volatility::Immutable),
109 is_input_pre_ordered: false,
110 }
111 }
112}
113
114impl AggregateUDFImpl for FirstValue {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "first_value"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
128 not_impl_err!("Not called because the return_field_from_args is implemented")
129 }
130
131 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
132 Ok(Arc::new(
134 Field::new(
135 self.name(),
136 arg_fields[0].data_type().clone(),
137 true, )
139 .with_metadata(arg_fields[0].metadata().clone()),
140 ))
141 }
142
143 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
144 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
145 return TrivialFirstValueAccumulator::try_new(
146 acc_args.return_field.data_type(),
147 acc_args.ignore_nulls,
148 )
149 .map(|acc| Box::new(acc) as _);
150 };
151 let ordering_dtypes = ordering
152 .iter()
153 .map(|e| e.expr.data_type(acc_args.schema))
154 .collect::<Result<Vec<_>>>()?;
155 Ok(Box::new(FirstValueAccumulator::try_new(
156 acc_args.return_field.data_type(),
157 &ordering_dtypes,
158 ordering,
159 self.is_input_pre_ordered,
160 acc_args.ignore_nulls,
161 )?))
162 }
163
164 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
165 let mut fields = vec![
166 Field::new(
167 format_state_name(args.name, "first_value"),
168 args.return_type().clone(),
169 true,
170 )
171 .into(),
172 ];
173 fields.extend(args.ordering_fields.iter().cloned());
174 fields.push(
175 Field::new(
176 format_state_name(args.name, "first_value_is_set"),
177 DataType::Boolean,
178 true,
179 )
180 .into(),
181 );
182 Ok(fields)
183 }
184
185 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
186 use DataType::*;
187 !args.order_bys.is_empty()
188 && matches!(
189 args.return_field.data_type(),
190 Int8 | Int16
191 | Int32
192 | Int64
193 | UInt8
194 | UInt16
195 | UInt32
196 | UInt64
197 | Float16
198 | Float32
199 | Float64
200 | Decimal32(_, _)
201 | Decimal64(_, _)
202 | Decimal128(_, _)
203 | Decimal256(_, _)
204 | Date32
205 | Date64
206 | Time32(_)
207 | Time64(_)
208 | Timestamp(_, _)
209 )
210 }
211
212 fn create_groups_accumulator(
213 &self,
214 args: AccumulatorArgs,
215 ) -> Result<Box<dyn GroupsAccumulator>> {
216 fn create_accumulator<T: ArrowPrimitiveType + Send>(
217 args: &AccumulatorArgs,
218 ) -> Result<Box<dyn GroupsAccumulator>> {
219 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
220 return internal_err!("Groups accumulator must have an ordering.");
221 };
222
223 let ordering_dtypes = ordering
224 .iter()
225 .map(|e| e.expr.data_type(args.schema))
226 .collect::<Result<Vec<_>>>()?;
227
228 FirstPrimitiveGroupsAccumulator::<T>::try_new(
229 ordering,
230 args.ignore_nulls,
231 args.return_field.data_type(),
232 &ordering_dtypes,
233 true,
234 )
235 .map(|acc| Box::new(acc) as _)
236 }
237
238 match args.return_field.data_type() {
239 DataType::Int8 => create_accumulator::<Int8Type>(&args),
240 DataType::Int16 => create_accumulator::<Int16Type>(&args),
241 DataType::Int32 => create_accumulator::<Int32Type>(&args),
242 DataType::Int64 => create_accumulator::<Int64Type>(&args),
243 DataType::UInt8 => create_accumulator::<UInt8Type>(&args),
244 DataType::UInt16 => create_accumulator::<UInt16Type>(&args),
245 DataType::UInt32 => create_accumulator::<UInt32Type>(&args),
246 DataType::UInt64 => create_accumulator::<UInt64Type>(&args),
247 DataType::Float16 => create_accumulator::<Float16Type>(&args),
248 DataType::Float32 => create_accumulator::<Float32Type>(&args),
249 DataType::Float64 => create_accumulator::<Float64Type>(&args),
250
251 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(&args),
252 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(&args),
253 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(&args),
254 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(&args),
255
256 DataType::Timestamp(TimeUnit::Second, _) => {
257 create_accumulator::<TimestampSecondType>(&args)
258 }
259 DataType::Timestamp(TimeUnit::Millisecond, _) => {
260 create_accumulator::<TimestampMillisecondType>(&args)
261 }
262 DataType::Timestamp(TimeUnit::Microsecond, _) => {
263 create_accumulator::<TimestampMicrosecondType>(&args)
264 }
265 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
266 create_accumulator::<TimestampNanosecondType>(&args)
267 }
268
269 DataType::Date32 => create_accumulator::<Date32Type>(&args),
270 DataType::Date64 => create_accumulator::<Date64Type>(&args),
271 DataType::Time32(TimeUnit::Second) => {
272 create_accumulator::<Time32SecondType>(&args)
273 }
274 DataType::Time32(TimeUnit::Millisecond) => {
275 create_accumulator::<Time32MillisecondType>(&args)
276 }
277
278 DataType::Time64(TimeUnit::Microsecond) => {
279 create_accumulator::<Time64MicrosecondType>(&args)
280 }
281 DataType::Time64(TimeUnit::Nanosecond) => {
282 create_accumulator::<Time64NanosecondType>(&args)
283 }
284
285 _ => internal_err!(
286 "GroupsAccumulator not supported for first_value({})",
287 args.return_field.data_type()
288 ),
289 }
290 }
291
292 fn with_beneficial_ordering(
293 self: Arc<Self>,
294 beneficial_ordering: bool,
295 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
296 Ok(Some(Arc::new(Self {
297 signature: self.signature.clone(),
298 is_input_pre_ordered: beneficial_ordering,
299 })))
300 }
301
302 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
303 AggregateOrderSensitivity::Beneficial
304 }
305
306 fn reverse_expr(&self) -> ReversedUDAF {
307 ReversedUDAF::Reversed(last_value_udaf())
308 }
309
310 fn supports_null_handling_clause(&self) -> bool {
311 true
312 }
313
314 fn documentation(&self) -> Option<&Documentation> {
315 self.doc()
316 }
317}
318
319struct FirstPrimitiveGroupsAccumulator<T>
321where
322 T: ArrowPrimitiveType + Send,
323{
324 vals: Vec<T::Native>,
326 orderings: Vec<Vec<ScalarValue>>,
331 is_sets: BooleanBufferBuilder,
334 null_builder: BooleanBufferBuilder,
336 size_of_orderings: usize,
341
342 min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
347
348 ordering_req: LexOrdering,
352 pick_first_in_group: bool,
355 sort_options: Vec<SortOptions>,
357 ignore_nulls: bool,
359 data_type: DataType,
361 default_orderings: Vec<ScalarValue>,
362}
363
364impl<T> FirstPrimitiveGroupsAccumulator<T>
365where
366 T: ArrowPrimitiveType + Send,
367{
368 fn try_new(
369 ordering_req: LexOrdering,
370 ignore_nulls: bool,
371 data_type: &DataType,
372 ordering_dtypes: &[DataType],
373 pick_first_in_group: bool,
374 ) -> Result<Self> {
375 let default_orderings = ordering_dtypes
376 .iter()
377 .map(ScalarValue::try_from)
378 .collect::<Result<_>>()?;
379
380 let sort_options = get_sort_options(&ordering_req);
381
382 Ok(Self {
383 null_builder: BooleanBufferBuilder::new(0),
384 ordering_req,
385 sort_options,
386 ignore_nulls,
387 default_orderings,
388 data_type: data_type.clone(),
389 vals: Vec::new(),
390 orderings: Vec::new(),
391 is_sets: BooleanBufferBuilder::new(0),
392 size_of_orderings: 0,
393 min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
394 pick_first_in_group,
395 })
396 }
397
398 fn should_update_state(
399 &self,
400 group_idx: usize,
401 new_ordering_values: &[ScalarValue],
402 ) -> Result<bool> {
403 if !self.is_sets.get_bit(group_idx) {
404 return Ok(true);
405 }
406
407 assert!(new_ordering_values.len() == self.ordering_req.len());
408 let current_ordering = &self.orderings[group_idx];
409 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
410 if self.pick_first_in_group {
411 x.is_gt()
412 } else {
413 x.is_lt()
414 }
415 })
416 }
417
418 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
419 let result = emit_to.take_needed(&mut self.orderings);
420
421 match emit_to {
422 EmitTo::All => self.size_of_orderings = 0,
423 EmitTo::First(_) => {
424 self.size_of_orderings -=
425 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
426 }
427 }
428
429 result
430 }
431
432 fn take_need(
433 bool_buf_builder: &mut BooleanBufferBuilder,
434 emit_to: EmitTo,
435 ) -> BooleanBuffer {
436 let bool_buf = bool_buf_builder.finish();
437 match emit_to {
438 EmitTo::All => bool_buf,
439 EmitTo::First(n) => {
440 let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
445 for b in bool_buf.iter().skip(n) {
447 bool_buf_builder.append(b);
448 }
449 first_n
450 }
451 }
452 }
453
454 fn resize_states(&mut self, new_size: usize) {
455 self.vals.resize(new_size, T::default_value());
456
457 self.null_builder.resize(new_size);
458
459 if self.orderings.len() < new_size {
460 let current_len = self.orderings.len();
461
462 self.orderings
463 .resize(new_size, self.default_orderings.clone());
464
465 self.size_of_orderings += (new_size - current_len)
466 * ScalarValue::size_of_vec(
467 self.orderings.last().unwrap(),
471 );
472 }
473
474 self.is_sets.resize(new_size);
475
476 self.min_of_each_group_buf.0.resize(new_size, 0);
477 self.min_of_each_group_buf.1.resize(new_size);
478 }
479
480 fn update_state(
481 &mut self,
482 group_idx: usize,
483 orderings: &[ScalarValue],
484 new_val: T::Native,
485 is_null: bool,
486 ) {
487 self.vals[group_idx] = new_val;
488 self.is_sets.set_bit(group_idx, true);
489
490 self.null_builder.set_bit(group_idx, !is_null);
491
492 assert!(orderings.len() == self.ordering_req.len());
493 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
494 self.orderings[group_idx].clear();
495 self.orderings[group_idx].extend_from_slice(orderings);
496 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
497 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
498 }
499
500 fn take_state(
501 &mut self,
502 emit_to: EmitTo,
503 ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
504 emit_to.take_needed(&mut self.min_of_each_group_buf.0);
505 self.min_of_each_group_buf
506 .1
507 .truncate(self.min_of_each_group_buf.0.len());
508
509 (
510 self.take_vals_and_null_buf(emit_to),
511 self.take_orderings(emit_to),
512 Self::take_need(&mut self.is_sets, emit_to),
513 )
514 }
515
516 #[cfg(test)]
518 fn compute_size_of_orderings(&self) -> usize {
519 self.orderings
520 .iter()
521 .map(ScalarValue::size_of_vec)
522 .sum::<usize>()
523 }
524 fn get_filtered_min_of_each_group(
529 &mut self,
530 orderings: &[ArrayRef],
531 group_indices: &[usize],
532 opt_filter: Option<&BooleanArray>,
533 vals: &PrimitiveArray<T>,
534 is_set_arr: Option<&BooleanArray>,
535 ) -> Result<Vec<(usize, usize)>> {
536 self.min_of_each_group_buf.1.truncate(0);
538 self.min_of_each_group_buf
539 .1
540 .append_n(self.vals.len(), false);
541
542 let comparator = {
546 assert_eq!(orderings.len(), self.ordering_req.len());
547 let sort_columns = orderings
548 .iter()
549 .zip(self.ordering_req.iter())
550 .map(|(array, req)| SortColumn {
551 values: Arc::clone(array),
552 options: Some(req.options),
553 })
554 .collect::<Vec<_>>();
555
556 LexicographicalComparator::try_new(&sort_columns)?
557 };
558
559 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
560 let group_idx = *group_idx;
561
562 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
563 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
564
565 if !passed_filter || !is_set {
566 continue;
567 }
568
569 if self.ignore_nulls && vals.is_null(idx_in_val) {
570 continue;
571 }
572
573 let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
574
575 if !is_valid {
576 self.min_of_each_group_buf.1.set_bit(group_idx, true);
577 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
578 } else {
579 let ordering = comparator
580 .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
581
582 if (ordering.is_gt() && self.pick_first_in_group)
583 || (ordering.is_lt() && !self.pick_first_in_group)
584 {
585 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
586 }
587 }
588 }
589
590 Ok(self
591 .min_of_each_group_buf
592 .0
593 .iter()
594 .enumerate()
595 .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
596 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
597 .collect::<Vec<_>>())
598 }
599
600 fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
601 let r = emit_to.take_needed(&mut self.vals);
602
603 let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
604
605 let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) .with_data_type(self.data_type.clone());
607 Arc::new(values)
608 }
609}
610
611impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
612where
613 T: ArrowPrimitiveType + Send,
614{
615 fn update_batch(
616 &mut self,
617 values_and_order_cols: &[ArrayRef],
619 group_indices: &[usize],
620 opt_filter: Option<&BooleanArray>,
621 total_num_groups: usize,
622 ) -> Result<()> {
623 self.resize_states(total_num_groups);
624
625 let vals = values_and_order_cols[0].as_primitive::<T>();
626
627 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
628
629 for (group_idx, idx) in self
631 .get_filtered_min_of_each_group(
632 &values_and_order_cols[1..],
633 group_indices,
634 opt_filter,
635 vals,
636 None,
637 )?
638 .into_iter()
639 {
640 extract_row_at_idx_to_buf(
641 &values_and_order_cols[1..],
642 idx,
643 &mut ordering_buf,
644 )?;
645
646 if self.should_update_state(group_idx, &ordering_buf)? {
647 self.update_state(
648 group_idx,
649 &ordering_buf,
650 vals.value(idx),
651 vals.is_null(idx),
652 );
653 }
654 }
655
656 Ok(())
657 }
658
659 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
660 Ok(self.take_state(emit_to).0)
661 }
662
663 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
664 let (val_arr, orderings, is_sets) = self.take_state(emit_to);
665 let mut result = Vec::with_capacity(self.orderings.len() + 2);
666
667 result.push(val_arr);
668
669 let ordering_cols = {
670 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
671 for _ in 0..self.ordering_req.len() {
672 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
673 }
674 for row in orderings.into_iter() {
675 assert_eq!(row.len(), self.ordering_req.len());
676 for (col_idx, ordering) in row.into_iter().enumerate() {
677 ordering_cols[col_idx].push(ordering);
678 }
679 }
680
681 ordering_cols
682 };
683 for ordering_col in ordering_cols {
684 result.push(ScalarValue::iter_to_array(ordering_col)?);
685 }
686
687 result.push(Arc::new(BooleanArray::new(is_sets, None)));
688
689 Ok(result)
690 }
691
692 fn merge_batch(
693 &mut self,
694 values: &[ArrayRef],
695 group_indices: &[usize],
696 opt_filter: Option<&BooleanArray>,
697 total_num_groups: usize,
698 ) -> Result<()> {
699 self.resize_states(total_num_groups);
700
701 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
702
703 let (is_set_arr, val_and_order_cols) = match values.split_last() {
704 Some(result) => result,
705 None => return internal_err!("Empty row in FIRST_VALUE"),
706 };
707
708 let is_set_arr = as_boolean_array(is_set_arr)?;
709
710 let vals = values[0].as_primitive::<T>();
711 let groups = self.get_filtered_min_of_each_group(
713 &val_and_order_cols[1..],
714 group_indices,
715 opt_filter,
716 vals,
717 Some(is_set_arr),
718 )?;
719
720 for (group_idx, idx) in groups.into_iter() {
721 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
722
723 if self.should_update_state(group_idx, &ordering_buf)? {
724 self.update_state(
725 group_idx,
726 &ordering_buf,
727 vals.value(idx),
728 vals.is_null(idx),
729 );
730 }
731 }
732
733 Ok(())
734 }
735
736 fn size(&self) -> usize {
737 self.vals.capacity() * size_of::<T::Native>()
738 + self.null_builder.capacity() / 8 + self.is_sets.capacity() / 8
740 + self.size_of_orderings
741 + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
742 + self.min_of_each_group_buf.1.capacity() / 8
743 }
744
745 fn supports_convert_to_state(&self) -> bool {
746 true
747 }
748
749 fn convert_to_state(
750 &self,
751 values: &[ArrayRef],
752 opt_filter: Option<&BooleanArray>,
753 ) -> Result<Vec<ArrayRef>> {
754 let mut result = values.to_vec();
755 match opt_filter {
756 Some(f) => {
757 result.push(Arc::new(f.clone()));
758 Ok(result)
759 }
760 None => {
761 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
762 Ok(result)
763 }
764 }
765 }
766}
767
768#[derive(Debug)]
773pub struct TrivialFirstValueAccumulator {
774 first: ScalarValue,
775 is_set: bool,
777 ignore_nulls: bool,
779}
780
781impl TrivialFirstValueAccumulator {
782 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
784 ScalarValue::try_from(data_type).map(|first| Self {
785 first,
786 is_set: false,
787 ignore_nulls,
788 })
789 }
790}
791
792impl Accumulator for TrivialFirstValueAccumulator {
793 fn state(&mut self) -> Result<Vec<ScalarValue>> {
794 Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
795 }
796
797 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
798 if !self.is_set {
799 let value = &values[0];
801 let mut first_idx = None;
802 if self.ignore_nulls {
803 for i in 0..value.len() {
805 if !value.is_null(i) {
806 first_idx = Some(i);
807 break;
808 }
809 }
810 } else if !value.is_empty() {
811 first_idx = Some(0);
813 }
814 if let Some(first_idx) = first_idx {
815 let mut row = get_row_at_idx(values, first_idx)?;
816 self.first = row.swap_remove(0);
817 self.first.compact();
818 self.is_set = true;
819 }
820 }
821 Ok(())
822 }
823
824 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
825 if !self.is_set {
828 let flags = states[1].as_boolean();
829 validate_is_set_flags(flags, "first_value")?;
830
831 let filtered_states =
832 filter_states_according_to_is_set(&states[0..1], flags)?;
833 if let Some(first) = filtered_states.first()
834 && !first.is_empty()
835 {
836 self.first = ScalarValue::try_from_array(first, 0)?;
837 self.is_set = true;
838 }
839 }
840 Ok(())
841 }
842
843 fn evaluate(&mut self) -> Result<ScalarValue> {
844 Ok(self.first.clone())
845 }
846
847 fn size(&self) -> usize {
848 size_of_val(self) - size_of_val(&self.first) + self.first.size()
849 }
850}
851
852#[derive(Debug)]
853pub struct FirstValueAccumulator {
854 first: ScalarValue,
855 is_set: bool,
857 orderings: Vec<ScalarValue>,
860 ordering_req: LexOrdering,
862 is_input_pre_ordered: bool,
864 ignore_nulls: bool,
866}
867
868impl FirstValueAccumulator {
869 pub fn try_new(
871 data_type: &DataType,
872 ordering_dtypes: &[DataType],
873 ordering_req: LexOrdering,
874 is_input_pre_ordered: bool,
875 ignore_nulls: bool,
876 ) -> Result<Self> {
877 let orderings = ordering_dtypes
878 .iter()
879 .map(ScalarValue::try_from)
880 .collect::<Result<_>>()?;
881 ScalarValue::try_from(data_type).map(|first| Self {
882 first,
883 is_set: false,
884 orderings,
885 ordering_req,
886 is_input_pre_ordered,
887 ignore_nulls,
888 })
889 }
890
891 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
893 for s in row.iter_mut() {
895 s.compact();
896 }
897 self.first = row.remove(0);
898 self.orderings = row;
899 self.is_set = true;
900 }
901
902 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
903 let [value, ordering_values @ ..] = values else {
904 return internal_err!("Empty row in FIRST_VALUE");
905 };
906 if self.is_input_pre_ordered {
907 if self.ignore_nulls {
909 for i in 0..value.len() {
911 if !value.is_null(i) {
912 return Ok(Some(i));
913 }
914 }
915 return Ok(None);
916 } else {
917 return Ok((!value.is_empty()).then_some(0));
919 }
920 }
921
922 let sort_columns = ordering_values
923 .iter()
924 .zip(self.ordering_req.iter())
925 .map(|(values, req)| SortColumn {
926 values: Arc::clone(values),
927 options: Some(req.options),
928 })
929 .collect::<Vec<_>>();
930
931 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
932
933 let min_index = if self.ignore_nulls {
934 (0..value.len())
935 .filter(|&index| !value.is_null(index))
936 .min_by(|&a, &b| comparator.compare(a, b))
937 } else {
938 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
939 };
940
941 Ok(min_index)
942 }
943}
944
945impl Accumulator for FirstValueAccumulator {
946 fn state(&mut self) -> Result<Vec<ScalarValue>> {
947 let mut result = vec![self.first.clone()];
948 result.extend(self.orderings.iter().cloned());
949 result.push(ScalarValue::from(self.is_set));
950 Ok(result)
951 }
952
953 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
954 if let Some(first_idx) = self.get_first_idx(values)? {
955 let row = get_row_at_idx(values, first_idx)?;
956 if !self.is_set
957 || (!self.is_input_pre_ordered
958 && compare_rows(
959 &self.orderings,
960 &row[1..],
961 &get_sort_options(&self.ordering_req),
962 )?
963 .is_gt())
964 {
965 self.update_with_new_row(row);
966 }
967 }
968 Ok(())
969 }
970
971 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
972 let is_set_idx = states.len() - 1;
975 let flags = states[is_set_idx].as_boolean();
976 validate_is_set_flags(flags, "first_value")?;
977
978 let filtered_states =
979 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
980 let sort_columns =
982 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
983
984 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
985 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
986
987 if let Some(first_idx) = min {
988 let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
989 let first_ordering = &first_row[1..is_set_idx];
991 let sort_options = get_sort_options(&self.ordering_req);
992 if !self.is_set
994 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
995 {
996 assert!(is_set_idx <= first_row.len());
1000 first_row.resize(is_set_idx, ScalarValue::Null);
1001 self.update_with_new_row(first_row);
1002 }
1003 }
1004 Ok(())
1005 }
1006
1007 fn evaluate(&mut self) -> Result<ScalarValue> {
1008 Ok(self.first.clone())
1009 }
1010
1011 fn size(&self) -> usize {
1012 size_of_val(self) - size_of_val(&self.first)
1013 + self.first.size()
1014 + ScalarValue::size_of_vec(&self.orderings)
1015 - size_of_val(&self.orderings)
1016 }
1017}
1018
1019#[user_doc(
1020 doc_section(label = "General Functions"),
1021 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
1022 syntax_example = "last_value(expression [ORDER BY expression])",
1023 sql_example = r#"```sql
1024> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
1025+-----------------------------------------------+
1026| last_value(column_name ORDER BY other_column) |
1027+-----------------------------------------------+
1028| last_element |
1029+-----------------------------------------------+
1030```"#,
1031 standard_argument(name = "expression",)
1032)]
1033#[derive(PartialEq, Eq, Hash, Debug)]
1034pub struct LastValue {
1035 signature: Signature,
1036 is_input_pre_ordered: bool,
1037}
1038
1039impl Default for LastValue {
1040 fn default() -> Self {
1041 Self::new()
1042 }
1043}
1044
1045impl LastValue {
1046 pub fn new() -> Self {
1047 Self {
1048 signature: Signature::any(1, Volatility::Immutable),
1049 is_input_pre_ordered: false,
1050 }
1051 }
1052}
1053
1054impl AggregateUDFImpl for LastValue {
1055 fn as_any(&self) -> &dyn Any {
1056 self
1057 }
1058
1059 fn name(&self) -> &str {
1060 "last_value"
1061 }
1062
1063 fn signature(&self) -> &Signature {
1064 &self.signature
1065 }
1066
1067 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1068 not_impl_err!("Not called because the return_field_from_args is implemented")
1069 }
1070
1071 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1072 Ok(Arc::new(
1074 Field::new(
1075 self.name(),
1076 arg_fields[0].data_type().clone(),
1077 true, )
1079 .with_metadata(arg_fields[0].metadata().clone()),
1080 ))
1081 }
1082
1083 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1084 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1085 return TrivialLastValueAccumulator::try_new(
1086 acc_args.return_field.data_type(),
1087 acc_args.ignore_nulls,
1088 )
1089 .map(|acc| Box::new(acc) as _);
1090 };
1091 let ordering_dtypes = ordering
1092 .iter()
1093 .map(|e| e.expr.data_type(acc_args.schema))
1094 .collect::<Result<Vec<_>>>()?;
1095 Ok(Box::new(LastValueAccumulator::try_new(
1096 acc_args.return_field.data_type(),
1097 &ordering_dtypes,
1098 ordering,
1099 self.is_input_pre_ordered,
1100 acc_args.ignore_nulls,
1101 )?))
1102 }
1103
1104 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1105 let mut fields = vec![
1106 Field::new(
1107 format_state_name(args.name, "last_value"),
1108 args.return_field.data_type().clone(),
1109 true,
1110 )
1111 .into(),
1112 ];
1113 fields.extend(args.ordering_fields.iter().cloned());
1114 fields.push(
1115 Field::new(
1116 format_state_name(args.name, "last_value_is_set"),
1117 DataType::Boolean,
1118 true,
1119 )
1120 .into(),
1121 );
1122 Ok(fields)
1123 }
1124
1125 fn with_beneficial_ordering(
1126 self: Arc<Self>,
1127 beneficial_ordering: bool,
1128 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1129 Ok(Some(Arc::new(Self {
1130 signature: self.signature.clone(),
1131 is_input_pre_ordered: beneficial_ordering,
1132 })))
1133 }
1134
1135 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1136 AggregateOrderSensitivity::Beneficial
1137 }
1138
1139 fn reverse_expr(&self) -> ReversedUDAF {
1140 ReversedUDAF::Reversed(first_value_udaf())
1141 }
1142
1143 fn supports_null_handling_clause(&self) -> bool {
1144 true
1145 }
1146
1147 fn documentation(&self) -> Option<&Documentation> {
1148 self.doc()
1149 }
1150
1151 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1152 use DataType::*;
1153 !args.order_bys.is_empty()
1154 && matches!(
1155 args.return_field.data_type(),
1156 Int8 | Int16
1157 | Int32
1158 | Int64
1159 | UInt8
1160 | UInt16
1161 | UInt32
1162 | UInt64
1163 | Float16
1164 | Float32
1165 | Float64
1166 | Decimal32(_, _)
1167 | Decimal64(_, _)
1168 | Decimal128(_, _)
1169 | Decimal256(_, _)
1170 | Date32
1171 | Date64
1172 | Time32(_)
1173 | Time64(_)
1174 | Timestamp(_, _)
1175 )
1176 }
1177
1178 fn create_groups_accumulator(
1179 &self,
1180 args: AccumulatorArgs,
1181 ) -> Result<Box<dyn GroupsAccumulator>> {
1182 fn create_accumulator<T>(
1183 args: &AccumulatorArgs,
1184 ) -> Result<Box<dyn GroupsAccumulator>>
1185 where
1186 T: ArrowPrimitiveType + Send,
1187 {
1188 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
1189 return internal_err!("Groups accumulator must have an ordering.");
1190 };
1191
1192 let ordering_dtypes = ordering
1193 .iter()
1194 .map(|e| e.expr.data_type(args.schema))
1195 .collect::<Result<Vec<_>>>()?;
1196
1197 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1198 ordering,
1199 args.ignore_nulls,
1200 args.return_field.data_type(),
1201 &ordering_dtypes,
1202 false,
1203 )?))
1204 }
1205
1206 match args.return_field.data_type() {
1207 DataType::Int8 => create_accumulator::<Int8Type>(&args),
1208 DataType::Int16 => create_accumulator::<Int16Type>(&args),
1209 DataType::Int32 => create_accumulator::<Int32Type>(&args),
1210 DataType::Int64 => create_accumulator::<Int64Type>(&args),
1211 DataType::UInt8 => create_accumulator::<UInt8Type>(&args),
1212 DataType::UInt16 => create_accumulator::<UInt16Type>(&args),
1213 DataType::UInt32 => create_accumulator::<UInt32Type>(&args),
1214 DataType::UInt64 => create_accumulator::<UInt64Type>(&args),
1215 DataType::Float16 => create_accumulator::<Float16Type>(&args),
1216 DataType::Float32 => create_accumulator::<Float32Type>(&args),
1217 DataType::Float64 => create_accumulator::<Float64Type>(&args),
1218
1219 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(&args),
1220 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(&args),
1221 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(&args),
1222 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(&args),
1223
1224 DataType::Timestamp(TimeUnit::Second, _) => {
1225 create_accumulator::<TimestampSecondType>(&args)
1226 }
1227 DataType::Timestamp(TimeUnit::Millisecond, _) => {
1228 create_accumulator::<TimestampMillisecondType>(&args)
1229 }
1230 DataType::Timestamp(TimeUnit::Microsecond, _) => {
1231 create_accumulator::<TimestampMicrosecondType>(&args)
1232 }
1233 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1234 create_accumulator::<TimestampNanosecondType>(&args)
1235 }
1236
1237 DataType::Date32 => create_accumulator::<Date32Type>(&args),
1238 DataType::Date64 => create_accumulator::<Date64Type>(&args),
1239 DataType::Time32(TimeUnit::Second) => {
1240 create_accumulator::<Time32SecondType>(&args)
1241 }
1242 DataType::Time32(TimeUnit::Millisecond) => {
1243 create_accumulator::<Time32MillisecondType>(&args)
1244 }
1245
1246 DataType::Time64(TimeUnit::Microsecond) => {
1247 create_accumulator::<Time64MicrosecondType>(&args)
1248 }
1249 DataType::Time64(TimeUnit::Nanosecond) => {
1250 create_accumulator::<Time64NanosecondType>(&args)
1251 }
1252
1253 _ => {
1254 internal_err!(
1255 "GroupsAccumulator not supported for last_value({})",
1256 args.return_field.data_type()
1257 )
1258 }
1259 }
1260 }
1261}
1262
1263#[derive(Debug)]
1268pub struct TrivialLastValueAccumulator {
1269 last: ScalarValue,
1270 is_set: bool,
1274 ignore_nulls: bool,
1276}
1277
1278impl TrivialLastValueAccumulator {
1279 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1281 ScalarValue::try_from(data_type).map(|last| Self {
1282 last,
1283 is_set: false,
1284 ignore_nulls,
1285 })
1286 }
1287}
1288
1289impl Accumulator for TrivialLastValueAccumulator {
1290 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1291 Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1292 }
1293
1294 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1295 let value = &values[0];
1297 let mut last_idx = None;
1298 if self.ignore_nulls {
1299 for i in (0..value.len()).rev() {
1301 if !value.is_null(i) {
1302 last_idx = Some(i);
1303 break;
1304 }
1305 }
1306 } else if !value.is_empty() {
1307 last_idx = Some(value.len() - 1);
1309 }
1310 if let Some(last_idx) = last_idx {
1311 let mut row = get_row_at_idx(values, last_idx)?;
1312 self.last = row.swap_remove(0);
1313 self.last.compact();
1314 self.is_set = true;
1315 }
1316 Ok(())
1317 }
1318
1319 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1320 let flags = states[1].as_boolean();
1323 validate_is_set_flags(flags, "last_value")?;
1324
1325 let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1326 if let Some(last) = filtered_states.last()
1327 && !last.is_empty()
1328 {
1329 self.last = ScalarValue::try_from_array(last, 0)?;
1330 self.is_set = true;
1331 }
1332 Ok(())
1333 }
1334
1335 fn evaluate(&mut self) -> Result<ScalarValue> {
1336 Ok(self.last.clone())
1337 }
1338
1339 fn size(&self) -> usize {
1340 size_of_val(self) - size_of_val(&self.last) + self.last.size()
1341 }
1342}
1343
1344#[derive(Debug)]
1345struct LastValueAccumulator {
1346 last: ScalarValue,
1347 is_set: bool,
1351 orderings: Vec<ScalarValue>,
1354 ordering_req: LexOrdering,
1356 is_input_pre_ordered: bool,
1358 ignore_nulls: bool,
1360}
1361
1362impl LastValueAccumulator {
1363 pub fn try_new(
1365 data_type: &DataType,
1366 ordering_dtypes: &[DataType],
1367 ordering_req: LexOrdering,
1368 is_input_pre_ordered: bool,
1369 ignore_nulls: bool,
1370 ) -> Result<Self> {
1371 let orderings = ordering_dtypes
1372 .iter()
1373 .map(ScalarValue::try_from)
1374 .collect::<Result<_>>()?;
1375 ScalarValue::try_from(data_type).map(|last| Self {
1376 last,
1377 is_set: false,
1378 orderings,
1379 ordering_req,
1380 is_input_pre_ordered,
1381 ignore_nulls,
1382 })
1383 }
1384
1385 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1387 for s in row.iter_mut() {
1389 s.compact();
1390 }
1391 self.last = row.remove(0);
1392 self.orderings = row;
1393 self.is_set = true;
1394 }
1395
1396 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1397 let [value, ordering_values @ ..] = values else {
1398 return internal_err!("Empty row in LAST_VALUE");
1399 };
1400 if self.is_input_pre_ordered {
1401 if self.ignore_nulls {
1403 for i in (0..value.len()).rev() {
1405 if !value.is_null(i) {
1406 return Ok(Some(i));
1407 }
1408 }
1409 return Ok(None);
1410 } else {
1411 return Ok((!value.is_empty()).then_some(value.len() - 1));
1412 }
1413 }
1414
1415 let sort_columns = ordering_values
1416 .iter()
1417 .zip(self.ordering_req.iter())
1418 .map(|(values, req)| SortColumn {
1419 values: Arc::clone(values),
1420 options: Some(req.options),
1421 })
1422 .collect::<Vec<_>>();
1423
1424 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1425 let max_ind = if self.ignore_nulls {
1426 (0..value.len())
1427 .filter(|&index| !(value.is_null(index)))
1428 .max_by(|&a, &b| comparator.compare(a, b))
1429 } else {
1430 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1431 };
1432
1433 Ok(max_ind)
1434 }
1435}
1436
1437impl Accumulator for LastValueAccumulator {
1438 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1439 let mut result = vec![self.last.clone()];
1440 result.extend(self.orderings.clone());
1441 result.push(ScalarValue::from(self.is_set));
1442 Ok(result)
1443 }
1444
1445 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1446 if let Some(last_idx) = self.get_last_idx(values)? {
1447 let row = get_row_at_idx(values, last_idx)?;
1448 let orderings = &row[1..];
1449 if !self.is_set
1451 || self.is_input_pre_ordered
1452 || compare_rows(
1453 &self.orderings,
1454 orderings,
1455 &get_sort_options(&self.ordering_req),
1456 )?
1457 .is_lt()
1458 {
1459 self.update_with_new_row(row);
1460 }
1461 }
1462 Ok(())
1463 }
1464
1465 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1466 let is_set_idx = states.len() - 1;
1469 let flags = states[is_set_idx].as_boolean();
1470 validate_is_set_flags(flags, "last_value")?;
1471
1472 let filtered_states =
1473 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1474 let sort_columns =
1476 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1477
1478 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1479 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1480
1481 if let Some(last_idx) = max {
1482 let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1483 let last_ordering = &last_row[1..is_set_idx];
1485 let sort_options = get_sort_options(&self.ordering_req);
1486 if !self.is_set
1489 || self.is_input_pre_ordered
1490 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1491 {
1492 assert!(is_set_idx <= last_row.len());
1496 last_row.resize(is_set_idx, ScalarValue::Null);
1497 self.update_with_new_row(last_row);
1498 }
1499 }
1500 Ok(())
1501 }
1502
1503 fn evaluate(&mut self) -> Result<ScalarValue> {
1504 Ok(self.last.clone())
1505 }
1506
1507 fn size(&self) -> usize {
1508 size_of_val(self) - size_of_val(&self.last)
1509 + self.last.size()
1510 + ScalarValue::size_of_vec(&self.orderings)
1511 - size_of_val(&self.orderings)
1512 }
1513}
1514
1515fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1517 if flags.null_count() > 0 {
1518 return Err(DataFusionError::Internal(format!(
1519 "{function_name}: is_set flags contain nulls"
1520 )));
1521 }
1522 Ok(())
1523}
1524
1525fn filter_states_according_to_is_set(
1528 states: &[ArrayRef],
1529 flags: &BooleanArray,
1530) -> Result<Vec<ArrayRef>> {
1531 states
1532 .iter()
1533 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1534 .collect()
1535}
1536
1537fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1539 arrs.iter()
1540 .zip(sort_exprs.iter())
1541 .map(|(item, sort_expr)| SortColumn {
1542 values: Arc::clone(item),
1543 options: Some(sort_expr.options),
1544 })
1545 .collect()
1546}
1547
1548#[cfg(test)]
1549mod tests {
1550 use std::iter::repeat_with;
1551
1552 use arrow::{
1553 array::{BooleanArray, Int64Array, ListArray, StringArray},
1554 compute::SortOptions,
1555 datatypes::Schema,
1556 };
1557 use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};
1558
1559 use super::*;
1560
1561 #[test]
1562 fn test_first_last_value_value() -> Result<()> {
1563 let mut first_accumulator =
1564 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1565 let mut last_accumulator =
1566 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1567 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1570 let arrs = ranges
1572 .into_iter()
1573 .map(|(start, end)| {
1574 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1575 })
1576 .collect::<Vec<_>>();
1577 for arr in arrs {
1578 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1581 last_accumulator.update_batch(&[arr])?;
1583 }
1584 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1586 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1588 Ok(())
1589 }
1590
1591 #[test]
1592 fn test_first_last_state_after_merge() -> Result<()> {
1593 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1594 let arrs = ranges
1596 .into_iter()
1597 .map(|(start, end)| {
1598 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1599 })
1600 .collect::<Vec<_>>();
1601
1602 let mut first_accumulator =
1604 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1605
1606 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1607 let state1 = first_accumulator.state()?;
1608
1609 let mut first_accumulator =
1610 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1611 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1612 let state2 = first_accumulator.state()?;
1613
1614 assert_eq!(state1.len(), state2.len());
1615
1616 let mut states = vec![];
1617
1618 for idx in 0..state1.len() {
1619 states.push(compute::concat(&[
1620 &state1[idx].to_array()?,
1621 &state2[idx].to_array()?,
1622 ])?);
1623 }
1624
1625 let mut first_accumulator =
1626 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1627 first_accumulator.merge_batch(&states)?;
1628
1629 let merged_state = first_accumulator.state()?;
1630 assert_eq!(merged_state.len(), state1.len());
1631
1632 let mut last_accumulator =
1634 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1635
1636 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1637 let state1 = last_accumulator.state()?;
1638
1639 let mut last_accumulator =
1640 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1641 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1642 let state2 = last_accumulator.state()?;
1643
1644 assert_eq!(state1.len(), state2.len());
1645
1646 let mut states = vec![];
1647
1648 for idx in 0..state1.len() {
1649 states.push(compute::concat(&[
1650 &state1[idx].to_array()?,
1651 &state2[idx].to_array()?,
1652 ])?);
1653 }
1654
1655 let mut last_accumulator =
1656 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1657 last_accumulator.merge_batch(&states)?;
1658
1659 let merged_state = last_accumulator.state()?;
1660 assert_eq!(merged_state.len(), state1.len());
1661
1662 Ok(())
1663 }
1664
1665 #[test]
1666 fn test_first_group_acc() -> Result<()> {
1667 let schema = Arc::new(Schema::new(vec![
1668 Field::new("a", DataType::Int64, true),
1669 Field::new("b", DataType::Int64, true),
1670 Field::new("c", DataType::Int64, true),
1671 Field::new("d", DataType::Int32, true),
1672 Field::new("e", DataType::Boolean, true),
1673 ]));
1674
1675 let sort_keys = [PhysicalSortExpr {
1676 expr: col("c", &schema).unwrap(),
1677 options: SortOptions::default(),
1678 }];
1679
1680 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1681 sort_keys.into(),
1682 true,
1683 &DataType::Int64,
1684 &[DataType::Int64],
1685 true,
1686 )?;
1687
1688 let mut val_with_orderings = {
1689 let mut val_with_orderings = Vec::<ArrayRef>::new();
1690
1691 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1692 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1693
1694 val_with_orderings.push(vals);
1695 val_with_orderings.push(orderings);
1696
1697 val_with_orderings
1698 };
1699
1700 group_acc.update_batch(
1701 &val_with_orderings,
1702 &[0, 1, 2, 1],
1703 Some(&BooleanArray::from(vec![true, true, false, true])),
1704 3,
1705 )?;
1706 assert_eq!(
1707 group_acc.size_of_orderings,
1708 group_acc.compute_size_of_orderings()
1709 );
1710
1711 let state = group_acc.state(EmitTo::All)?;
1712
1713 let expected_state: Vec<Arc<dyn Array>> = vec![
1714 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1715 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1716 Arc::new(BooleanArray::from(vec![true, true, false])),
1717 ];
1718 assert_eq!(state, expected_state);
1719
1720 assert_eq!(
1721 group_acc.size_of_orderings,
1722 group_acc.compute_size_of_orderings()
1723 );
1724
1725 group_acc.merge_batch(
1726 &state,
1727 &[0, 1, 2],
1728 Some(&BooleanArray::from(vec![true, false, false])),
1729 3,
1730 )?;
1731
1732 assert_eq!(
1733 group_acc.size_of_orderings,
1734 group_acc.compute_size_of_orderings()
1735 );
1736
1737 val_with_orderings.clear();
1738 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1739 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1740
1741 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1742
1743 let binding = group_acc.evaluate(EmitTo::All)?;
1744 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1745
1746 let expect: PrimitiveArray<Int64Type> =
1747 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1748
1749 assert_eq!(eval_result, &expect);
1750
1751 assert_eq!(
1752 group_acc.size_of_orderings,
1753 group_acc.compute_size_of_orderings()
1754 );
1755
1756 Ok(())
1757 }
1758
1759 #[test]
1760 fn test_group_acc_size_of_ordering() -> Result<()> {
1761 let schema = Arc::new(Schema::new(vec![
1762 Field::new("a", DataType::Int64, true),
1763 Field::new("b", DataType::Int64, true),
1764 Field::new("c", DataType::Int64, true),
1765 Field::new("d", DataType::Int32, true),
1766 Field::new("e", DataType::Boolean, true),
1767 ]));
1768
1769 let sort_keys = [PhysicalSortExpr {
1770 expr: col("c", &schema).unwrap(),
1771 options: SortOptions::default(),
1772 }];
1773
1774 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1775 sort_keys.into(),
1776 true,
1777 &DataType::Int64,
1778 &[DataType::Int64],
1779 true,
1780 )?;
1781
1782 let val_with_orderings = {
1783 let mut val_with_orderings = Vec::<ArrayRef>::new();
1784
1785 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1786 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1787
1788 val_with_orderings.push(vals);
1789 val_with_orderings.push(orderings);
1790
1791 val_with_orderings
1792 };
1793
1794 for _ in 0..10 {
1795 group_acc.update_batch(
1796 &val_with_orderings,
1797 &[0, 1, 2, 1],
1798 Some(&BooleanArray::from(vec![true, true, false, true])),
1799 100,
1800 )?;
1801 assert_eq!(
1802 group_acc.size_of_orderings,
1803 group_acc.compute_size_of_orderings()
1804 );
1805
1806 group_acc.state(EmitTo::First(2))?;
1807 assert_eq!(
1808 group_acc.size_of_orderings,
1809 group_acc.compute_size_of_orderings()
1810 );
1811
1812 let s = group_acc.state(EmitTo::All)?;
1813 assert_eq!(
1814 group_acc.size_of_orderings,
1815 group_acc.compute_size_of_orderings()
1816 );
1817
1818 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1819 assert_eq!(
1820 group_acc.size_of_orderings,
1821 group_acc.compute_size_of_orderings()
1822 );
1823
1824 group_acc.evaluate(EmitTo::First(2))?;
1825 assert_eq!(
1826 group_acc.size_of_orderings,
1827 group_acc.compute_size_of_orderings()
1828 );
1829
1830 group_acc.evaluate(EmitTo::All)?;
1831 assert_eq!(
1832 group_acc.size_of_orderings,
1833 group_acc.compute_size_of_orderings()
1834 );
1835 }
1836
1837 Ok(())
1838 }
1839
1840 #[test]
1841 fn test_last_group_acc() -> Result<()> {
1842 let schema = Arc::new(Schema::new(vec![
1843 Field::new("a", DataType::Int64, true),
1844 Field::new("b", DataType::Int64, true),
1845 Field::new("c", DataType::Int64, true),
1846 Field::new("d", DataType::Int32, true),
1847 Field::new("e", DataType::Boolean, true),
1848 ]));
1849
1850 let sort_keys = [PhysicalSortExpr {
1851 expr: col("c", &schema).unwrap(),
1852 options: SortOptions::default(),
1853 }];
1854
1855 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1856 sort_keys.into(),
1857 true,
1858 &DataType::Int64,
1859 &[DataType::Int64],
1860 false,
1861 )?;
1862
1863 let mut val_with_orderings = {
1864 let mut val_with_orderings = Vec::<ArrayRef>::new();
1865
1866 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1867 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1868
1869 val_with_orderings.push(vals);
1870 val_with_orderings.push(orderings);
1871
1872 val_with_orderings
1873 };
1874
1875 group_acc.update_batch(
1876 &val_with_orderings,
1877 &[0, 1, 2, 1],
1878 Some(&BooleanArray::from(vec![true, true, false, true])),
1879 3,
1880 )?;
1881
1882 let state = group_acc.state(EmitTo::All)?;
1883
1884 let expected_state: Vec<Arc<dyn Array>> = vec![
1885 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1886 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1887 Arc::new(BooleanArray::from(vec![true, true, false])),
1888 ];
1889 assert_eq!(state, expected_state);
1890
1891 group_acc.merge_batch(
1892 &state,
1893 &[0, 1, 2],
1894 Some(&BooleanArray::from(vec![true, false, false])),
1895 3,
1896 )?;
1897
1898 val_with_orderings.clear();
1899 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1900 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1901
1902 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1903
1904 let binding = group_acc.evaluate(EmitTo::All)?;
1905 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1906
1907 let expect: PrimitiveArray<Int64Type> =
1908 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1909
1910 assert_eq!(eval_result, &expect);
1911
1912 Ok(())
1913 }
1914
1915 #[test]
1916 fn test_first_list_acc_size() -> Result<()> {
1917 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1918 let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1919 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1920 false,
1921 )?;
1922
1923 first_accumulator.update_batch(values)?;
1924
1925 Ok(first_accumulator.size())
1926 }
1927
1928 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1929 repeat_with(|| Some(vec![Some(1)])).take(10000),
1930 );
1931 let batch2 =
1932 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1933
1934 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1935 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1936 assert_eq!(size1, size2);
1937
1938 Ok(())
1939 }
1940
1941 #[test]
1942 fn test_last_list_acc_size() -> Result<()> {
1943 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1944 let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1945 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1946 false,
1947 )?;
1948
1949 last_accumulator.update_batch(values)?;
1950
1951 Ok(last_accumulator.size())
1952 }
1953
1954 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1955 repeat_with(|| Some(vec![Some(1)])).take(10000),
1956 );
1957 let batch2 =
1958 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1959
1960 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1961 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1962 assert_eq!(size1, size2);
1963
1964 Ok(())
1965 }
1966
1967 #[test]
1968 fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1969 let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1971 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1972
1973 let mut trivial_accumulator =
1975 TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1976 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1977 let result = trivial_accumulator.merge_batch(&trivial_states);
1978 assert!(result.is_err());
1979 assert!(
1980 result
1981 .unwrap_err()
1982 .to_string()
1983 .contains("is_set flags contain nulls")
1984 );
1985
1986 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1988 let ordering_expr = col("ordering", &schema)?;
1989 let mut ordered_accumulator = FirstValueAccumulator::try_new(
1990 &DataType::Utf8,
1991 &[DataType::Int64],
1992 LexOrdering::new(vec![PhysicalSortExpr {
1993 expr: ordering_expr,
1994 options: SortOptions::default(),
1995 }])
1996 .unwrap(),
1997 false,
1998 false,
1999 )?;
2000 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2001 let ordered_states = vec![value, ordering, corrupted_flag];
2002 let result = ordered_accumulator.merge_batch(&ordered_states);
2003 assert!(result.is_err());
2004 assert!(
2005 result
2006 .unwrap_err()
2007 .to_string()
2008 .contains("is_set flags contain nulls")
2009 );
2010
2011 Ok(())
2012 }
2013
2014 #[test]
2015 fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
2016 let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
2018 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
2019
2020 let mut trivial_accumulator =
2022 TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
2023 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
2024 let result = trivial_accumulator.merge_batch(&trivial_states);
2025 assert!(result.is_err());
2026 assert!(
2027 result
2028 .unwrap_err()
2029 .to_string()
2030 .contains("is_set flags contain nulls")
2031 );
2032
2033 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
2035 let ordering_expr = col("ordering", &schema)?;
2036 let mut ordered_accumulator = LastValueAccumulator::try_new(
2037 &DataType::Utf8,
2038 &[DataType::Int64],
2039 LexOrdering::new(vec![PhysicalSortExpr {
2040 expr: ordering_expr,
2041 options: SortOptions::default(),
2042 }])
2043 .unwrap(),
2044 false,
2045 false,
2046 )?;
2047 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2048 let ordered_states = vec![value, ordering, corrupted_flag];
2049 let result = ordered_accumulator.merge_batch(&ordered_states);
2050 assert!(result.is_err());
2051 assert!(
2052 result
2053 .unwrap_err()
2054 .to_string()
2055 .contains("is_set flags contain nulls")
2056 );
2057
2058 Ok(())
2059 }
2060}