1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::size_of_val;
24use std::sync::Arc;
25
26use arrow::array::{
27 Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
28 PrimitiveArray,
29};
30use arrow::buffer::{BooleanBuffer, NullBuffer};
31use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
32use arrow::datatypes::{
33 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type,
34 Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type,
35 Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType,
36 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
38 UInt32Type, UInt64Type, UInt8Type,
39};
40use datafusion_common::cast::as_boolean_array;
41use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
42use datafusion_common::{
43 arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
44};
45use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
46use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
47use datafusion_expr::{
48 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
49 GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
50};
51use datafusion_functions_aggregate_common::utils::get_sort_options;
52use datafusion_macros::user_doc;
53use datafusion_physical_expr_common::sort_expr::LexOrdering;
54
55create_func!(FirstValue, first_value_udaf);
56create_func!(LastValue, last_value_udaf);
57
58pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
60 first_value_udaf()
61 .call(vec![expression])
62 .order_by(order_by)
63 .build()
64 .unwrap()
66}
67
68pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
70 last_value_udaf()
71 .call(vec![expression])
72 .order_by(order_by)
73 .build()
74 .unwrap()
76}
77
78#[user_doc(
79 doc_section(label = "General Functions"),
80 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
81 syntax_example = "first_value(expression [ORDER BY expression])",
82 sql_example = r#"```sql
83> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
84+-----------------------------------------------+
85| first_value(column_name ORDER BY other_column)|
86+-----------------------------------------------+
87| first_element |
88+-----------------------------------------------+
89```"#,
90 standard_argument(name = "expression",)
91)]
92#[derive(PartialEq, Eq, Hash)]
93pub struct FirstValue {
94 signature: Signature,
95 is_input_pre_ordered: bool,
96}
97
98impl Debug for FirstValue {
99 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
100 f.debug_struct("FirstValue")
101 .field("name", &self.name())
102 .field("signature", &self.signature)
103 .field("accumulator", &"<FUNC>")
104 .finish()
105 }
106}
107
108impl Default for FirstValue {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl FirstValue {
115 pub fn new() -> Self {
116 Self {
117 signature: Signature::any(1, Volatility::Immutable),
118 is_input_pre_ordered: false,
119 }
120 }
121}
122
123impl AggregateUDFImpl for FirstValue {
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127
128 fn name(&self) -> &str {
129 "first_value"
130 }
131
132 fn signature(&self) -> &Signature {
133 &self.signature
134 }
135
136 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137 Ok(arg_types[0].clone())
138 }
139
140 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
142 return TrivialFirstValueAccumulator::try_new(
143 acc_args.return_field.data_type(),
144 acc_args.ignore_nulls,
145 )
146 .map(|acc| Box::new(acc) as _);
147 };
148 let ordering_dtypes = ordering
149 .iter()
150 .map(|e| e.expr.data_type(acc_args.schema))
151 .collect::<Result<Vec<_>>>()?;
152 Ok(Box::new(FirstValueAccumulator::try_new(
153 acc_args.return_field.data_type(),
154 &ordering_dtypes,
155 ordering,
156 self.is_input_pre_ordered,
157 acc_args.ignore_nulls,
158 )?))
159 }
160
161 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
162 let mut fields = vec![Field::new(
163 format_state_name(args.name, "first_value"),
164 args.return_type().clone(),
165 true,
166 )
167 .into()];
168 fields.extend(args.ordering_fields.iter().cloned());
169 fields.push(
170 Field::new(
171 format_state_name(args.name, "first_value_is_set"),
172 DataType::Boolean,
173 true,
174 )
175 .into(),
176 );
177 Ok(fields)
178 }
179
180 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
181 use DataType::*;
182 !args.order_bys.is_empty()
183 && matches!(
184 args.return_field.data_type(),
185 Int8 | Int16
186 | Int32
187 | Int64
188 | UInt8
189 | UInt16
190 | UInt32
191 | UInt64
192 | Float16
193 | Float32
194 | Float64
195 | Decimal32(_, _)
196 | Decimal64(_, _)
197 | Decimal128(_, _)
198 | Decimal256(_, _)
199 | Date32
200 | Date64
201 | Time32(_)
202 | Time64(_)
203 | Timestamp(_, _)
204 )
205 }
206
207 fn create_groups_accumulator(
208 &self,
209 args: AccumulatorArgs,
210 ) -> Result<Box<dyn GroupsAccumulator>> {
211 fn create_accumulator<T: ArrowPrimitiveType + Send>(
212 args: AccumulatorArgs,
213 ) -> Result<Box<dyn GroupsAccumulator>> {
214 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
215 return internal_err!("Groups accumulator must have an ordering.");
216 };
217
218 let ordering_dtypes = ordering
219 .iter()
220 .map(|e| e.expr.data_type(args.schema))
221 .collect::<Result<Vec<_>>>()?;
222
223 FirstPrimitiveGroupsAccumulator::<T>::try_new(
224 ordering,
225 args.ignore_nulls,
226 args.return_field.data_type(),
227 &ordering_dtypes,
228 true,
229 )
230 .map(|acc| Box::new(acc) as _)
231 }
232
233 match args.return_field.data_type() {
234 DataType::Int8 => create_accumulator::<Int8Type>(args),
235 DataType::Int16 => create_accumulator::<Int16Type>(args),
236 DataType::Int32 => create_accumulator::<Int32Type>(args),
237 DataType::Int64 => create_accumulator::<Int64Type>(args),
238 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
239 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
240 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
241 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
242 DataType::Float16 => create_accumulator::<Float16Type>(args),
243 DataType::Float32 => create_accumulator::<Float32Type>(args),
244 DataType::Float64 => create_accumulator::<Float64Type>(args),
245
246 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(args),
247 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(args),
248 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
249 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
250
251 DataType::Timestamp(TimeUnit::Second, _) => {
252 create_accumulator::<TimestampSecondType>(args)
253 }
254 DataType::Timestamp(TimeUnit::Millisecond, _) => {
255 create_accumulator::<TimestampMillisecondType>(args)
256 }
257 DataType::Timestamp(TimeUnit::Microsecond, _) => {
258 create_accumulator::<TimestampMicrosecondType>(args)
259 }
260 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
261 create_accumulator::<TimestampNanosecondType>(args)
262 }
263
264 DataType::Date32 => create_accumulator::<Date32Type>(args),
265 DataType::Date64 => create_accumulator::<Date64Type>(args),
266 DataType::Time32(TimeUnit::Second) => {
267 create_accumulator::<Time32SecondType>(args)
268 }
269 DataType::Time32(TimeUnit::Millisecond) => {
270 create_accumulator::<Time32MillisecondType>(args)
271 }
272
273 DataType::Time64(TimeUnit::Microsecond) => {
274 create_accumulator::<Time64MicrosecondType>(args)
275 }
276 DataType::Time64(TimeUnit::Nanosecond) => {
277 create_accumulator::<Time64NanosecondType>(args)
278 }
279
280 _ => internal_err!(
281 "GroupsAccumulator not supported for first_value({})",
282 args.return_field.data_type()
283 ),
284 }
285 }
286
287 fn with_beneficial_ordering(
288 self: Arc<Self>,
289 beneficial_ordering: bool,
290 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
291 Ok(Some(Arc::new(Self {
292 signature: self.signature.clone(),
293 is_input_pre_ordered: beneficial_ordering,
294 })))
295 }
296
297 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
298 AggregateOrderSensitivity::Beneficial
299 }
300
301 fn reverse_expr(&self) -> ReversedUDAF {
302 ReversedUDAF::Reversed(last_value_udaf())
303 }
304
305 fn documentation(&self) -> Option<&Documentation> {
306 self.doc()
307 }
308}
309
310struct FirstPrimitiveGroupsAccumulator<T>
312where
313 T: ArrowPrimitiveType + Send,
314{
315 vals: Vec<T::Native>,
317 orderings: Vec<Vec<ScalarValue>>,
322 is_sets: BooleanBufferBuilder,
325 null_builder: BooleanBufferBuilder,
327 size_of_orderings: usize,
332
333 min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
338
339 ordering_req: LexOrdering,
343 pick_first_in_group: bool,
346 sort_options: Vec<SortOptions>,
348 ignore_nulls: bool,
350 data_type: DataType,
352 default_orderings: Vec<ScalarValue>,
353}
354
355impl<T> FirstPrimitiveGroupsAccumulator<T>
356where
357 T: ArrowPrimitiveType + Send,
358{
359 fn try_new(
360 ordering_req: LexOrdering,
361 ignore_nulls: bool,
362 data_type: &DataType,
363 ordering_dtypes: &[DataType],
364 pick_first_in_group: bool,
365 ) -> Result<Self> {
366 let default_orderings = ordering_dtypes
367 .iter()
368 .map(ScalarValue::try_from)
369 .collect::<Result<_>>()?;
370
371 let sort_options = get_sort_options(&ordering_req);
372
373 Ok(Self {
374 null_builder: BooleanBufferBuilder::new(0),
375 ordering_req,
376 sort_options,
377 ignore_nulls,
378 default_orderings,
379 data_type: data_type.clone(),
380 vals: Vec::new(),
381 orderings: Vec::new(),
382 is_sets: BooleanBufferBuilder::new(0),
383 size_of_orderings: 0,
384 min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
385 pick_first_in_group,
386 })
387 }
388
389 fn should_update_state(
390 &self,
391 group_idx: usize,
392 new_ordering_values: &[ScalarValue],
393 ) -> Result<bool> {
394 if !self.is_sets.get_bit(group_idx) {
395 return Ok(true);
396 }
397
398 assert!(new_ordering_values.len() == self.ordering_req.len());
399 let current_ordering = &self.orderings[group_idx];
400 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
401 if self.pick_first_in_group {
402 x.is_gt()
403 } else {
404 x.is_lt()
405 }
406 })
407 }
408
409 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
410 let result = emit_to.take_needed(&mut self.orderings);
411
412 match emit_to {
413 EmitTo::All => self.size_of_orderings = 0,
414 EmitTo::First(_) => {
415 self.size_of_orderings -=
416 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
417 }
418 }
419
420 result
421 }
422
423 fn take_need(
424 bool_buf_builder: &mut BooleanBufferBuilder,
425 emit_to: EmitTo,
426 ) -> BooleanBuffer {
427 let bool_buf = bool_buf_builder.finish();
428 match emit_to {
429 EmitTo::All => bool_buf,
430 EmitTo::First(n) => {
431 let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
436 for b in bool_buf.iter().skip(n) {
438 bool_buf_builder.append(b);
439 }
440 first_n
441 }
442 }
443 }
444
445 fn resize_states(&mut self, new_size: usize) {
446 self.vals.resize(new_size, T::default_value());
447
448 self.null_builder.resize(new_size);
449
450 if self.orderings.len() < new_size {
451 let current_len = self.orderings.len();
452
453 self.orderings
454 .resize(new_size, self.default_orderings.clone());
455
456 self.size_of_orderings += (new_size - current_len)
457 * ScalarValue::size_of_vec(
458 self.orderings.last().unwrap(),
462 );
463 }
464
465 self.is_sets.resize(new_size);
466
467 self.min_of_each_group_buf.0.resize(new_size, 0);
468 self.min_of_each_group_buf.1.resize(new_size);
469 }
470
471 fn update_state(
472 &mut self,
473 group_idx: usize,
474 orderings: &[ScalarValue],
475 new_val: T::Native,
476 is_null: bool,
477 ) {
478 self.vals[group_idx] = new_val;
479 self.is_sets.set_bit(group_idx, true);
480
481 self.null_builder.set_bit(group_idx, !is_null);
482
483 assert!(orderings.len() == self.ordering_req.len());
484 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
485 self.orderings[group_idx].clear();
486 self.orderings[group_idx].extend_from_slice(orderings);
487 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
488 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
489 }
490
491 fn take_state(
492 &mut self,
493 emit_to: EmitTo,
494 ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
495 emit_to.take_needed(&mut self.min_of_each_group_buf.0);
496 self.min_of_each_group_buf
497 .1
498 .truncate(self.min_of_each_group_buf.0.len());
499
500 (
501 self.take_vals_and_null_buf(emit_to),
502 self.take_orderings(emit_to),
503 Self::take_need(&mut self.is_sets, emit_to),
504 )
505 }
506
507 #[cfg(test)]
509 fn compute_size_of_orderings(&self) -> usize {
510 self.orderings
511 .iter()
512 .map(ScalarValue::size_of_vec)
513 .sum::<usize>()
514 }
515 fn get_filtered_min_of_each_group(
520 &mut self,
521 orderings: &[ArrayRef],
522 group_indices: &[usize],
523 opt_filter: Option<&BooleanArray>,
524 vals: &PrimitiveArray<T>,
525 is_set_arr: Option<&BooleanArray>,
526 ) -> Result<Vec<(usize, usize)>> {
527 self.min_of_each_group_buf.1.truncate(0);
529 self.min_of_each_group_buf
530 .1
531 .append_n(self.vals.len(), false);
532
533 let comparator = {
537 assert_eq!(orderings.len(), self.ordering_req.len());
538 let sort_columns = orderings
539 .iter()
540 .zip(self.ordering_req.iter())
541 .map(|(array, req)| SortColumn {
542 values: Arc::clone(array),
543 options: Some(req.options),
544 })
545 .collect::<Vec<_>>();
546
547 LexicographicalComparator::try_new(&sort_columns)?
548 };
549
550 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
551 let group_idx = *group_idx;
552
553 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
554 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
555
556 if !passed_filter || !is_set {
557 continue;
558 }
559
560 if self.ignore_nulls && vals.is_null(idx_in_val) {
561 continue;
562 }
563
564 let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
565
566 if !is_valid {
567 self.min_of_each_group_buf.1.set_bit(group_idx, true);
568 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
569 } else {
570 let ordering = comparator
571 .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
572
573 if (ordering.is_gt() && self.pick_first_in_group)
574 || (ordering.is_lt() && !self.pick_first_in_group)
575 {
576 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
577 }
578 }
579 }
580
581 Ok(self
582 .min_of_each_group_buf
583 .0
584 .iter()
585 .enumerate()
586 .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
587 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
588 .collect::<Vec<_>>())
589 }
590
591 fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
592 let r = emit_to.take_needed(&mut self.vals);
593
594 let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
595
596 let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) .with_data_type(self.data_type.clone());
598 Arc::new(values)
599 }
600}
601
602impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
603where
604 T: ArrowPrimitiveType + Send,
605{
606 fn update_batch(
607 &mut self,
608 values_and_order_cols: &[ArrayRef],
610 group_indices: &[usize],
611 opt_filter: Option<&BooleanArray>,
612 total_num_groups: usize,
613 ) -> Result<()> {
614 self.resize_states(total_num_groups);
615
616 let vals = values_and_order_cols[0].as_primitive::<T>();
617
618 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
619
620 for (group_idx, idx) in self
622 .get_filtered_min_of_each_group(
623 &values_and_order_cols[1..],
624 group_indices,
625 opt_filter,
626 vals,
627 None,
628 )?
629 .into_iter()
630 {
631 extract_row_at_idx_to_buf(
632 &values_and_order_cols[1..],
633 idx,
634 &mut ordering_buf,
635 )?;
636
637 if self.should_update_state(group_idx, &ordering_buf)? {
638 self.update_state(
639 group_idx,
640 &ordering_buf,
641 vals.value(idx),
642 vals.is_null(idx),
643 );
644 }
645 }
646
647 Ok(())
648 }
649
650 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
651 Ok(self.take_state(emit_to).0)
652 }
653
654 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
655 let (val_arr, orderings, is_sets) = self.take_state(emit_to);
656 let mut result = Vec::with_capacity(self.orderings.len() + 2);
657
658 result.push(val_arr);
659
660 let ordering_cols = {
661 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
662 for _ in 0..self.ordering_req.len() {
663 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
664 }
665 for row in orderings.into_iter() {
666 assert_eq!(row.len(), self.ordering_req.len());
667 for (col_idx, ordering) in row.into_iter().enumerate() {
668 ordering_cols[col_idx].push(ordering);
669 }
670 }
671
672 ordering_cols
673 };
674 for ordering_col in ordering_cols {
675 result.push(ScalarValue::iter_to_array(ordering_col)?);
676 }
677
678 result.push(Arc::new(BooleanArray::new(is_sets, None)));
679
680 Ok(result)
681 }
682
683 fn merge_batch(
684 &mut self,
685 values: &[ArrayRef],
686 group_indices: &[usize],
687 opt_filter: Option<&BooleanArray>,
688 total_num_groups: usize,
689 ) -> Result<()> {
690 self.resize_states(total_num_groups);
691
692 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
693
694 let (is_set_arr, val_and_order_cols) = match values.split_last() {
695 Some(result) => result,
696 None => return internal_err!("Empty row in FIRST_VALUE"),
697 };
698
699 let is_set_arr = as_boolean_array(is_set_arr)?;
700
701 let vals = values[0].as_primitive::<T>();
702 let groups = self.get_filtered_min_of_each_group(
704 &val_and_order_cols[1..],
705 group_indices,
706 opt_filter,
707 vals,
708 Some(is_set_arr),
709 )?;
710
711 for (group_idx, idx) in groups.into_iter() {
712 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
713
714 if self.should_update_state(group_idx, &ordering_buf)? {
715 self.update_state(
716 group_idx,
717 &ordering_buf,
718 vals.value(idx),
719 vals.is_null(idx),
720 );
721 }
722 }
723
724 Ok(())
725 }
726
727 fn size(&self) -> usize {
728 self.vals.capacity() * size_of::<T::Native>()
729 + self.null_builder.capacity() / 8 + self.is_sets.capacity() / 8
731 + self.size_of_orderings
732 + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
733 + self.min_of_each_group_buf.1.capacity() / 8
734 }
735
736 fn supports_convert_to_state(&self) -> bool {
737 true
738 }
739
740 fn convert_to_state(
741 &self,
742 values: &[ArrayRef],
743 opt_filter: Option<&BooleanArray>,
744 ) -> Result<Vec<ArrayRef>> {
745 let mut result = values.to_vec();
746 match opt_filter {
747 Some(f) => {
748 result.push(Arc::new(f.clone()));
749 Ok(result)
750 }
751 None => {
752 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
753 Ok(result)
754 }
755 }
756 }
757}
758
759#[derive(Debug)]
764pub struct TrivialFirstValueAccumulator {
765 first: ScalarValue,
766 is_set: bool,
768 ignore_nulls: bool,
770}
771
772impl TrivialFirstValueAccumulator {
773 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
775 ScalarValue::try_from(data_type).map(|first| Self {
776 first,
777 is_set: false,
778 ignore_nulls,
779 })
780 }
781}
782
783impl Accumulator for TrivialFirstValueAccumulator {
784 fn state(&mut self) -> Result<Vec<ScalarValue>> {
785 Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
786 }
787
788 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
789 if !self.is_set {
790 let value = &values[0];
792 let mut first_idx = None;
793 if self.ignore_nulls {
794 for i in 0..value.len() {
796 if !value.is_null(i) {
797 first_idx = Some(i);
798 break;
799 }
800 }
801 } else if !value.is_empty() {
802 first_idx = Some(0);
804 }
805 if let Some(first_idx) = first_idx {
806 let mut row = get_row_at_idx(values, first_idx)?;
807 self.first = row.swap_remove(0);
808 self.first.compact();
809 self.is_set = true;
810 }
811 }
812 Ok(())
813 }
814
815 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
816 if !self.is_set {
819 let flags = states[1].as_boolean();
820 validate_is_set_flags(flags, "first_value")?;
821
822 let filtered_states =
823 filter_states_according_to_is_set(&states[0..1], flags)?;
824 if let Some(first) = filtered_states.first() {
825 if !first.is_empty() {
826 self.first = ScalarValue::try_from_array(first, 0)?;
827 self.is_set = true;
828 }
829 }
830 }
831 Ok(())
832 }
833
834 fn evaluate(&mut self) -> Result<ScalarValue> {
835 Ok(self.first.clone())
836 }
837
838 fn size(&self) -> usize {
839 size_of_val(self) - size_of_val(&self.first) + self.first.size()
840 }
841}
842
843#[derive(Debug)]
844pub struct FirstValueAccumulator {
845 first: ScalarValue,
846 is_set: bool,
848 orderings: Vec<ScalarValue>,
851 ordering_req: LexOrdering,
853 is_input_pre_ordered: bool,
855 ignore_nulls: bool,
857}
858
859impl FirstValueAccumulator {
860 pub fn try_new(
862 data_type: &DataType,
863 ordering_dtypes: &[DataType],
864 ordering_req: LexOrdering,
865 is_input_pre_ordered: bool,
866 ignore_nulls: bool,
867 ) -> Result<Self> {
868 let orderings = ordering_dtypes
869 .iter()
870 .map(ScalarValue::try_from)
871 .collect::<Result<_>>()?;
872 ScalarValue::try_from(data_type).map(|first| Self {
873 first,
874 is_set: false,
875 orderings,
876 ordering_req,
877 is_input_pre_ordered,
878 ignore_nulls,
879 })
880 }
881
882 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
884 for s in row.iter_mut() {
886 s.compact();
887 }
888 self.first = row.remove(0);
889 self.orderings = row;
890 self.is_set = true;
891 }
892
893 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
894 let [value, ordering_values @ ..] = values else {
895 return internal_err!("Empty row in FIRST_VALUE");
896 };
897 if self.is_input_pre_ordered {
898 if self.ignore_nulls {
900 for i in 0..value.len() {
902 if !value.is_null(i) {
903 return Ok(Some(i));
904 }
905 }
906 return Ok(None);
907 } else {
908 return Ok((!value.is_empty()).then_some(0));
910 }
911 }
912
913 let sort_columns = ordering_values
914 .iter()
915 .zip(self.ordering_req.iter())
916 .map(|(values, req)| SortColumn {
917 values: Arc::clone(values),
918 options: Some(req.options),
919 })
920 .collect::<Vec<_>>();
921
922 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
923
924 let min_index = if self.ignore_nulls {
925 (0..value.len())
926 .filter(|&index| !value.is_null(index))
927 .min_by(|&a, &b| comparator.compare(a, b))
928 } else {
929 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
930 };
931
932 Ok(min_index)
933 }
934}
935
936impl Accumulator for FirstValueAccumulator {
937 fn state(&mut self) -> Result<Vec<ScalarValue>> {
938 let mut result = vec![self.first.clone()];
939 result.extend(self.orderings.iter().cloned());
940 result.push(ScalarValue::from(self.is_set));
941 Ok(result)
942 }
943
944 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
945 if let Some(first_idx) = self.get_first_idx(values)? {
946 let row = get_row_at_idx(values, first_idx)?;
947 if !self.is_set
948 || (!self.is_input_pre_ordered
949 && compare_rows(
950 &self.orderings,
951 &row[1..],
952 &get_sort_options(&self.ordering_req),
953 )?
954 .is_gt())
955 {
956 self.update_with_new_row(row);
957 }
958 }
959 Ok(())
960 }
961
962 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
963 let is_set_idx = states.len() - 1;
966 let flags = states[is_set_idx].as_boolean();
967 validate_is_set_flags(flags, "first_value")?;
968
969 let filtered_states =
970 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
971 let sort_columns =
973 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
974
975 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
976 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
977
978 if let Some(first_idx) = min {
979 let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
980 let first_ordering = &first_row[1..is_set_idx];
982 let sort_options = get_sort_options(&self.ordering_req);
983 if !self.is_set
985 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
986 {
987 assert!(is_set_idx <= first_row.len());
991 first_row.resize(is_set_idx, ScalarValue::Null);
992 self.update_with_new_row(first_row);
993 }
994 }
995 Ok(())
996 }
997
998 fn evaluate(&mut self) -> Result<ScalarValue> {
999 Ok(self.first.clone())
1000 }
1001
1002 fn size(&self) -> usize {
1003 size_of_val(self) - size_of_val(&self.first)
1004 + self.first.size()
1005 + ScalarValue::size_of_vec(&self.orderings)
1006 - size_of_val(&self.orderings)
1007 }
1008}
1009
1010#[user_doc(
1011 doc_section(label = "General Functions"),
1012 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
1013 syntax_example = "last_value(expression [ORDER BY expression])",
1014 sql_example = r#"```sql
1015> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
1016+-----------------------------------------------+
1017| last_value(column_name ORDER BY other_column) |
1018+-----------------------------------------------+
1019| last_element |
1020+-----------------------------------------------+
1021```"#,
1022 standard_argument(name = "expression",)
1023)]
1024#[derive(PartialEq, Eq, Hash)]
1025pub struct LastValue {
1026 signature: Signature,
1027 is_input_pre_ordered: bool,
1028}
1029
1030impl Debug for LastValue {
1031 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1032 f.debug_struct("LastValue")
1033 .field("name", &self.name())
1034 .field("signature", &self.signature)
1035 .field("accumulator", &"<FUNC>")
1036 .finish()
1037 }
1038}
1039
1040impl Default for LastValue {
1041 fn default() -> Self {
1042 Self::new()
1043 }
1044}
1045
1046impl LastValue {
1047 pub fn new() -> Self {
1048 Self {
1049 signature: Signature::any(1, Volatility::Immutable),
1050 is_input_pre_ordered: false,
1051 }
1052 }
1053}
1054
1055impl AggregateUDFImpl for LastValue {
1056 fn as_any(&self) -> &dyn Any {
1057 self
1058 }
1059
1060 fn name(&self) -> &str {
1061 "last_value"
1062 }
1063
1064 fn signature(&self) -> &Signature {
1065 &self.signature
1066 }
1067
1068 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1069 Ok(arg_types[0].clone())
1070 }
1071
1072 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1073 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1074 return TrivialLastValueAccumulator::try_new(
1075 acc_args.return_field.data_type(),
1076 acc_args.ignore_nulls,
1077 )
1078 .map(|acc| Box::new(acc) as _);
1079 };
1080 let ordering_dtypes = ordering
1081 .iter()
1082 .map(|e| e.expr.data_type(acc_args.schema))
1083 .collect::<Result<Vec<_>>>()?;
1084 Ok(Box::new(LastValueAccumulator::try_new(
1085 acc_args.return_field.data_type(),
1086 &ordering_dtypes,
1087 ordering,
1088 self.is_input_pre_ordered,
1089 acc_args.ignore_nulls,
1090 )?))
1091 }
1092
1093 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1094 let mut fields = vec![Field::new(
1095 format_state_name(args.name, "last_value"),
1096 args.return_field.data_type().clone(),
1097 true,
1098 )
1099 .into()];
1100 fields.extend(args.ordering_fields.iter().cloned());
1101 fields.push(
1102 Field::new(
1103 format_state_name(args.name, "last_value_is_set"),
1104 DataType::Boolean,
1105 true,
1106 )
1107 .into(),
1108 );
1109 Ok(fields)
1110 }
1111
1112 fn with_beneficial_ordering(
1113 self: Arc<Self>,
1114 beneficial_ordering: bool,
1115 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1116 Ok(Some(Arc::new(Self {
1117 signature: self.signature.clone(),
1118 is_input_pre_ordered: beneficial_ordering,
1119 })))
1120 }
1121
1122 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1123 AggregateOrderSensitivity::Beneficial
1124 }
1125
1126 fn reverse_expr(&self) -> ReversedUDAF {
1127 ReversedUDAF::Reversed(first_value_udaf())
1128 }
1129
1130 fn documentation(&self) -> Option<&Documentation> {
1131 self.doc()
1132 }
1133
1134 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1135 use DataType::*;
1136 !args.order_bys.is_empty()
1137 && matches!(
1138 args.return_field.data_type(),
1139 Int8 | Int16
1140 | Int32
1141 | Int64
1142 | UInt8
1143 | UInt16
1144 | UInt32
1145 | UInt64
1146 | Float16
1147 | Float32
1148 | Float64
1149 | Decimal32(_, _)
1150 | Decimal64(_, _)
1151 | Decimal128(_, _)
1152 | Decimal256(_, _)
1153 | Date32
1154 | Date64
1155 | Time32(_)
1156 | Time64(_)
1157 | Timestamp(_, _)
1158 )
1159 }
1160
1161 fn create_groups_accumulator(
1162 &self,
1163 args: AccumulatorArgs,
1164 ) -> Result<Box<dyn GroupsAccumulator>> {
1165 fn create_accumulator<T>(
1166 args: AccumulatorArgs,
1167 ) -> Result<Box<dyn GroupsAccumulator>>
1168 where
1169 T: ArrowPrimitiveType + Send,
1170 {
1171 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
1172 return internal_err!("Groups accumulator must have an ordering.");
1173 };
1174
1175 let ordering_dtypes = ordering
1176 .iter()
1177 .map(|e| e.expr.data_type(args.schema))
1178 .collect::<Result<Vec<_>>>()?;
1179
1180 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1181 ordering,
1182 args.ignore_nulls,
1183 args.return_field.data_type(),
1184 &ordering_dtypes,
1185 false,
1186 )?))
1187 }
1188
1189 match args.return_field.data_type() {
1190 DataType::Int8 => create_accumulator::<Int8Type>(args),
1191 DataType::Int16 => create_accumulator::<Int16Type>(args),
1192 DataType::Int32 => create_accumulator::<Int32Type>(args),
1193 DataType::Int64 => create_accumulator::<Int64Type>(args),
1194 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1195 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1196 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1197 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1198 DataType::Float16 => create_accumulator::<Float16Type>(args),
1199 DataType::Float32 => create_accumulator::<Float32Type>(args),
1200 DataType::Float64 => create_accumulator::<Float64Type>(args),
1201
1202 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(args),
1203 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(args),
1204 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1205 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1206
1207 DataType::Timestamp(TimeUnit::Second, _) => {
1208 create_accumulator::<TimestampSecondType>(args)
1209 }
1210 DataType::Timestamp(TimeUnit::Millisecond, _) => {
1211 create_accumulator::<TimestampMillisecondType>(args)
1212 }
1213 DataType::Timestamp(TimeUnit::Microsecond, _) => {
1214 create_accumulator::<TimestampMicrosecondType>(args)
1215 }
1216 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1217 create_accumulator::<TimestampNanosecondType>(args)
1218 }
1219
1220 DataType::Date32 => create_accumulator::<Date32Type>(args),
1221 DataType::Date64 => create_accumulator::<Date64Type>(args),
1222 DataType::Time32(TimeUnit::Second) => {
1223 create_accumulator::<Time32SecondType>(args)
1224 }
1225 DataType::Time32(TimeUnit::Millisecond) => {
1226 create_accumulator::<Time32MillisecondType>(args)
1227 }
1228
1229 DataType::Time64(TimeUnit::Microsecond) => {
1230 create_accumulator::<Time64MicrosecondType>(args)
1231 }
1232 DataType::Time64(TimeUnit::Nanosecond) => {
1233 create_accumulator::<Time64NanosecondType>(args)
1234 }
1235
1236 _ => {
1237 internal_err!(
1238 "GroupsAccumulator not supported for last_value({})",
1239 args.return_field.data_type()
1240 )
1241 }
1242 }
1243 }
1244}
1245
1246#[derive(Debug)]
1251pub struct TrivialLastValueAccumulator {
1252 last: ScalarValue,
1253 is_set: bool,
1257 ignore_nulls: bool,
1259}
1260
1261impl TrivialLastValueAccumulator {
1262 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1264 ScalarValue::try_from(data_type).map(|last| Self {
1265 last,
1266 is_set: false,
1267 ignore_nulls,
1268 })
1269 }
1270}
1271
1272impl Accumulator for TrivialLastValueAccumulator {
1273 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1274 Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1275 }
1276
1277 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1278 let value = &values[0];
1280 let mut last_idx = None;
1281 if self.ignore_nulls {
1282 for i in (0..value.len()).rev() {
1284 if !value.is_null(i) {
1285 last_idx = Some(i);
1286 break;
1287 }
1288 }
1289 } else if !value.is_empty() {
1290 last_idx = Some(value.len() - 1);
1292 }
1293 if let Some(last_idx) = last_idx {
1294 let mut row = get_row_at_idx(values, last_idx)?;
1295 self.last = row.swap_remove(0);
1296 self.last.compact();
1297 self.is_set = true;
1298 }
1299 Ok(())
1300 }
1301
1302 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1303 let flags = states[1].as_boolean();
1306 validate_is_set_flags(flags, "last_value")?;
1307
1308 let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1309 if let Some(last) = filtered_states.last() {
1310 if !last.is_empty() {
1311 self.last = ScalarValue::try_from_array(last, 0)?;
1312 self.is_set = true;
1313 }
1314 }
1315 Ok(())
1316 }
1317
1318 fn evaluate(&mut self) -> Result<ScalarValue> {
1319 Ok(self.last.clone())
1320 }
1321
1322 fn size(&self) -> usize {
1323 size_of_val(self) - size_of_val(&self.last) + self.last.size()
1324 }
1325}
1326
1327#[derive(Debug)]
1328struct LastValueAccumulator {
1329 last: ScalarValue,
1330 is_set: bool,
1334 orderings: Vec<ScalarValue>,
1337 ordering_req: LexOrdering,
1339 is_input_pre_ordered: bool,
1341 ignore_nulls: bool,
1343}
1344
1345impl LastValueAccumulator {
1346 pub fn try_new(
1348 data_type: &DataType,
1349 ordering_dtypes: &[DataType],
1350 ordering_req: LexOrdering,
1351 is_input_pre_ordered: bool,
1352 ignore_nulls: bool,
1353 ) -> Result<Self> {
1354 let orderings = ordering_dtypes
1355 .iter()
1356 .map(ScalarValue::try_from)
1357 .collect::<Result<_>>()?;
1358 ScalarValue::try_from(data_type).map(|last| Self {
1359 last,
1360 is_set: false,
1361 orderings,
1362 ordering_req,
1363 is_input_pre_ordered,
1364 ignore_nulls,
1365 })
1366 }
1367
1368 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1370 for s in row.iter_mut() {
1372 s.compact();
1373 }
1374 self.last = row.remove(0);
1375 self.orderings = row;
1376 self.is_set = true;
1377 }
1378
1379 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1380 let [value, ordering_values @ ..] = values else {
1381 return internal_err!("Empty row in LAST_VALUE");
1382 };
1383 if self.is_input_pre_ordered {
1384 if self.ignore_nulls {
1386 for i in (0..value.len()).rev() {
1388 if !value.is_null(i) {
1389 return Ok(Some(i));
1390 }
1391 }
1392 return Ok(None);
1393 } else {
1394 return Ok((!value.is_empty()).then_some(value.len() - 1));
1395 }
1396 }
1397
1398 let sort_columns = ordering_values
1399 .iter()
1400 .zip(self.ordering_req.iter())
1401 .map(|(values, req)| SortColumn {
1402 values: Arc::clone(values),
1403 options: Some(req.options),
1404 })
1405 .collect::<Vec<_>>();
1406
1407 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1408 let max_ind = if self.ignore_nulls {
1409 (0..value.len())
1410 .filter(|&index| !(value.is_null(index)))
1411 .max_by(|&a, &b| comparator.compare(a, b))
1412 } else {
1413 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1414 };
1415
1416 Ok(max_ind)
1417 }
1418}
1419
1420impl Accumulator for LastValueAccumulator {
1421 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1422 let mut result = vec![self.last.clone()];
1423 result.extend(self.orderings.clone());
1424 result.push(ScalarValue::from(self.is_set));
1425 Ok(result)
1426 }
1427
1428 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1429 if let Some(last_idx) = self.get_last_idx(values)? {
1430 let row = get_row_at_idx(values, last_idx)?;
1431 let orderings = &row[1..];
1432 if !self.is_set
1434 || self.is_input_pre_ordered
1435 || compare_rows(
1436 &self.orderings,
1437 orderings,
1438 &get_sort_options(&self.ordering_req),
1439 )?
1440 .is_lt()
1441 {
1442 self.update_with_new_row(row);
1443 }
1444 }
1445 Ok(())
1446 }
1447
1448 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1449 let is_set_idx = states.len() - 1;
1452 let flags = states[is_set_idx].as_boolean();
1453 validate_is_set_flags(flags, "last_value")?;
1454
1455 let filtered_states =
1456 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1457 let sort_columns =
1459 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1460
1461 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1462 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1463
1464 if let Some(last_idx) = max {
1465 let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1466 let last_ordering = &last_row[1..is_set_idx];
1468 let sort_options = get_sort_options(&self.ordering_req);
1469 if !self.is_set
1472 || self.is_input_pre_ordered
1473 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1474 {
1475 assert!(is_set_idx <= last_row.len());
1479 last_row.resize(is_set_idx, ScalarValue::Null);
1480 self.update_with_new_row(last_row);
1481 }
1482 }
1483 Ok(())
1484 }
1485
1486 fn evaluate(&mut self) -> Result<ScalarValue> {
1487 Ok(self.last.clone())
1488 }
1489
1490 fn size(&self) -> usize {
1491 size_of_val(self) - size_of_val(&self.last)
1492 + self.last.size()
1493 + ScalarValue::size_of_vec(&self.orderings)
1494 - size_of_val(&self.orderings)
1495 }
1496}
1497
1498fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1500 if flags.null_count() > 0 {
1501 return Err(DataFusionError::Internal(format!(
1502 "{function_name}: is_set flags contain nulls"
1503 )));
1504 }
1505 Ok(())
1506}
1507
1508fn filter_states_according_to_is_set(
1511 states: &[ArrayRef],
1512 flags: &BooleanArray,
1513) -> Result<Vec<ArrayRef>> {
1514 states
1515 .iter()
1516 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1517 .collect()
1518}
1519
1520fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1522 arrs.iter()
1523 .zip(sort_exprs.iter())
1524 .map(|(item, sort_expr)| SortColumn {
1525 values: Arc::clone(item),
1526 options: Some(sort_expr.options),
1527 })
1528 .collect()
1529}
1530
1531#[cfg(test)]
1532mod tests {
1533 use std::iter::repeat_with;
1534
1535 use arrow::{
1536 array::{BooleanArray, Int64Array, ListArray, StringArray},
1537 compute::SortOptions,
1538 datatypes::Schema,
1539 };
1540 use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1541
1542 use super::*;
1543
1544 #[test]
1545 fn test_first_last_value_value() -> Result<()> {
1546 let mut first_accumulator =
1547 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1548 let mut last_accumulator =
1549 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1550 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1553 let arrs = ranges
1555 .into_iter()
1556 .map(|(start, end)| {
1557 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1558 })
1559 .collect::<Vec<_>>();
1560 for arr in arrs {
1561 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1564 last_accumulator.update_batch(&[arr])?;
1566 }
1567 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1569 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1571 Ok(())
1572 }
1573
1574 #[test]
1575 fn test_first_last_state_after_merge() -> Result<()> {
1576 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1577 let arrs = ranges
1579 .into_iter()
1580 .map(|(start, end)| {
1581 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1582 })
1583 .collect::<Vec<_>>();
1584
1585 let mut first_accumulator =
1587 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1588
1589 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1590 let state1 = first_accumulator.state()?;
1591
1592 let mut first_accumulator =
1593 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1594 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1595 let state2 = first_accumulator.state()?;
1596
1597 assert_eq!(state1.len(), state2.len());
1598
1599 let mut states = vec![];
1600
1601 for idx in 0..state1.len() {
1602 states.push(compute::concat(&[
1603 &state1[idx].to_array()?,
1604 &state2[idx].to_array()?,
1605 ])?);
1606 }
1607
1608 let mut first_accumulator =
1609 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1610 first_accumulator.merge_batch(&states)?;
1611
1612 let merged_state = first_accumulator.state()?;
1613 assert_eq!(merged_state.len(), state1.len());
1614
1615 let mut last_accumulator =
1617 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1618
1619 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1620 let state1 = last_accumulator.state()?;
1621
1622 let mut last_accumulator =
1623 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1624 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1625 let state2 = last_accumulator.state()?;
1626
1627 assert_eq!(state1.len(), state2.len());
1628
1629 let mut states = vec![];
1630
1631 for idx in 0..state1.len() {
1632 states.push(compute::concat(&[
1633 &state1[idx].to_array()?,
1634 &state2[idx].to_array()?,
1635 ])?);
1636 }
1637
1638 let mut last_accumulator =
1639 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1640 last_accumulator.merge_batch(&states)?;
1641
1642 let merged_state = last_accumulator.state()?;
1643 assert_eq!(merged_state.len(), state1.len());
1644
1645 Ok(())
1646 }
1647
1648 #[test]
1649 fn test_first_group_acc() -> Result<()> {
1650 let schema = Arc::new(Schema::new(vec![
1651 Field::new("a", DataType::Int64, true),
1652 Field::new("b", DataType::Int64, true),
1653 Field::new("c", DataType::Int64, true),
1654 Field::new("d", DataType::Int32, true),
1655 Field::new("e", DataType::Boolean, true),
1656 ]));
1657
1658 let sort_keys = [PhysicalSortExpr {
1659 expr: col("c", &schema).unwrap(),
1660 options: SortOptions::default(),
1661 }];
1662
1663 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1664 sort_keys.into(),
1665 true,
1666 &DataType::Int64,
1667 &[DataType::Int64],
1668 true,
1669 )?;
1670
1671 let mut val_with_orderings = {
1672 let mut val_with_orderings = Vec::<ArrayRef>::new();
1673
1674 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1675 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1676
1677 val_with_orderings.push(vals);
1678 val_with_orderings.push(orderings);
1679
1680 val_with_orderings
1681 };
1682
1683 group_acc.update_batch(
1684 &val_with_orderings,
1685 &[0, 1, 2, 1],
1686 Some(&BooleanArray::from(vec![true, true, false, true])),
1687 3,
1688 )?;
1689 assert_eq!(
1690 group_acc.size_of_orderings,
1691 group_acc.compute_size_of_orderings()
1692 );
1693
1694 let state = group_acc.state(EmitTo::All)?;
1695
1696 let expected_state: Vec<Arc<dyn Array>> = vec![
1697 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1698 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1699 Arc::new(BooleanArray::from(vec![true, true, false])),
1700 ];
1701 assert_eq!(state, expected_state);
1702
1703 assert_eq!(
1704 group_acc.size_of_orderings,
1705 group_acc.compute_size_of_orderings()
1706 );
1707
1708 group_acc.merge_batch(
1709 &state,
1710 &[0, 1, 2],
1711 Some(&BooleanArray::from(vec![true, false, false])),
1712 3,
1713 )?;
1714
1715 assert_eq!(
1716 group_acc.size_of_orderings,
1717 group_acc.compute_size_of_orderings()
1718 );
1719
1720 val_with_orderings.clear();
1721 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1722 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1723
1724 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1725
1726 let binding = group_acc.evaluate(EmitTo::All)?;
1727 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1728
1729 let expect: PrimitiveArray<Int64Type> =
1730 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1731
1732 assert_eq!(eval_result, &expect);
1733
1734 assert_eq!(
1735 group_acc.size_of_orderings,
1736 group_acc.compute_size_of_orderings()
1737 );
1738
1739 Ok(())
1740 }
1741
1742 #[test]
1743 fn test_group_acc_size_of_ordering() -> Result<()> {
1744 let schema = Arc::new(Schema::new(vec![
1745 Field::new("a", DataType::Int64, true),
1746 Field::new("b", DataType::Int64, true),
1747 Field::new("c", DataType::Int64, true),
1748 Field::new("d", DataType::Int32, true),
1749 Field::new("e", DataType::Boolean, true),
1750 ]));
1751
1752 let sort_keys = [PhysicalSortExpr {
1753 expr: col("c", &schema).unwrap(),
1754 options: SortOptions::default(),
1755 }];
1756
1757 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1758 sort_keys.into(),
1759 true,
1760 &DataType::Int64,
1761 &[DataType::Int64],
1762 true,
1763 )?;
1764
1765 let val_with_orderings = {
1766 let mut val_with_orderings = Vec::<ArrayRef>::new();
1767
1768 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1769 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1770
1771 val_with_orderings.push(vals);
1772 val_with_orderings.push(orderings);
1773
1774 val_with_orderings
1775 };
1776
1777 for _ in 0..10 {
1778 group_acc.update_batch(
1779 &val_with_orderings,
1780 &[0, 1, 2, 1],
1781 Some(&BooleanArray::from(vec![true, true, false, true])),
1782 100,
1783 )?;
1784 assert_eq!(
1785 group_acc.size_of_orderings,
1786 group_acc.compute_size_of_orderings()
1787 );
1788
1789 group_acc.state(EmitTo::First(2))?;
1790 assert_eq!(
1791 group_acc.size_of_orderings,
1792 group_acc.compute_size_of_orderings()
1793 );
1794
1795 let s = group_acc.state(EmitTo::All)?;
1796 assert_eq!(
1797 group_acc.size_of_orderings,
1798 group_acc.compute_size_of_orderings()
1799 );
1800
1801 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1802 assert_eq!(
1803 group_acc.size_of_orderings,
1804 group_acc.compute_size_of_orderings()
1805 );
1806
1807 group_acc.evaluate(EmitTo::First(2))?;
1808 assert_eq!(
1809 group_acc.size_of_orderings,
1810 group_acc.compute_size_of_orderings()
1811 );
1812
1813 group_acc.evaluate(EmitTo::All)?;
1814 assert_eq!(
1815 group_acc.size_of_orderings,
1816 group_acc.compute_size_of_orderings()
1817 );
1818 }
1819
1820 Ok(())
1821 }
1822
1823 #[test]
1824 fn test_last_group_acc() -> Result<()> {
1825 let schema = Arc::new(Schema::new(vec![
1826 Field::new("a", DataType::Int64, true),
1827 Field::new("b", DataType::Int64, true),
1828 Field::new("c", DataType::Int64, true),
1829 Field::new("d", DataType::Int32, true),
1830 Field::new("e", DataType::Boolean, true),
1831 ]));
1832
1833 let sort_keys = [PhysicalSortExpr {
1834 expr: col("c", &schema).unwrap(),
1835 options: SortOptions::default(),
1836 }];
1837
1838 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1839 sort_keys.into(),
1840 true,
1841 &DataType::Int64,
1842 &[DataType::Int64],
1843 false,
1844 )?;
1845
1846 let mut val_with_orderings = {
1847 let mut val_with_orderings = Vec::<ArrayRef>::new();
1848
1849 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1850 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1851
1852 val_with_orderings.push(vals);
1853 val_with_orderings.push(orderings);
1854
1855 val_with_orderings
1856 };
1857
1858 group_acc.update_batch(
1859 &val_with_orderings,
1860 &[0, 1, 2, 1],
1861 Some(&BooleanArray::from(vec![true, true, false, true])),
1862 3,
1863 )?;
1864
1865 let state = group_acc.state(EmitTo::All)?;
1866
1867 let expected_state: Vec<Arc<dyn Array>> = vec![
1868 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1869 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1870 Arc::new(BooleanArray::from(vec![true, true, false])),
1871 ];
1872 assert_eq!(state, expected_state);
1873
1874 group_acc.merge_batch(
1875 &state,
1876 &[0, 1, 2],
1877 Some(&BooleanArray::from(vec![true, false, false])),
1878 3,
1879 )?;
1880
1881 val_with_orderings.clear();
1882 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1883 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1884
1885 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1886
1887 let binding = group_acc.evaluate(EmitTo::All)?;
1888 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1889
1890 let expect: PrimitiveArray<Int64Type> =
1891 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1892
1893 assert_eq!(eval_result, &expect);
1894
1895 Ok(())
1896 }
1897
1898 #[test]
1899 fn test_first_list_acc_size() -> Result<()> {
1900 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1901 let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1902 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1903 false,
1904 )?;
1905
1906 first_accumulator.update_batch(values)?;
1907
1908 Ok(first_accumulator.size())
1909 }
1910
1911 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1912 repeat_with(|| Some(vec![Some(1)])).take(10000),
1913 );
1914 let batch2 =
1915 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1916
1917 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1918 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1919 assert_eq!(size1, size2);
1920
1921 Ok(())
1922 }
1923
1924 #[test]
1925 fn test_last_list_acc_size() -> Result<()> {
1926 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1927 let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1928 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1929 false,
1930 )?;
1931
1932 last_accumulator.update_batch(values)?;
1933
1934 Ok(last_accumulator.size())
1935 }
1936
1937 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1938 repeat_with(|| Some(vec![Some(1)])).take(10000),
1939 );
1940 let batch2 =
1941 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1942
1943 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1944 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1945 assert_eq!(size1, size2);
1946
1947 Ok(())
1948 }
1949
1950 #[test]
1951 fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1952 let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1954 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1955
1956 let mut trivial_accumulator =
1958 TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1959 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1960 let result = trivial_accumulator.merge_batch(&trivial_states);
1961 assert!(result.is_err());
1962 assert!(result
1963 .unwrap_err()
1964 .to_string()
1965 .contains("is_set flags contain nulls"));
1966
1967 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1969 let ordering_expr = col("ordering", &schema)?;
1970 let mut ordered_accumulator = FirstValueAccumulator::try_new(
1971 &DataType::Utf8,
1972 &[DataType::Int64],
1973 LexOrdering::new(vec![PhysicalSortExpr {
1974 expr: ordering_expr,
1975 options: SortOptions::default(),
1976 }])
1977 .unwrap(),
1978 false,
1979 false,
1980 )?;
1981 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1982 let ordered_states = vec![value, ordering, corrupted_flag];
1983 let result = ordered_accumulator.merge_batch(&ordered_states);
1984 assert!(result.is_err());
1985 assert!(result
1986 .unwrap_err()
1987 .to_string()
1988 .contains("is_set flags contain nulls"));
1989
1990 Ok(())
1991 }
1992
1993 #[test]
1994 fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
1995 let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
1997 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1998
1999 let mut trivial_accumulator =
2001 TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
2002 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
2003 let result = trivial_accumulator.merge_batch(&trivial_states);
2004 assert!(result.is_err());
2005 assert!(result
2006 .unwrap_err()
2007 .to_string()
2008 .contains("is_set flags contain nulls"));
2009
2010 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
2012 let ordering_expr = col("ordering", &schema)?;
2013 let mut ordered_accumulator = LastValueAccumulator::try_new(
2014 &DataType::Utf8,
2015 &[DataType::Int64],
2016 LexOrdering::new(vec![PhysicalSortExpr {
2017 expr: ordering_expr,
2018 options: SortOptions::default(),
2019 }])
2020 .unwrap(),
2021 false,
2022 false,
2023 )?;
2024 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2025 let ordered_states = vec![value, ordering, corrupted_flag];
2026 let result = ordered_accumulator.merge_batch(&ordered_states);
2027 assert!(result.is_err());
2028 assert!(result
2029 .unwrap_err()
2030 .to_string()
2031 .contains("is_set flags contain nulls"));
2032
2033 Ok(())
2034 }
2035}