1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::size_of_val;
24use std::sync::Arc;
25
26use arrow::array::{
27 Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
28 PrimitiveArray,
29};
30use arrow::buffer::{BooleanBuffer, NullBuffer};
31use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
32use arrow::datatypes::{
33 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef,
34 Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35 Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
36 TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
37 TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
38 UInt8Type,
39};
40use datafusion_common::cast::as_boolean_array;
41use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
42use datafusion_common::{
43 arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
44};
45use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
46use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
47use datafusion_expr::{
48 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
49 GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
50};
51use datafusion_functions_aggregate_common::utils::get_sort_options;
52use datafusion_macros::user_doc;
53use datafusion_physical_expr_common::sort_expr::LexOrdering;
54
55create_func!(FirstValue, first_value_udaf);
56create_func!(LastValue, last_value_udaf);
57
58pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
60 first_value_udaf()
61 .call(vec![expression])
62 .order_by(order_by)
63 .build()
64 .unwrap()
66}
67
68pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
70 last_value_udaf()
71 .call(vec![expression])
72 .order_by(order_by)
73 .build()
74 .unwrap()
76}
77
78#[user_doc(
79 doc_section(label = "General Functions"),
80 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
81 syntax_example = "first_value(expression [ORDER BY expression])",
82 sql_example = r#"```sql
83> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
84+-----------------------------------------------+
85| first_value(column_name ORDER BY other_column)|
86+-----------------------------------------------+
87| first_element |
88+-----------------------------------------------+
89```"#,
90 standard_argument(name = "expression",)
91)]
92#[derive(PartialEq, Eq, Hash)]
93pub struct FirstValue {
94 signature: Signature,
95 is_input_pre_ordered: bool,
96}
97
98impl Debug for FirstValue {
99 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
100 f.debug_struct("FirstValue")
101 .field("name", &self.name())
102 .field("signature", &self.signature)
103 .field("accumulator", &"<FUNC>")
104 .finish()
105 }
106}
107
108impl Default for FirstValue {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl FirstValue {
115 pub fn new() -> Self {
116 Self {
117 signature: Signature::any(1, Volatility::Immutable),
118 is_input_pre_ordered: false,
119 }
120 }
121}
122
123impl AggregateUDFImpl for FirstValue {
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn name(&self) -> &str {
129 "first_value"
130 }
131
132 fn signature(&self) -> &Signature {
133 &self.signature
134 }
135
136 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137 Ok(arg_types[0].clone())
138 }
139
140 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
142 return TrivialFirstValueAccumulator::try_new(
143 acc_args.return_field.data_type(),
144 acc_args.ignore_nulls,
145 )
146 .map(|acc| Box::new(acc) as _);
147 };
148 let ordering_dtypes = ordering
149 .iter()
150 .map(|e| e.expr.data_type(acc_args.schema))
151 .collect::<Result<Vec<_>>>()?;
152 Ok(Box::new(FirstValueAccumulator::try_new(
153 acc_args.return_field.data_type(),
154 &ordering_dtypes,
155 ordering,
156 self.is_input_pre_ordered,
157 acc_args.ignore_nulls,
158 )?))
159 }
160
161 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
162 let mut fields = vec![Field::new(
163 format_state_name(args.name, "first_value"),
164 args.return_type().clone(),
165 true,
166 )
167 .into()];
168 fields.extend(args.ordering_fields.iter().cloned());
169 fields.push(Field::new("is_set", DataType::Boolean, true).into());
170 Ok(fields)
171 }
172
173 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
174 use DataType::*;
175 !args.order_bys.is_empty()
176 && matches!(
177 args.return_field.data_type(),
178 Int8 | Int16
179 | Int32
180 | Int64
181 | UInt8
182 | UInt16
183 | UInt32
184 | UInt64
185 | Float16
186 | Float32
187 | Float64
188 | Decimal128(_, _)
189 | Decimal256(_, _)
190 | Date32
191 | Date64
192 | Time32(_)
193 | Time64(_)
194 | Timestamp(_, _)
195 )
196 }
197
198 fn create_groups_accumulator(
199 &self,
200 args: AccumulatorArgs,
201 ) -> Result<Box<dyn GroupsAccumulator>> {
202 fn create_accumulator<T: ArrowPrimitiveType + Send>(
203 args: AccumulatorArgs,
204 ) -> Result<Box<dyn GroupsAccumulator>> {
205 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
206 return internal_err!("Groups accumulator must have an ordering.");
207 };
208
209 let ordering_dtypes = ordering
210 .iter()
211 .map(|e| e.expr.data_type(args.schema))
212 .collect::<Result<Vec<_>>>()?;
213
214 FirstPrimitiveGroupsAccumulator::<T>::try_new(
215 ordering,
216 args.ignore_nulls,
217 args.return_field.data_type(),
218 &ordering_dtypes,
219 true,
220 )
221 .map(|acc| Box::new(acc) as _)
222 }
223
224 match args.return_field.data_type() {
225 DataType::Int8 => create_accumulator::<Int8Type>(args),
226 DataType::Int16 => create_accumulator::<Int16Type>(args),
227 DataType::Int32 => create_accumulator::<Int32Type>(args),
228 DataType::Int64 => create_accumulator::<Int64Type>(args),
229 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
230 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
231 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
232 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
233 DataType::Float16 => create_accumulator::<Float16Type>(args),
234 DataType::Float32 => create_accumulator::<Float32Type>(args),
235 DataType::Float64 => create_accumulator::<Float64Type>(args),
236
237 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
238 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
239
240 DataType::Timestamp(TimeUnit::Second, _) => {
241 create_accumulator::<TimestampSecondType>(args)
242 }
243 DataType::Timestamp(TimeUnit::Millisecond, _) => {
244 create_accumulator::<TimestampMillisecondType>(args)
245 }
246 DataType::Timestamp(TimeUnit::Microsecond, _) => {
247 create_accumulator::<TimestampMicrosecondType>(args)
248 }
249 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
250 create_accumulator::<TimestampNanosecondType>(args)
251 }
252
253 DataType::Date32 => create_accumulator::<Date32Type>(args),
254 DataType::Date64 => create_accumulator::<Date64Type>(args),
255 DataType::Time32(TimeUnit::Second) => {
256 create_accumulator::<Time32SecondType>(args)
257 }
258 DataType::Time32(TimeUnit::Millisecond) => {
259 create_accumulator::<Time32MillisecondType>(args)
260 }
261
262 DataType::Time64(TimeUnit::Microsecond) => {
263 create_accumulator::<Time64MicrosecondType>(args)
264 }
265 DataType::Time64(TimeUnit::Nanosecond) => {
266 create_accumulator::<Time64NanosecondType>(args)
267 }
268
269 _ => internal_err!(
270 "GroupsAccumulator not supported for first_value({})",
271 args.return_field.data_type()
272 ),
273 }
274 }
275
276 fn with_beneficial_ordering(
277 self: Arc<Self>,
278 beneficial_ordering: bool,
279 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
280 Ok(Some(Arc::new(Self {
281 signature: self.signature.clone(),
282 is_input_pre_ordered: beneficial_ordering,
283 })))
284 }
285
286 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
287 AggregateOrderSensitivity::Beneficial
288 }
289
290 fn reverse_expr(&self) -> ReversedUDAF {
291 ReversedUDAF::Reversed(last_value_udaf())
292 }
293
294 fn documentation(&self) -> Option<&Documentation> {
295 self.doc()
296 }
297}
298
299struct FirstPrimitiveGroupsAccumulator<T>
301where
302 T: ArrowPrimitiveType + Send,
303{
304 vals: Vec<T::Native>,
306 orderings: Vec<Vec<ScalarValue>>,
311 is_sets: BooleanBufferBuilder,
314 null_builder: BooleanBufferBuilder,
316 size_of_orderings: usize,
321
322 min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
327
328 ordering_req: LexOrdering,
332 pick_first_in_group: bool,
335 sort_options: Vec<SortOptions>,
337 ignore_nulls: bool,
339 data_type: DataType,
341 default_orderings: Vec<ScalarValue>,
342}
343
344impl<T> FirstPrimitiveGroupsAccumulator<T>
345where
346 T: ArrowPrimitiveType + Send,
347{
348 fn try_new(
349 ordering_req: LexOrdering,
350 ignore_nulls: bool,
351 data_type: &DataType,
352 ordering_dtypes: &[DataType],
353 pick_first_in_group: bool,
354 ) -> Result<Self> {
355 let default_orderings = ordering_dtypes
356 .iter()
357 .map(ScalarValue::try_from)
358 .collect::<Result<_>>()?;
359
360 let sort_options = get_sort_options(&ordering_req);
361
362 Ok(Self {
363 null_builder: BooleanBufferBuilder::new(0),
364 ordering_req,
365 sort_options,
366 ignore_nulls,
367 default_orderings,
368 data_type: data_type.clone(),
369 vals: Vec::new(),
370 orderings: Vec::new(),
371 is_sets: BooleanBufferBuilder::new(0),
372 size_of_orderings: 0,
373 min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
374 pick_first_in_group,
375 })
376 }
377
378 fn should_update_state(
379 &self,
380 group_idx: usize,
381 new_ordering_values: &[ScalarValue],
382 ) -> Result<bool> {
383 if !self.is_sets.get_bit(group_idx) {
384 return Ok(true);
385 }
386
387 assert!(new_ordering_values.len() == self.ordering_req.len());
388 let current_ordering = &self.orderings[group_idx];
389 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
390 if self.pick_first_in_group {
391 x.is_gt()
392 } else {
393 x.is_lt()
394 }
395 })
396 }
397
398 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
399 let result = emit_to.take_needed(&mut self.orderings);
400
401 match emit_to {
402 EmitTo::All => self.size_of_orderings = 0,
403 EmitTo::First(_) => {
404 self.size_of_orderings -=
405 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
406 }
407 }
408
409 result
410 }
411
412 fn take_need(
413 bool_buf_builder: &mut BooleanBufferBuilder,
414 emit_to: EmitTo,
415 ) -> BooleanBuffer {
416 let bool_buf = bool_buf_builder.finish();
417 match emit_to {
418 EmitTo::All => bool_buf,
419 EmitTo::First(n) => {
420 let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
425 for b in bool_buf.iter().skip(n) {
427 bool_buf_builder.append(b);
428 }
429 first_n
430 }
431 }
432 }
433
434 fn resize_states(&mut self, new_size: usize) {
435 self.vals.resize(new_size, T::default_value());
436
437 self.null_builder.resize(new_size);
438
439 if self.orderings.len() < new_size {
440 let current_len = self.orderings.len();
441
442 self.orderings
443 .resize(new_size, self.default_orderings.clone());
444
445 self.size_of_orderings += (new_size - current_len)
446 * ScalarValue::size_of_vec(
447 self.orderings.last().unwrap(),
451 );
452 }
453
454 self.is_sets.resize(new_size);
455
456 self.min_of_each_group_buf.0.resize(new_size, 0);
457 self.min_of_each_group_buf.1.resize(new_size);
458 }
459
460 fn update_state(
461 &mut self,
462 group_idx: usize,
463 orderings: &[ScalarValue],
464 new_val: T::Native,
465 is_null: bool,
466 ) {
467 self.vals[group_idx] = new_val;
468 self.is_sets.set_bit(group_idx, true);
469
470 self.null_builder.set_bit(group_idx, !is_null);
471
472 assert!(orderings.len() == self.ordering_req.len());
473 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
474 self.orderings[group_idx].clear();
475 self.orderings[group_idx].extend_from_slice(orderings);
476 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
477 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
478 }
479
480 fn take_state(
481 &mut self,
482 emit_to: EmitTo,
483 ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
484 emit_to.take_needed(&mut self.min_of_each_group_buf.0);
485 self.min_of_each_group_buf
486 .1
487 .truncate(self.min_of_each_group_buf.0.len());
488
489 (
490 self.take_vals_and_null_buf(emit_to),
491 self.take_orderings(emit_to),
492 Self::take_need(&mut self.is_sets, emit_to),
493 )
494 }
495
496 #[cfg(test)]
498 fn compute_size_of_orderings(&self) -> usize {
499 self.orderings
500 .iter()
501 .map(ScalarValue::size_of_vec)
502 .sum::<usize>()
503 }
504 fn get_filtered_min_of_each_group(
509 &mut self,
510 orderings: &[ArrayRef],
511 group_indices: &[usize],
512 opt_filter: Option<&BooleanArray>,
513 vals: &PrimitiveArray<T>,
514 is_set_arr: Option<&BooleanArray>,
515 ) -> Result<Vec<(usize, usize)>> {
516 self.min_of_each_group_buf.1.truncate(0);
518 self.min_of_each_group_buf
519 .1
520 .append_n(self.vals.len(), false);
521
522 let comparator = {
526 assert_eq!(orderings.len(), self.ordering_req.len());
527 let sort_columns = orderings
528 .iter()
529 .zip(self.ordering_req.iter())
530 .map(|(array, req)| SortColumn {
531 values: Arc::clone(array),
532 options: Some(req.options),
533 })
534 .collect::<Vec<_>>();
535
536 LexicographicalComparator::try_new(&sort_columns)?
537 };
538
539 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
540 let group_idx = *group_idx;
541
542 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
543 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
544
545 if !passed_filter || !is_set {
546 continue;
547 }
548
549 if self.ignore_nulls && vals.is_null(idx_in_val) {
550 continue;
551 }
552
553 let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
554
555 if !is_valid {
556 self.min_of_each_group_buf.1.set_bit(group_idx, true);
557 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
558 } else {
559 let ordering = comparator
560 .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
561
562 if (ordering.is_gt() && self.pick_first_in_group)
563 || (ordering.is_lt() && !self.pick_first_in_group)
564 {
565 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
566 }
567 }
568 }
569
570 Ok(self
571 .min_of_each_group_buf
572 .0
573 .iter()
574 .enumerate()
575 .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
576 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
577 .collect::<Vec<_>>())
578 }
579
580 fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
581 let r = emit_to.take_needed(&mut self.vals);
582
583 let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
584
585 let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) .with_data_type(self.data_type.clone());
587 Arc::new(values)
588 }
589}
590
591impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
592where
593 T: ArrowPrimitiveType + Send,
594{
595 fn update_batch(
596 &mut self,
597 values_and_order_cols: &[ArrayRef],
599 group_indices: &[usize],
600 opt_filter: Option<&BooleanArray>,
601 total_num_groups: usize,
602 ) -> Result<()> {
603 self.resize_states(total_num_groups);
604
605 let vals = values_and_order_cols[0].as_primitive::<T>();
606
607 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
608
609 for (group_idx, idx) in self
611 .get_filtered_min_of_each_group(
612 &values_and_order_cols[1..],
613 group_indices,
614 opt_filter,
615 vals,
616 None,
617 )?
618 .into_iter()
619 {
620 extract_row_at_idx_to_buf(
621 &values_and_order_cols[1..],
622 idx,
623 &mut ordering_buf,
624 )?;
625
626 if self.should_update_state(group_idx, &ordering_buf)? {
627 self.update_state(
628 group_idx,
629 &ordering_buf,
630 vals.value(idx),
631 vals.is_null(idx),
632 );
633 }
634 }
635
636 Ok(())
637 }
638
639 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
640 Ok(self.take_state(emit_to).0)
641 }
642
643 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
644 let (val_arr, orderings, is_sets) = self.take_state(emit_to);
645 let mut result = Vec::with_capacity(self.orderings.len() + 2);
646
647 result.push(val_arr);
648
649 let ordering_cols = {
650 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
651 for _ in 0..self.ordering_req.len() {
652 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
653 }
654 for row in orderings.into_iter() {
655 assert_eq!(row.len(), self.ordering_req.len());
656 for (col_idx, ordering) in row.into_iter().enumerate() {
657 ordering_cols[col_idx].push(ordering);
658 }
659 }
660
661 ordering_cols
662 };
663 for ordering_col in ordering_cols {
664 result.push(ScalarValue::iter_to_array(ordering_col)?);
665 }
666
667 result.push(Arc::new(BooleanArray::new(is_sets, None)));
668
669 Ok(result)
670 }
671
672 fn merge_batch(
673 &mut self,
674 values: &[ArrayRef],
675 group_indices: &[usize],
676 opt_filter: Option<&BooleanArray>,
677 total_num_groups: usize,
678 ) -> Result<()> {
679 self.resize_states(total_num_groups);
680
681 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
682
683 let (is_set_arr, val_and_order_cols) = match values.split_last() {
684 Some(result) => result,
685 None => return internal_err!("Empty row in FIRST_VALUE"),
686 };
687
688 let is_set_arr = as_boolean_array(is_set_arr)?;
689
690 let vals = values[0].as_primitive::<T>();
691 let groups = self.get_filtered_min_of_each_group(
693 &val_and_order_cols[1..],
694 group_indices,
695 opt_filter,
696 vals,
697 Some(is_set_arr),
698 )?;
699
700 for (group_idx, idx) in groups.into_iter() {
701 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
702
703 if self.should_update_state(group_idx, &ordering_buf)? {
704 self.update_state(
705 group_idx,
706 &ordering_buf,
707 vals.value(idx),
708 vals.is_null(idx),
709 );
710 }
711 }
712
713 Ok(())
714 }
715
716 fn size(&self) -> usize {
717 self.vals.capacity() * size_of::<T::Native>()
718 + self.null_builder.capacity() / 8 + self.is_sets.capacity() / 8
720 + self.size_of_orderings
721 + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
722 + self.min_of_each_group_buf.1.capacity() / 8
723 }
724
725 fn supports_convert_to_state(&self) -> bool {
726 true
727 }
728
729 fn convert_to_state(
730 &self,
731 values: &[ArrayRef],
732 opt_filter: Option<&BooleanArray>,
733 ) -> Result<Vec<ArrayRef>> {
734 let mut result = values.to_vec();
735 match opt_filter {
736 Some(f) => {
737 result.push(Arc::new(f.clone()));
738 Ok(result)
739 }
740 None => {
741 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
742 Ok(result)
743 }
744 }
745 }
746}
747
748#[derive(Debug)]
753pub struct TrivialFirstValueAccumulator {
754 first: ScalarValue,
755 is_set: bool,
757 ignore_nulls: bool,
759}
760
761impl TrivialFirstValueAccumulator {
762 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
764 ScalarValue::try_from(data_type).map(|first| Self {
765 first,
766 is_set: false,
767 ignore_nulls,
768 })
769 }
770}
771
772impl Accumulator for TrivialFirstValueAccumulator {
773 fn state(&mut self) -> Result<Vec<ScalarValue>> {
774 Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
775 }
776
777 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
778 if !self.is_set {
779 let value = &values[0];
781 let mut first_idx = None;
782 if self.ignore_nulls {
783 for i in 0..value.len() {
785 if !value.is_null(i) {
786 first_idx = Some(i);
787 break;
788 }
789 }
790 } else if !value.is_empty() {
791 first_idx = Some(0);
793 }
794 if let Some(first_idx) = first_idx {
795 let mut row = get_row_at_idx(values, first_idx)?;
796 self.first = row.swap_remove(0);
797 self.first.compact();
798 self.is_set = true;
799 }
800 }
801 Ok(())
802 }
803
804 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
805 if !self.is_set {
808 let flags = states[1].as_boolean();
809 let filtered_states =
810 filter_states_according_to_is_set(&states[0..1], flags)?;
811 if let Some(first) = filtered_states.first() {
812 if !first.is_empty() {
813 self.first = ScalarValue::try_from_array(first, 0)?;
814 self.is_set = true;
815 }
816 }
817 }
818 Ok(())
819 }
820
821 fn evaluate(&mut self) -> Result<ScalarValue> {
822 Ok(self.first.clone())
823 }
824
825 fn size(&self) -> usize {
826 size_of_val(self) - size_of_val(&self.first) + self.first.size()
827 }
828}
829
830#[derive(Debug)]
831pub struct FirstValueAccumulator {
832 first: ScalarValue,
833 is_set: bool,
835 orderings: Vec<ScalarValue>,
838 ordering_req: LexOrdering,
840 is_input_pre_ordered: bool,
842 ignore_nulls: bool,
844}
845
846impl FirstValueAccumulator {
847 pub fn try_new(
849 data_type: &DataType,
850 ordering_dtypes: &[DataType],
851 ordering_req: LexOrdering,
852 is_input_pre_ordered: bool,
853 ignore_nulls: bool,
854 ) -> Result<Self> {
855 let orderings = ordering_dtypes
856 .iter()
857 .map(ScalarValue::try_from)
858 .collect::<Result<_>>()?;
859 ScalarValue::try_from(data_type).map(|first| Self {
860 first,
861 is_set: false,
862 orderings,
863 ordering_req,
864 is_input_pre_ordered,
865 ignore_nulls,
866 })
867 }
868
869 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
871 for s in row.iter_mut() {
873 s.compact();
874 }
875 self.first = row.remove(0);
876 self.orderings = row;
877 self.is_set = true;
878 }
879
880 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
881 let [value, ordering_values @ ..] = values else {
882 return internal_err!("Empty row in FIRST_VALUE");
883 };
884 if self.is_input_pre_ordered {
885 if self.ignore_nulls {
887 for i in 0..value.len() {
889 if !value.is_null(i) {
890 return Ok(Some(i));
891 }
892 }
893 return Ok(None);
894 } else {
895 return Ok((!value.is_empty()).then_some(0));
897 }
898 }
899
900 let sort_columns = ordering_values
901 .iter()
902 .zip(self.ordering_req.iter())
903 .map(|(values, req)| SortColumn {
904 values: Arc::clone(values),
905 options: Some(req.options),
906 })
907 .collect::<Vec<_>>();
908
909 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
910
911 let min_index = if self.ignore_nulls {
912 (0..value.len())
913 .filter(|&index| !value.is_null(index))
914 .min_by(|&a, &b| comparator.compare(a, b))
915 } else {
916 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
917 };
918
919 Ok(min_index)
920 }
921}
922
923impl Accumulator for FirstValueAccumulator {
924 fn state(&mut self) -> Result<Vec<ScalarValue>> {
925 let mut result = vec![self.first.clone()];
926 result.extend(self.orderings.iter().cloned());
927 result.push(ScalarValue::from(self.is_set));
928 Ok(result)
929 }
930
931 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
932 if let Some(first_idx) = self.get_first_idx(values)? {
933 let row = get_row_at_idx(values, first_idx)?;
934 if !self.is_set
935 || (!self.is_input_pre_ordered
936 && compare_rows(
937 &self.orderings,
938 &row[1..],
939 &get_sort_options(&self.ordering_req),
940 )?
941 .is_gt())
942 {
943 self.update_with_new_row(row);
944 }
945 }
946 Ok(())
947 }
948
949 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
950 let is_set_idx = states.len() - 1;
953 let flags = states[is_set_idx].as_boolean();
954 let filtered_states =
955 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
956 let sort_columns =
958 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
959
960 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
961 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
962
963 if let Some(first_idx) = min {
964 let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
965 let first_ordering = &first_row[1..is_set_idx];
967 let sort_options = get_sort_options(&self.ordering_req);
968 if !self.is_set
970 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
971 {
972 assert!(is_set_idx <= first_row.len());
976 first_row.resize(is_set_idx, ScalarValue::Null);
977 self.update_with_new_row(first_row);
978 }
979 }
980 Ok(())
981 }
982
983 fn evaluate(&mut self) -> Result<ScalarValue> {
984 Ok(self.first.clone())
985 }
986
987 fn size(&self) -> usize {
988 size_of_val(self) - size_of_val(&self.first)
989 + self.first.size()
990 + ScalarValue::size_of_vec(&self.orderings)
991 - size_of_val(&self.orderings)
992 }
993}
994
995#[user_doc(
996 doc_section(label = "General Functions"),
997 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
998 syntax_example = "last_value(expression [ORDER BY expression])",
999 sql_example = r#"```sql
1000> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
1001+-----------------------------------------------+
1002| last_value(column_name ORDER BY other_column) |
1003+-----------------------------------------------+
1004| last_element |
1005+-----------------------------------------------+
1006```"#,
1007 standard_argument(name = "expression",)
1008)]
1009#[derive(PartialEq, Eq, Hash)]
1010pub struct LastValue {
1011 signature: Signature,
1012 is_input_pre_ordered: bool,
1013}
1014
1015impl Debug for LastValue {
1016 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1017 f.debug_struct("LastValue")
1018 .field("name", &self.name())
1019 .field("signature", &self.signature)
1020 .field("accumulator", &"<FUNC>")
1021 .finish()
1022 }
1023}
1024
1025impl Default for LastValue {
1026 fn default() -> Self {
1027 Self::new()
1028 }
1029}
1030
1031impl LastValue {
1032 pub fn new() -> Self {
1033 Self {
1034 signature: Signature::any(1, Volatility::Immutable),
1035 is_input_pre_ordered: false,
1036 }
1037 }
1038}
1039
1040impl AggregateUDFImpl for LastValue {
1041 fn as_any(&self) -> &dyn Any {
1042 self
1043 }
1044
1045 fn name(&self) -> &str {
1046 "last_value"
1047 }
1048
1049 fn signature(&self) -> &Signature {
1050 &self.signature
1051 }
1052
1053 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1054 Ok(arg_types[0].clone())
1055 }
1056
1057 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1058 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1059 return TrivialLastValueAccumulator::try_new(
1060 acc_args.return_field.data_type(),
1061 acc_args.ignore_nulls,
1062 )
1063 .map(|acc| Box::new(acc) as _);
1064 };
1065 let ordering_dtypes = ordering
1066 .iter()
1067 .map(|e| e.expr.data_type(acc_args.schema))
1068 .collect::<Result<Vec<_>>>()?;
1069 Ok(Box::new(LastValueAccumulator::try_new(
1070 acc_args.return_field.data_type(),
1071 &ordering_dtypes,
1072 ordering,
1073 self.is_input_pre_ordered,
1074 acc_args.ignore_nulls,
1075 )?))
1076 }
1077
1078 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1079 let mut fields = vec![Field::new(
1080 format_state_name(args.name, "last_value"),
1081 args.return_field.data_type().clone(),
1082 true,
1083 )
1084 .into()];
1085 fields.extend(args.ordering_fields.iter().cloned());
1086 fields.push(Field::new("is_set", DataType::Boolean, true).into());
1087 Ok(fields)
1088 }
1089
1090 fn with_beneficial_ordering(
1091 self: Arc<Self>,
1092 beneficial_ordering: bool,
1093 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1094 Ok(Some(Arc::new(Self {
1095 signature: self.signature.clone(),
1096 is_input_pre_ordered: beneficial_ordering,
1097 })))
1098 }
1099
1100 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1101 AggregateOrderSensitivity::Beneficial
1102 }
1103
1104 fn reverse_expr(&self) -> ReversedUDAF {
1105 ReversedUDAF::Reversed(first_value_udaf())
1106 }
1107
1108 fn documentation(&self) -> Option<&Documentation> {
1109 self.doc()
1110 }
1111
1112 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1113 use DataType::*;
1114 !args.order_bys.is_empty()
1115 && matches!(
1116 args.return_field.data_type(),
1117 Int8 | Int16
1118 | Int32
1119 | Int64
1120 | UInt8
1121 | UInt16
1122 | UInt32
1123 | UInt64
1124 | Float16
1125 | Float32
1126 | Float64
1127 | Decimal128(_, _)
1128 | Decimal256(_, _)
1129 | Date32
1130 | Date64
1131 | Time32(_)
1132 | Time64(_)
1133 | Timestamp(_, _)
1134 )
1135 }
1136
1137 fn create_groups_accumulator(
1138 &self,
1139 args: AccumulatorArgs,
1140 ) -> Result<Box<dyn GroupsAccumulator>> {
1141 fn create_accumulator<T>(
1142 args: AccumulatorArgs,
1143 ) -> Result<Box<dyn GroupsAccumulator>>
1144 where
1145 T: ArrowPrimitiveType + Send,
1146 {
1147 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
1148 return internal_err!("Groups accumulator must have an ordering.");
1149 };
1150
1151 let ordering_dtypes = ordering
1152 .iter()
1153 .map(|e| e.expr.data_type(args.schema))
1154 .collect::<Result<Vec<_>>>()?;
1155
1156 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1157 ordering,
1158 args.ignore_nulls,
1159 args.return_field.data_type(),
1160 &ordering_dtypes,
1161 false,
1162 )?))
1163 }
1164
1165 match args.return_field.data_type() {
1166 DataType::Int8 => create_accumulator::<Int8Type>(args),
1167 DataType::Int16 => create_accumulator::<Int16Type>(args),
1168 DataType::Int32 => create_accumulator::<Int32Type>(args),
1169 DataType::Int64 => create_accumulator::<Int64Type>(args),
1170 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1171 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1172 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1173 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1174 DataType::Float16 => create_accumulator::<Float16Type>(args),
1175 DataType::Float32 => create_accumulator::<Float32Type>(args),
1176 DataType::Float64 => create_accumulator::<Float64Type>(args),
1177
1178 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1179 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1180
1181 DataType::Timestamp(TimeUnit::Second, _) => {
1182 create_accumulator::<TimestampSecondType>(args)
1183 }
1184 DataType::Timestamp(TimeUnit::Millisecond, _) => {
1185 create_accumulator::<TimestampMillisecondType>(args)
1186 }
1187 DataType::Timestamp(TimeUnit::Microsecond, _) => {
1188 create_accumulator::<TimestampMicrosecondType>(args)
1189 }
1190 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1191 create_accumulator::<TimestampNanosecondType>(args)
1192 }
1193
1194 DataType::Date32 => create_accumulator::<Date32Type>(args),
1195 DataType::Date64 => create_accumulator::<Date64Type>(args),
1196 DataType::Time32(TimeUnit::Second) => {
1197 create_accumulator::<Time32SecondType>(args)
1198 }
1199 DataType::Time32(TimeUnit::Millisecond) => {
1200 create_accumulator::<Time32MillisecondType>(args)
1201 }
1202
1203 DataType::Time64(TimeUnit::Microsecond) => {
1204 create_accumulator::<Time64MicrosecondType>(args)
1205 }
1206 DataType::Time64(TimeUnit::Nanosecond) => {
1207 create_accumulator::<Time64NanosecondType>(args)
1208 }
1209
1210 _ => {
1211 internal_err!(
1212 "GroupsAccumulator not supported for last_value({})",
1213 args.return_field.data_type()
1214 )
1215 }
1216 }
1217 }
1218}
1219
1220#[derive(Debug)]
1225pub struct TrivialLastValueAccumulator {
1226 last: ScalarValue,
1227 is_set: bool,
1231 ignore_nulls: bool,
1233}
1234
1235impl TrivialLastValueAccumulator {
1236 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1238 ScalarValue::try_from(data_type).map(|last| Self {
1239 last,
1240 is_set: false,
1241 ignore_nulls,
1242 })
1243 }
1244}
1245
1246impl Accumulator for TrivialLastValueAccumulator {
1247 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1248 Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1249 }
1250
1251 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1252 let value = &values[0];
1254 let mut last_idx = None;
1255 if self.ignore_nulls {
1256 for i in (0..value.len()).rev() {
1258 if !value.is_null(i) {
1259 last_idx = Some(i);
1260 break;
1261 }
1262 }
1263 } else if !value.is_empty() {
1264 last_idx = Some(value.len() - 1);
1266 }
1267 if let Some(last_idx) = last_idx {
1268 let mut row = get_row_at_idx(values, last_idx)?;
1269 self.last = row.swap_remove(0);
1270 self.last.compact();
1271 self.is_set = true;
1272 }
1273 Ok(())
1274 }
1275
1276 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1277 let flags = states[1].as_boolean();
1280 let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1281 if let Some(last) = filtered_states.last() {
1282 if !last.is_empty() {
1283 self.last = ScalarValue::try_from_array(last, 0)?;
1284 self.is_set = true;
1285 }
1286 }
1287 Ok(())
1288 }
1289
1290 fn evaluate(&mut self) -> Result<ScalarValue> {
1291 Ok(self.last.clone())
1292 }
1293
1294 fn size(&self) -> usize {
1295 size_of_val(self) - size_of_val(&self.last) + self.last.size()
1296 }
1297}
1298
1299#[derive(Debug)]
1300struct LastValueAccumulator {
1301 last: ScalarValue,
1302 is_set: bool,
1306 orderings: Vec<ScalarValue>,
1309 ordering_req: LexOrdering,
1311 is_input_pre_ordered: bool,
1313 ignore_nulls: bool,
1315}
1316
1317impl LastValueAccumulator {
1318 pub fn try_new(
1320 data_type: &DataType,
1321 ordering_dtypes: &[DataType],
1322 ordering_req: LexOrdering,
1323 is_input_pre_ordered: bool,
1324 ignore_nulls: bool,
1325 ) -> Result<Self> {
1326 let orderings = ordering_dtypes
1327 .iter()
1328 .map(ScalarValue::try_from)
1329 .collect::<Result<_>>()?;
1330 ScalarValue::try_from(data_type).map(|last| Self {
1331 last,
1332 is_set: false,
1333 orderings,
1334 ordering_req,
1335 is_input_pre_ordered,
1336 ignore_nulls,
1337 })
1338 }
1339
1340 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1342 for s in row.iter_mut() {
1344 s.compact();
1345 }
1346 self.last = row.remove(0);
1347 self.orderings = row;
1348 self.is_set = true;
1349 }
1350
1351 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1352 let [value, ordering_values @ ..] = values else {
1353 return internal_err!("Empty row in LAST_VALUE");
1354 };
1355 if self.is_input_pre_ordered {
1356 if self.ignore_nulls {
1358 for i in (0..value.len()).rev() {
1360 if !value.is_null(i) {
1361 return Ok(Some(i));
1362 }
1363 }
1364 return Ok(None);
1365 } else {
1366 return Ok((!value.is_empty()).then_some(value.len() - 1));
1367 }
1368 }
1369
1370 let sort_columns = ordering_values
1371 .iter()
1372 .zip(self.ordering_req.iter())
1373 .map(|(values, req)| SortColumn {
1374 values: Arc::clone(values),
1375 options: Some(req.options),
1376 })
1377 .collect::<Vec<_>>();
1378
1379 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1380 let max_ind = if self.ignore_nulls {
1381 (0..value.len())
1382 .filter(|&index| !(value.is_null(index)))
1383 .max_by(|&a, &b| comparator.compare(a, b))
1384 } else {
1385 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1386 };
1387
1388 Ok(max_ind)
1389 }
1390}
1391
1392impl Accumulator for LastValueAccumulator {
1393 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1394 let mut result = vec![self.last.clone()];
1395 result.extend(self.orderings.clone());
1396 result.push(ScalarValue::from(self.is_set));
1397 Ok(result)
1398 }
1399
1400 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1401 if let Some(last_idx) = self.get_last_idx(values)? {
1402 let row = get_row_at_idx(values, last_idx)?;
1403 let orderings = &row[1..];
1404 if !self.is_set
1406 || self.is_input_pre_ordered
1407 || compare_rows(
1408 &self.orderings,
1409 orderings,
1410 &get_sort_options(&self.ordering_req),
1411 )?
1412 .is_lt()
1413 {
1414 self.update_with_new_row(row);
1415 }
1416 }
1417 Ok(())
1418 }
1419
1420 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1421 let is_set_idx = states.len() - 1;
1424 let flags = states[is_set_idx].as_boolean();
1425 let filtered_states =
1426 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1427 let sort_columns =
1429 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1430
1431 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1432 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1433
1434 if let Some(last_idx) = max {
1435 let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1436 let last_ordering = &last_row[1..is_set_idx];
1438 let sort_options = get_sort_options(&self.ordering_req);
1439 if !self.is_set
1442 || self.is_input_pre_ordered
1443 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1444 {
1445 assert!(is_set_idx <= last_row.len());
1449 last_row.resize(is_set_idx, ScalarValue::Null);
1450 self.update_with_new_row(last_row);
1451 }
1452 }
1453 Ok(())
1454 }
1455
1456 fn evaluate(&mut self) -> Result<ScalarValue> {
1457 Ok(self.last.clone())
1458 }
1459
1460 fn size(&self) -> usize {
1461 size_of_val(self) - size_of_val(&self.last)
1462 + self.last.size()
1463 + ScalarValue::size_of_vec(&self.orderings)
1464 - size_of_val(&self.orderings)
1465 }
1466}
1467
1468fn filter_states_according_to_is_set(
1471 states: &[ArrayRef],
1472 flags: &BooleanArray,
1473) -> Result<Vec<ArrayRef>> {
1474 states
1475 .iter()
1476 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1477 .collect()
1478}
1479
1480fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1482 arrs.iter()
1483 .zip(sort_exprs.iter())
1484 .map(|(item, sort_expr)| SortColumn {
1485 values: Arc::clone(item),
1486 options: Some(sort_expr.options),
1487 })
1488 .collect()
1489}
1490
1491#[cfg(test)]
1492mod tests {
1493 use std::iter::repeat_with;
1494
1495 use arrow::{
1496 array::{Int64Array, ListArray},
1497 compute::SortOptions,
1498 datatypes::Schema,
1499 };
1500 use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1501
1502 use super::*;
1503
1504 #[test]
1505 fn test_first_last_value_value() -> Result<()> {
1506 let mut first_accumulator =
1507 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1508 let mut last_accumulator =
1509 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1510 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1513 let arrs = ranges
1515 .into_iter()
1516 .map(|(start, end)| {
1517 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1518 })
1519 .collect::<Vec<_>>();
1520 for arr in arrs {
1521 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1524 last_accumulator.update_batch(&[arr])?;
1526 }
1527 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1529 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1531 Ok(())
1532 }
1533
1534 #[test]
1535 fn test_first_last_state_after_merge() -> Result<()> {
1536 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1537 let arrs = ranges
1539 .into_iter()
1540 .map(|(start, end)| {
1541 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1542 })
1543 .collect::<Vec<_>>();
1544
1545 let mut first_accumulator =
1547 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1548
1549 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1550 let state1 = first_accumulator.state()?;
1551
1552 let mut first_accumulator =
1553 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1554 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1555 let state2 = first_accumulator.state()?;
1556
1557 assert_eq!(state1.len(), state2.len());
1558
1559 let mut states = vec![];
1560
1561 for idx in 0..state1.len() {
1562 states.push(compute::concat(&[
1563 &state1[idx].to_array()?,
1564 &state2[idx].to_array()?,
1565 ])?);
1566 }
1567
1568 let mut first_accumulator =
1569 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1570 first_accumulator.merge_batch(&states)?;
1571
1572 let merged_state = first_accumulator.state()?;
1573 assert_eq!(merged_state.len(), state1.len());
1574
1575 let mut last_accumulator =
1577 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1578
1579 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1580 let state1 = last_accumulator.state()?;
1581
1582 let mut last_accumulator =
1583 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1584 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1585 let state2 = last_accumulator.state()?;
1586
1587 assert_eq!(state1.len(), state2.len());
1588
1589 let mut states = vec![];
1590
1591 for idx in 0..state1.len() {
1592 states.push(compute::concat(&[
1593 &state1[idx].to_array()?,
1594 &state2[idx].to_array()?,
1595 ])?);
1596 }
1597
1598 let mut last_accumulator =
1599 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1600 last_accumulator.merge_batch(&states)?;
1601
1602 let merged_state = last_accumulator.state()?;
1603 assert_eq!(merged_state.len(), state1.len());
1604
1605 Ok(())
1606 }
1607
1608 #[test]
1609 fn test_first_group_acc() -> Result<()> {
1610 let schema = Arc::new(Schema::new(vec![
1611 Field::new("a", DataType::Int64, true),
1612 Field::new("b", DataType::Int64, true),
1613 Field::new("c", DataType::Int64, true),
1614 Field::new("d", DataType::Int32, true),
1615 Field::new("e", DataType::Boolean, true),
1616 ]));
1617
1618 let sort_keys = [PhysicalSortExpr {
1619 expr: col("c", &schema).unwrap(),
1620 options: SortOptions::default(),
1621 }];
1622
1623 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1624 sort_keys.into(),
1625 true,
1626 &DataType::Int64,
1627 &[DataType::Int64],
1628 true,
1629 )?;
1630
1631 let mut val_with_orderings = {
1632 let mut val_with_orderings = Vec::<ArrayRef>::new();
1633
1634 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1635 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1636
1637 val_with_orderings.push(vals);
1638 val_with_orderings.push(orderings);
1639
1640 val_with_orderings
1641 };
1642
1643 group_acc.update_batch(
1644 &val_with_orderings,
1645 &[0, 1, 2, 1],
1646 Some(&BooleanArray::from(vec![true, true, false, true])),
1647 3,
1648 )?;
1649 assert_eq!(
1650 group_acc.size_of_orderings,
1651 group_acc.compute_size_of_orderings()
1652 );
1653
1654 let state = group_acc.state(EmitTo::All)?;
1655
1656 let expected_state: Vec<Arc<dyn Array>> = vec![
1657 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1658 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1659 Arc::new(BooleanArray::from(vec![true, true, false])),
1660 ];
1661 assert_eq!(state, expected_state);
1662
1663 assert_eq!(
1664 group_acc.size_of_orderings,
1665 group_acc.compute_size_of_orderings()
1666 );
1667
1668 group_acc.merge_batch(
1669 &state,
1670 &[0, 1, 2],
1671 Some(&BooleanArray::from(vec![true, false, false])),
1672 3,
1673 )?;
1674
1675 assert_eq!(
1676 group_acc.size_of_orderings,
1677 group_acc.compute_size_of_orderings()
1678 );
1679
1680 val_with_orderings.clear();
1681 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1682 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1683
1684 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1685
1686 let binding = group_acc.evaluate(EmitTo::All)?;
1687 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1688
1689 let expect: PrimitiveArray<Int64Type> =
1690 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1691
1692 assert_eq!(eval_result, &expect);
1693
1694 assert_eq!(
1695 group_acc.size_of_orderings,
1696 group_acc.compute_size_of_orderings()
1697 );
1698
1699 Ok(())
1700 }
1701
1702 #[test]
1703 fn test_group_acc_size_of_ordering() -> Result<()> {
1704 let schema = Arc::new(Schema::new(vec![
1705 Field::new("a", DataType::Int64, true),
1706 Field::new("b", DataType::Int64, true),
1707 Field::new("c", DataType::Int64, true),
1708 Field::new("d", DataType::Int32, true),
1709 Field::new("e", DataType::Boolean, true),
1710 ]));
1711
1712 let sort_keys = [PhysicalSortExpr {
1713 expr: col("c", &schema).unwrap(),
1714 options: SortOptions::default(),
1715 }];
1716
1717 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1718 sort_keys.into(),
1719 true,
1720 &DataType::Int64,
1721 &[DataType::Int64],
1722 true,
1723 )?;
1724
1725 let val_with_orderings = {
1726 let mut val_with_orderings = Vec::<ArrayRef>::new();
1727
1728 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1729 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1730
1731 val_with_orderings.push(vals);
1732 val_with_orderings.push(orderings);
1733
1734 val_with_orderings
1735 };
1736
1737 for _ in 0..10 {
1738 group_acc.update_batch(
1739 &val_with_orderings,
1740 &[0, 1, 2, 1],
1741 Some(&BooleanArray::from(vec![true, true, false, true])),
1742 100,
1743 )?;
1744 assert_eq!(
1745 group_acc.size_of_orderings,
1746 group_acc.compute_size_of_orderings()
1747 );
1748
1749 group_acc.state(EmitTo::First(2))?;
1750 assert_eq!(
1751 group_acc.size_of_orderings,
1752 group_acc.compute_size_of_orderings()
1753 );
1754
1755 let s = group_acc.state(EmitTo::All)?;
1756 assert_eq!(
1757 group_acc.size_of_orderings,
1758 group_acc.compute_size_of_orderings()
1759 );
1760
1761 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1762 assert_eq!(
1763 group_acc.size_of_orderings,
1764 group_acc.compute_size_of_orderings()
1765 );
1766
1767 group_acc.evaluate(EmitTo::First(2))?;
1768 assert_eq!(
1769 group_acc.size_of_orderings,
1770 group_acc.compute_size_of_orderings()
1771 );
1772
1773 group_acc.evaluate(EmitTo::All)?;
1774 assert_eq!(
1775 group_acc.size_of_orderings,
1776 group_acc.compute_size_of_orderings()
1777 );
1778 }
1779
1780 Ok(())
1781 }
1782
1783 #[test]
1784 fn test_last_group_acc() -> Result<()> {
1785 let schema = Arc::new(Schema::new(vec![
1786 Field::new("a", DataType::Int64, true),
1787 Field::new("b", DataType::Int64, true),
1788 Field::new("c", DataType::Int64, true),
1789 Field::new("d", DataType::Int32, true),
1790 Field::new("e", DataType::Boolean, true),
1791 ]));
1792
1793 let sort_keys = [PhysicalSortExpr {
1794 expr: col("c", &schema).unwrap(),
1795 options: SortOptions::default(),
1796 }];
1797
1798 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1799 sort_keys.into(),
1800 true,
1801 &DataType::Int64,
1802 &[DataType::Int64],
1803 false,
1804 )?;
1805
1806 let mut val_with_orderings = {
1807 let mut val_with_orderings = Vec::<ArrayRef>::new();
1808
1809 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1810 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1811
1812 val_with_orderings.push(vals);
1813 val_with_orderings.push(orderings);
1814
1815 val_with_orderings
1816 };
1817
1818 group_acc.update_batch(
1819 &val_with_orderings,
1820 &[0, 1, 2, 1],
1821 Some(&BooleanArray::from(vec![true, true, false, true])),
1822 3,
1823 )?;
1824
1825 let state = group_acc.state(EmitTo::All)?;
1826
1827 let expected_state: Vec<Arc<dyn Array>> = vec![
1828 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1829 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1830 Arc::new(BooleanArray::from(vec![true, true, false])),
1831 ];
1832 assert_eq!(state, expected_state);
1833
1834 group_acc.merge_batch(
1835 &state,
1836 &[0, 1, 2],
1837 Some(&BooleanArray::from(vec![true, false, false])),
1838 3,
1839 )?;
1840
1841 val_with_orderings.clear();
1842 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1843 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1844
1845 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1846
1847 let binding = group_acc.evaluate(EmitTo::All)?;
1848 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1849
1850 let expect: PrimitiveArray<Int64Type> =
1851 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1852
1853 assert_eq!(eval_result, &expect);
1854
1855 Ok(())
1856 }
1857
1858 #[test]
1859 fn test_first_list_acc_size() -> Result<()> {
1860 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1861 let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1862 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1863 false,
1864 )?;
1865
1866 first_accumulator.update_batch(values)?;
1867
1868 Ok(first_accumulator.size())
1869 }
1870
1871 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1872 repeat_with(|| Some(vec![Some(1)])).take(10000),
1873 );
1874 let batch2 =
1875 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1876
1877 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1878 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1879 assert_eq!(size1, size2);
1880
1881 Ok(())
1882 }
1883
1884 #[test]
1885 fn test_last_list_acc_size() -> Result<()> {
1886 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1887 let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1888 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1889 false,
1890 )?;
1891
1892 last_accumulator.update_batch(values)?;
1893
1894 Ok(last_accumulator.size())
1895 }
1896
1897 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1898 repeat_with(|| Some(vec![Some(1)])).take(10000),
1899 );
1900 let batch2 =
1901 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1902
1903 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1904 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1905 assert_eq!(size1, size2);
1906
1907 Ok(())
1908 }
1909}