1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::size_of_val;
24use std::sync::Arc;
25
26use arrow::array::{
27 Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
28 PrimitiveArray,
29};
30use arrow::buffer::{BooleanBuffer, NullBuffer};
31use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
32use arrow::datatypes::{
33 DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type,
34 Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int8Type,
35 Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
36 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
38 UInt16Type, UInt32Type, UInt64Type,
39};
40use datafusion_common::cast::as_boolean_array;
41use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
42use datafusion_common::{
43 DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err,
44 not_impl_err,
45};
46use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
47use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
48use datafusion_expr::{
49 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
50 GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
51};
52use datafusion_functions_aggregate_common::utils::get_sort_options;
53use datafusion_macros::user_doc;
54use datafusion_physical_expr_common::sort_expr::LexOrdering;
55
56create_func!(FirstValue, first_value_udaf);
57create_func!(LastValue, last_value_udaf);
58
59pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
61 first_value_udaf()
62 .call(vec![expression])
63 .order_by(order_by)
64 .build()
65 .unwrap()
67}
68
69pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
71 last_value_udaf()
72 .call(vec![expression])
73 .order_by(order_by)
74 .build()
75 .unwrap()
77}
78
79#[user_doc(
80 doc_section(label = "General Functions"),
81 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
82 syntax_example = "first_value(expression [ORDER BY expression])",
83 sql_example = r#"```sql
84> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
85+-----------------------------------------------+
86| first_value(column_name ORDER BY other_column)|
87+-----------------------------------------------+
88| first_element |
89+-----------------------------------------------+
90```"#,
91 standard_argument(name = "expression",)
92)]
93#[derive(PartialEq, Eq, Hash)]
94pub struct FirstValue {
95 signature: Signature,
96 is_input_pre_ordered: bool,
97}
98
99impl Debug for FirstValue {
100 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
101 f.debug_struct("FirstValue")
102 .field("name", &self.name())
103 .field("signature", &self.signature)
104 .field("accumulator", &"<FUNC>")
105 .finish()
106 }
107}
108
109impl Default for FirstValue {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl FirstValue {
116 pub fn new() -> Self {
117 Self {
118 signature: Signature::any(1, Volatility::Immutable),
119 is_input_pre_ordered: false,
120 }
121 }
122}
123
124impl AggregateUDFImpl for FirstValue {
125 fn as_any(&self) -> &dyn Any {
126 self
127 }
128
129 fn name(&self) -> &str {
130 "first_value"
131 }
132
133 fn signature(&self) -> &Signature {
134 &self.signature
135 }
136
137 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
138 not_impl_err!("Not called because the return_field_from_args is implemented")
139 }
140
141 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
142 Ok(Arc::new(
144 Field::new(
145 self.name(),
146 arg_fields[0].data_type().clone(),
147 true, )
149 .with_metadata(arg_fields[0].metadata().clone()),
150 ))
151 }
152
153 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
154 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
155 return TrivialFirstValueAccumulator::try_new(
156 acc_args.return_field.data_type(),
157 acc_args.ignore_nulls,
158 )
159 .map(|acc| Box::new(acc) as _);
160 };
161 let ordering_dtypes = ordering
162 .iter()
163 .map(|e| e.expr.data_type(acc_args.schema))
164 .collect::<Result<Vec<_>>>()?;
165 Ok(Box::new(FirstValueAccumulator::try_new(
166 acc_args.return_field.data_type(),
167 &ordering_dtypes,
168 ordering,
169 self.is_input_pre_ordered,
170 acc_args.ignore_nulls,
171 )?))
172 }
173
174 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
175 let mut fields = vec![
176 Field::new(
177 format_state_name(args.name, "first_value"),
178 args.return_type().clone(),
179 true,
180 )
181 .into(),
182 ];
183 fields.extend(args.ordering_fields.iter().cloned());
184 fields.push(
185 Field::new(
186 format_state_name(args.name, "first_value_is_set"),
187 DataType::Boolean,
188 true,
189 )
190 .into(),
191 );
192 Ok(fields)
193 }
194
195 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
196 use DataType::*;
197 !args.order_bys.is_empty()
198 && matches!(
199 args.return_field.data_type(),
200 Int8 | Int16
201 | Int32
202 | Int64
203 | UInt8
204 | UInt16
205 | UInt32
206 | UInt64
207 | Float16
208 | Float32
209 | Float64
210 | Decimal32(_, _)
211 | Decimal64(_, _)
212 | Decimal128(_, _)
213 | Decimal256(_, _)
214 | Date32
215 | Date64
216 | Time32(_)
217 | Time64(_)
218 | Timestamp(_, _)
219 )
220 }
221
222 fn create_groups_accumulator(
223 &self,
224 args: AccumulatorArgs,
225 ) -> Result<Box<dyn GroupsAccumulator>> {
226 fn create_accumulator<T: ArrowPrimitiveType + Send>(
227 args: &AccumulatorArgs,
228 ) -> Result<Box<dyn GroupsAccumulator>> {
229 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
230 return internal_err!("Groups accumulator must have an ordering.");
231 };
232
233 let ordering_dtypes = ordering
234 .iter()
235 .map(|e| e.expr.data_type(args.schema))
236 .collect::<Result<Vec<_>>>()?;
237
238 FirstPrimitiveGroupsAccumulator::<T>::try_new(
239 ordering,
240 args.ignore_nulls,
241 args.return_field.data_type(),
242 &ordering_dtypes,
243 true,
244 )
245 .map(|acc| Box::new(acc) as _)
246 }
247
248 match args.return_field.data_type() {
249 DataType::Int8 => create_accumulator::<Int8Type>(&args),
250 DataType::Int16 => create_accumulator::<Int16Type>(&args),
251 DataType::Int32 => create_accumulator::<Int32Type>(&args),
252 DataType::Int64 => create_accumulator::<Int64Type>(&args),
253 DataType::UInt8 => create_accumulator::<UInt8Type>(&args),
254 DataType::UInt16 => create_accumulator::<UInt16Type>(&args),
255 DataType::UInt32 => create_accumulator::<UInt32Type>(&args),
256 DataType::UInt64 => create_accumulator::<UInt64Type>(&args),
257 DataType::Float16 => create_accumulator::<Float16Type>(&args),
258 DataType::Float32 => create_accumulator::<Float32Type>(&args),
259 DataType::Float64 => create_accumulator::<Float64Type>(&args),
260
261 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(&args),
262 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(&args),
263 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(&args),
264 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(&args),
265
266 DataType::Timestamp(TimeUnit::Second, _) => {
267 create_accumulator::<TimestampSecondType>(&args)
268 }
269 DataType::Timestamp(TimeUnit::Millisecond, _) => {
270 create_accumulator::<TimestampMillisecondType>(&args)
271 }
272 DataType::Timestamp(TimeUnit::Microsecond, _) => {
273 create_accumulator::<TimestampMicrosecondType>(&args)
274 }
275 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
276 create_accumulator::<TimestampNanosecondType>(&args)
277 }
278
279 DataType::Date32 => create_accumulator::<Date32Type>(&args),
280 DataType::Date64 => create_accumulator::<Date64Type>(&args),
281 DataType::Time32(TimeUnit::Second) => {
282 create_accumulator::<Time32SecondType>(&args)
283 }
284 DataType::Time32(TimeUnit::Millisecond) => {
285 create_accumulator::<Time32MillisecondType>(&args)
286 }
287
288 DataType::Time64(TimeUnit::Microsecond) => {
289 create_accumulator::<Time64MicrosecondType>(&args)
290 }
291 DataType::Time64(TimeUnit::Nanosecond) => {
292 create_accumulator::<Time64NanosecondType>(&args)
293 }
294
295 _ => internal_err!(
296 "GroupsAccumulator not supported for first_value({})",
297 args.return_field.data_type()
298 ),
299 }
300 }
301
302 fn with_beneficial_ordering(
303 self: Arc<Self>,
304 beneficial_ordering: bool,
305 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
306 Ok(Some(Arc::new(Self {
307 signature: self.signature.clone(),
308 is_input_pre_ordered: beneficial_ordering,
309 })))
310 }
311
312 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
313 AggregateOrderSensitivity::Beneficial
314 }
315
316 fn reverse_expr(&self) -> ReversedUDAF {
317 ReversedUDAF::Reversed(last_value_udaf())
318 }
319
320 fn supports_null_handling_clause(&self) -> bool {
321 true
322 }
323
324 fn documentation(&self) -> Option<&Documentation> {
325 self.doc()
326 }
327}
328
329struct FirstPrimitiveGroupsAccumulator<T>
331where
332 T: ArrowPrimitiveType + Send,
333{
334 vals: Vec<T::Native>,
336 orderings: Vec<Vec<ScalarValue>>,
341 is_sets: BooleanBufferBuilder,
344 null_builder: BooleanBufferBuilder,
346 size_of_orderings: usize,
351
352 min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
357
358 ordering_req: LexOrdering,
362 pick_first_in_group: bool,
365 sort_options: Vec<SortOptions>,
367 ignore_nulls: bool,
369 data_type: DataType,
371 default_orderings: Vec<ScalarValue>,
372}
373
374impl<T> FirstPrimitiveGroupsAccumulator<T>
375where
376 T: ArrowPrimitiveType + Send,
377{
378 fn try_new(
379 ordering_req: LexOrdering,
380 ignore_nulls: bool,
381 data_type: &DataType,
382 ordering_dtypes: &[DataType],
383 pick_first_in_group: bool,
384 ) -> Result<Self> {
385 let default_orderings = ordering_dtypes
386 .iter()
387 .map(ScalarValue::try_from)
388 .collect::<Result<_>>()?;
389
390 let sort_options = get_sort_options(&ordering_req);
391
392 Ok(Self {
393 null_builder: BooleanBufferBuilder::new(0),
394 ordering_req,
395 sort_options,
396 ignore_nulls,
397 default_orderings,
398 data_type: data_type.clone(),
399 vals: Vec::new(),
400 orderings: Vec::new(),
401 is_sets: BooleanBufferBuilder::new(0),
402 size_of_orderings: 0,
403 min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
404 pick_first_in_group,
405 })
406 }
407
408 fn should_update_state(
409 &self,
410 group_idx: usize,
411 new_ordering_values: &[ScalarValue],
412 ) -> Result<bool> {
413 if !self.is_sets.get_bit(group_idx) {
414 return Ok(true);
415 }
416
417 assert!(new_ordering_values.len() == self.ordering_req.len());
418 let current_ordering = &self.orderings[group_idx];
419 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
420 if self.pick_first_in_group {
421 x.is_gt()
422 } else {
423 x.is_lt()
424 }
425 })
426 }
427
428 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
429 let result = emit_to.take_needed(&mut self.orderings);
430
431 match emit_to {
432 EmitTo::All => self.size_of_orderings = 0,
433 EmitTo::First(_) => {
434 self.size_of_orderings -=
435 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
436 }
437 }
438
439 result
440 }
441
442 fn take_need(
443 bool_buf_builder: &mut BooleanBufferBuilder,
444 emit_to: EmitTo,
445 ) -> BooleanBuffer {
446 let bool_buf = bool_buf_builder.finish();
447 match emit_to {
448 EmitTo::All => bool_buf,
449 EmitTo::First(n) => {
450 let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
455 for b in bool_buf.iter().skip(n) {
457 bool_buf_builder.append(b);
458 }
459 first_n
460 }
461 }
462 }
463
464 fn resize_states(&mut self, new_size: usize) {
465 self.vals.resize(new_size, T::default_value());
466
467 self.null_builder.resize(new_size);
468
469 if self.orderings.len() < new_size {
470 let current_len = self.orderings.len();
471
472 self.orderings
473 .resize(new_size, self.default_orderings.clone());
474
475 self.size_of_orderings += (new_size - current_len)
476 * ScalarValue::size_of_vec(
477 self.orderings.last().unwrap(),
481 );
482 }
483
484 self.is_sets.resize(new_size);
485
486 self.min_of_each_group_buf.0.resize(new_size, 0);
487 self.min_of_each_group_buf.1.resize(new_size);
488 }
489
490 fn update_state(
491 &mut self,
492 group_idx: usize,
493 orderings: &[ScalarValue],
494 new_val: T::Native,
495 is_null: bool,
496 ) {
497 self.vals[group_idx] = new_val;
498 self.is_sets.set_bit(group_idx, true);
499
500 self.null_builder.set_bit(group_idx, !is_null);
501
502 assert!(orderings.len() == self.ordering_req.len());
503 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
504 self.orderings[group_idx].clear();
505 self.orderings[group_idx].extend_from_slice(orderings);
506 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
507 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
508 }
509
510 fn take_state(
511 &mut self,
512 emit_to: EmitTo,
513 ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
514 emit_to.take_needed(&mut self.min_of_each_group_buf.0);
515 self.min_of_each_group_buf
516 .1
517 .truncate(self.min_of_each_group_buf.0.len());
518
519 (
520 self.take_vals_and_null_buf(emit_to),
521 self.take_orderings(emit_to),
522 Self::take_need(&mut self.is_sets, emit_to),
523 )
524 }
525
526 #[cfg(test)]
528 fn compute_size_of_orderings(&self) -> usize {
529 self.orderings
530 .iter()
531 .map(ScalarValue::size_of_vec)
532 .sum::<usize>()
533 }
534 fn get_filtered_min_of_each_group(
539 &mut self,
540 orderings: &[ArrayRef],
541 group_indices: &[usize],
542 opt_filter: Option<&BooleanArray>,
543 vals: &PrimitiveArray<T>,
544 is_set_arr: Option<&BooleanArray>,
545 ) -> Result<Vec<(usize, usize)>> {
546 self.min_of_each_group_buf.1.truncate(0);
548 self.min_of_each_group_buf
549 .1
550 .append_n(self.vals.len(), false);
551
552 let comparator = {
556 assert_eq!(orderings.len(), self.ordering_req.len());
557 let sort_columns = orderings
558 .iter()
559 .zip(self.ordering_req.iter())
560 .map(|(array, req)| SortColumn {
561 values: Arc::clone(array),
562 options: Some(req.options),
563 })
564 .collect::<Vec<_>>();
565
566 LexicographicalComparator::try_new(&sort_columns)?
567 };
568
569 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
570 let group_idx = *group_idx;
571
572 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
573 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
574
575 if !passed_filter || !is_set {
576 continue;
577 }
578
579 if self.ignore_nulls && vals.is_null(idx_in_val) {
580 continue;
581 }
582
583 let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
584
585 if !is_valid {
586 self.min_of_each_group_buf.1.set_bit(group_idx, true);
587 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
588 } else {
589 let ordering = comparator
590 .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
591
592 if (ordering.is_gt() && self.pick_first_in_group)
593 || (ordering.is_lt() && !self.pick_first_in_group)
594 {
595 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
596 }
597 }
598 }
599
600 Ok(self
601 .min_of_each_group_buf
602 .0
603 .iter()
604 .enumerate()
605 .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
606 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
607 .collect::<Vec<_>>())
608 }
609
610 fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
611 let r = emit_to.take_needed(&mut self.vals);
612
613 let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
614
615 let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) .with_data_type(self.data_type.clone());
617 Arc::new(values)
618 }
619}
620
621impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
622where
623 T: ArrowPrimitiveType + Send,
624{
625 fn update_batch(
626 &mut self,
627 values_and_order_cols: &[ArrayRef],
629 group_indices: &[usize],
630 opt_filter: Option<&BooleanArray>,
631 total_num_groups: usize,
632 ) -> Result<()> {
633 self.resize_states(total_num_groups);
634
635 let vals = values_and_order_cols[0].as_primitive::<T>();
636
637 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
638
639 for (group_idx, idx) in self
641 .get_filtered_min_of_each_group(
642 &values_and_order_cols[1..],
643 group_indices,
644 opt_filter,
645 vals,
646 None,
647 )?
648 .into_iter()
649 {
650 extract_row_at_idx_to_buf(
651 &values_and_order_cols[1..],
652 idx,
653 &mut ordering_buf,
654 )?;
655
656 if self.should_update_state(group_idx, &ordering_buf)? {
657 self.update_state(
658 group_idx,
659 &ordering_buf,
660 vals.value(idx),
661 vals.is_null(idx),
662 );
663 }
664 }
665
666 Ok(())
667 }
668
669 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
670 Ok(self.take_state(emit_to).0)
671 }
672
673 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
674 let (val_arr, orderings, is_sets) = self.take_state(emit_to);
675 let mut result = Vec::with_capacity(self.orderings.len() + 2);
676
677 result.push(val_arr);
678
679 let ordering_cols = {
680 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
681 for _ in 0..self.ordering_req.len() {
682 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
683 }
684 for row in orderings.into_iter() {
685 assert_eq!(row.len(), self.ordering_req.len());
686 for (col_idx, ordering) in row.into_iter().enumerate() {
687 ordering_cols[col_idx].push(ordering);
688 }
689 }
690
691 ordering_cols
692 };
693 for ordering_col in ordering_cols {
694 result.push(ScalarValue::iter_to_array(ordering_col)?);
695 }
696
697 result.push(Arc::new(BooleanArray::new(is_sets, None)));
698
699 Ok(result)
700 }
701
702 fn merge_batch(
703 &mut self,
704 values: &[ArrayRef],
705 group_indices: &[usize],
706 opt_filter: Option<&BooleanArray>,
707 total_num_groups: usize,
708 ) -> Result<()> {
709 self.resize_states(total_num_groups);
710
711 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
712
713 let (is_set_arr, val_and_order_cols) = match values.split_last() {
714 Some(result) => result,
715 None => return internal_err!("Empty row in FIRST_VALUE"),
716 };
717
718 let is_set_arr = as_boolean_array(is_set_arr)?;
719
720 let vals = values[0].as_primitive::<T>();
721 let groups = self.get_filtered_min_of_each_group(
723 &val_and_order_cols[1..],
724 group_indices,
725 opt_filter,
726 vals,
727 Some(is_set_arr),
728 )?;
729
730 for (group_idx, idx) in groups.into_iter() {
731 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
732
733 if self.should_update_state(group_idx, &ordering_buf)? {
734 self.update_state(
735 group_idx,
736 &ordering_buf,
737 vals.value(idx),
738 vals.is_null(idx),
739 );
740 }
741 }
742
743 Ok(())
744 }
745
746 fn size(&self) -> usize {
747 self.vals.capacity() * size_of::<T::Native>()
748 + self.null_builder.capacity() / 8 + self.is_sets.capacity() / 8
750 + self.size_of_orderings
751 + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
752 + self.min_of_each_group_buf.1.capacity() / 8
753 }
754
755 fn supports_convert_to_state(&self) -> bool {
756 true
757 }
758
759 fn convert_to_state(
760 &self,
761 values: &[ArrayRef],
762 opt_filter: Option<&BooleanArray>,
763 ) -> Result<Vec<ArrayRef>> {
764 let mut result = values.to_vec();
765 match opt_filter {
766 Some(f) => {
767 result.push(Arc::new(f.clone()));
768 Ok(result)
769 }
770 None => {
771 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
772 Ok(result)
773 }
774 }
775 }
776}
777
778#[derive(Debug)]
783pub struct TrivialFirstValueAccumulator {
784 first: ScalarValue,
785 is_set: bool,
787 ignore_nulls: bool,
789}
790
791impl TrivialFirstValueAccumulator {
792 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
794 ScalarValue::try_from(data_type).map(|first| Self {
795 first,
796 is_set: false,
797 ignore_nulls,
798 })
799 }
800}
801
802impl Accumulator for TrivialFirstValueAccumulator {
803 fn state(&mut self) -> Result<Vec<ScalarValue>> {
804 Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
805 }
806
807 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
808 if !self.is_set {
809 let value = &values[0];
811 let mut first_idx = None;
812 if self.ignore_nulls {
813 for i in 0..value.len() {
815 if !value.is_null(i) {
816 first_idx = Some(i);
817 break;
818 }
819 }
820 } else if !value.is_empty() {
821 first_idx = Some(0);
823 }
824 if let Some(first_idx) = first_idx {
825 let mut row = get_row_at_idx(values, first_idx)?;
826 self.first = row.swap_remove(0);
827 self.first.compact();
828 self.is_set = true;
829 }
830 }
831 Ok(())
832 }
833
834 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
835 if !self.is_set {
838 let flags = states[1].as_boolean();
839 validate_is_set_flags(flags, "first_value")?;
840
841 let filtered_states =
842 filter_states_according_to_is_set(&states[0..1], flags)?;
843 if let Some(first) = filtered_states.first()
844 && !first.is_empty()
845 {
846 self.first = ScalarValue::try_from_array(first, 0)?;
847 self.is_set = true;
848 }
849 }
850 Ok(())
851 }
852
853 fn evaluate(&mut self) -> Result<ScalarValue> {
854 Ok(self.first.clone())
855 }
856
857 fn size(&self) -> usize {
858 size_of_val(self) - size_of_val(&self.first) + self.first.size()
859 }
860}
861
862#[derive(Debug)]
863pub struct FirstValueAccumulator {
864 first: ScalarValue,
865 is_set: bool,
867 orderings: Vec<ScalarValue>,
870 ordering_req: LexOrdering,
872 is_input_pre_ordered: bool,
874 ignore_nulls: bool,
876}
877
878impl FirstValueAccumulator {
879 pub fn try_new(
881 data_type: &DataType,
882 ordering_dtypes: &[DataType],
883 ordering_req: LexOrdering,
884 is_input_pre_ordered: bool,
885 ignore_nulls: bool,
886 ) -> Result<Self> {
887 let orderings = ordering_dtypes
888 .iter()
889 .map(ScalarValue::try_from)
890 .collect::<Result<_>>()?;
891 ScalarValue::try_from(data_type).map(|first| Self {
892 first,
893 is_set: false,
894 orderings,
895 ordering_req,
896 is_input_pre_ordered,
897 ignore_nulls,
898 })
899 }
900
901 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
903 for s in row.iter_mut() {
905 s.compact();
906 }
907 self.first = row.remove(0);
908 self.orderings = row;
909 self.is_set = true;
910 }
911
912 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
913 let [value, ordering_values @ ..] = values else {
914 return internal_err!("Empty row in FIRST_VALUE");
915 };
916 if self.is_input_pre_ordered {
917 if self.ignore_nulls {
919 for i in 0..value.len() {
921 if !value.is_null(i) {
922 return Ok(Some(i));
923 }
924 }
925 return Ok(None);
926 } else {
927 return Ok((!value.is_empty()).then_some(0));
929 }
930 }
931
932 let sort_columns = ordering_values
933 .iter()
934 .zip(self.ordering_req.iter())
935 .map(|(values, req)| SortColumn {
936 values: Arc::clone(values),
937 options: Some(req.options),
938 })
939 .collect::<Vec<_>>();
940
941 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
942
943 let min_index = if self.ignore_nulls {
944 (0..value.len())
945 .filter(|&index| !value.is_null(index))
946 .min_by(|&a, &b| comparator.compare(a, b))
947 } else {
948 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
949 };
950
951 Ok(min_index)
952 }
953}
954
955impl Accumulator for FirstValueAccumulator {
956 fn state(&mut self) -> Result<Vec<ScalarValue>> {
957 let mut result = vec![self.first.clone()];
958 result.extend(self.orderings.iter().cloned());
959 result.push(ScalarValue::from(self.is_set));
960 Ok(result)
961 }
962
963 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
964 if let Some(first_idx) = self.get_first_idx(values)? {
965 let row = get_row_at_idx(values, first_idx)?;
966 if !self.is_set
967 || (!self.is_input_pre_ordered
968 && compare_rows(
969 &self.orderings,
970 &row[1..],
971 &get_sort_options(&self.ordering_req),
972 )?
973 .is_gt())
974 {
975 self.update_with_new_row(row);
976 }
977 }
978 Ok(())
979 }
980
981 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
982 let is_set_idx = states.len() - 1;
985 let flags = states[is_set_idx].as_boolean();
986 validate_is_set_flags(flags, "first_value")?;
987
988 let filtered_states =
989 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
990 let sort_columns =
992 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
993
994 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
995 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
996
997 if let Some(first_idx) = min {
998 let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
999 let first_ordering = &first_row[1..is_set_idx];
1001 let sort_options = get_sort_options(&self.ordering_req);
1002 if !self.is_set
1004 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
1005 {
1006 assert!(is_set_idx <= first_row.len());
1010 first_row.resize(is_set_idx, ScalarValue::Null);
1011 self.update_with_new_row(first_row);
1012 }
1013 }
1014 Ok(())
1015 }
1016
1017 fn evaluate(&mut self) -> Result<ScalarValue> {
1018 Ok(self.first.clone())
1019 }
1020
1021 fn size(&self) -> usize {
1022 size_of_val(self) - size_of_val(&self.first)
1023 + self.first.size()
1024 + ScalarValue::size_of_vec(&self.orderings)
1025 - size_of_val(&self.orderings)
1026 }
1027}
1028
1029#[user_doc(
1030 doc_section(label = "General Functions"),
1031 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
1032 syntax_example = "last_value(expression [ORDER BY expression])",
1033 sql_example = r#"```sql
1034> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
1035+-----------------------------------------------+
1036| last_value(column_name ORDER BY other_column) |
1037+-----------------------------------------------+
1038| last_element |
1039+-----------------------------------------------+
1040```"#,
1041 standard_argument(name = "expression",)
1042)]
1043#[derive(PartialEq, Eq, Hash)]
1044pub struct LastValue {
1045 signature: Signature,
1046 is_input_pre_ordered: bool,
1047}
1048
1049impl Debug for LastValue {
1050 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1051 f.debug_struct("LastValue")
1052 .field("name", &self.name())
1053 .field("signature", &self.signature)
1054 .field("accumulator", &"<FUNC>")
1055 .finish()
1056 }
1057}
1058
1059impl Default for LastValue {
1060 fn default() -> Self {
1061 Self::new()
1062 }
1063}
1064
1065impl LastValue {
1066 pub fn new() -> Self {
1067 Self {
1068 signature: Signature::any(1, Volatility::Immutable),
1069 is_input_pre_ordered: false,
1070 }
1071 }
1072}
1073
1074impl AggregateUDFImpl for LastValue {
1075 fn as_any(&self) -> &dyn Any {
1076 self
1077 }
1078
1079 fn name(&self) -> &str {
1080 "last_value"
1081 }
1082
1083 fn signature(&self) -> &Signature {
1084 &self.signature
1085 }
1086
1087 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1088 not_impl_err!("Not called because the return_field_from_args is implemented")
1089 }
1090
1091 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1092 Ok(Arc::new(
1094 Field::new(
1095 self.name(),
1096 arg_fields[0].data_type().clone(),
1097 true, )
1099 .with_metadata(arg_fields[0].metadata().clone()),
1100 ))
1101 }
1102
1103 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1104 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1105 return TrivialLastValueAccumulator::try_new(
1106 acc_args.return_field.data_type(),
1107 acc_args.ignore_nulls,
1108 )
1109 .map(|acc| Box::new(acc) as _);
1110 };
1111 let ordering_dtypes = ordering
1112 .iter()
1113 .map(|e| e.expr.data_type(acc_args.schema))
1114 .collect::<Result<Vec<_>>>()?;
1115 Ok(Box::new(LastValueAccumulator::try_new(
1116 acc_args.return_field.data_type(),
1117 &ordering_dtypes,
1118 ordering,
1119 self.is_input_pre_ordered,
1120 acc_args.ignore_nulls,
1121 )?))
1122 }
1123
1124 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1125 let mut fields = vec![
1126 Field::new(
1127 format_state_name(args.name, "last_value"),
1128 args.return_field.data_type().clone(),
1129 true,
1130 )
1131 .into(),
1132 ];
1133 fields.extend(args.ordering_fields.iter().cloned());
1134 fields.push(
1135 Field::new(
1136 format_state_name(args.name, "last_value_is_set"),
1137 DataType::Boolean,
1138 true,
1139 )
1140 .into(),
1141 );
1142 Ok(fields)
1143 }
1144
1145 fn with_beneficial_ordering(
1146 self: Arc<Self>,
1147 beneficial_ordering: bool,
1148 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1149 Ok(Some(Arc::new(Self {
1150 signature: self.signature.clone(),
1151 is_input_pre_ordered: beneficial_ordering,
1152 })))
1153 }
1154
1155 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1156 AggregateOrderSensitivity::Beneficial
1157 }
1158
1159 fn reverse_expr(&self) -> ReversedUDAF {
1160 ReversedUDAF::Reversed(first_value_udaf())
1161 }
1162
1163 fn supports_null_handling_clause(&self) -> bool {
1164 true
1165 }
1166
1167 fn documentation(&self) -> Option<&Documentation> {
1168 self.doc()
1169 }
1170
1171 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1172 use DataType::*;
1173 !args.order_bys.is_empty()
1174 && matches!(
1175 args.return_field.data_type(),
1176 Int8 | Int16
1177 | Int32
1178 | Int64
1179 | UInt8
1180 | UInt16
1181 | UInt32
1182 | UInt64
1183 | Float16
1184 | Float32
1185 | Float64
1186 | Decimal32(_, _)
1187 | Decimal64(_, _)
1188 | Decimal128(_, _)
1189 | Decimal256(_, _)
1190 | Date32
1191 | Date64
1192 | Time32(_)
1193 | Time64(_)
1194 | Timestamp(_, _)
1195 )
1196 }
1197
1198 fn create_groups_accumulator(
1199 &self,
1200 args: AccumulatorArgs,
1201 ) -> Result<Box<dyn GroupsAccumulator>> {
1202 fn create_accumulator<T>(
1203 args: &AccumulatorArgs,
1204 ) -> Result<Box<dyn GroupsAccumulator>>
1205 where
1206 T: ArrowPrimitiveType + Send,
1207 {
1208 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
1209 return internal_err!("Groups accumulator must have an ordering.");
1210 };
1211
1212 let ordering_dtypes = ordering
1213 .iter()
1214 .map(|e| e.expr.data_type(args.schema))
1215 .collect::<Result<Vec<_>>>()?;
1216
1217 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1218 ordering,
1219 args.ignore_nulls,
1220 args.return_field.data_type(),
1221 &ordering_dtypes,
1222 false,
1223 )?))
1224 }
1225
1226 match args.return_field.data_type() {
1227 DataType::Int8 => create_accumulator::<Int8Type>(&args),
1228 DataType::Int16 => create_accumulator::<Int16Type>(&args),
1229 DataType::Int32 => create_accumulator::<Int32Type>(&args),
1230 DataType::Int64 => create_accumulator::<Int64Type>(&args),
1231 DataType::UInt8 => create_accumulator::<UInt8Type>(&args),
1232 DataType::UInt16 => create_accumulator::<UInt16Type>(&args),
1233 DataType::UInt32 => create_accumulator::<UInt32Type>(&args),
1234 DataType::UInt64 => create_accumulator::<UInt64Type>(&args),
1235 DataType::Float16 => create_accumulator::<Float16Type>(&args),
1236 DataType::Float32 => create_accumulator::<Float32Type>(&args),
1237 DataType::Float64 => create_accumulator::<Float64Type>(&args),
1238
1239 DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(&args),
1240 DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(&args),
1241 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(&args),
1242 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(&args),
1243
1244 DataType::Timestamp(TimeUnit::Second, _) => {
1245 create_accumulator::<TimestampSecondType>(&args)
1246 }
1247 DataType::Timestamp(TimeUnit::Millisecond, _) => {
1248 create_accumulator::<TimestampMillisecondType>(&args)
1249 }
1250 DataType::Timestamp(TimeUnit::Microsecond, _) => {
1251 create_accumulator::<TimestampMicrosecondType>(&args)
1252 }
1253 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1254 create_accumulator::<TimestampNanosecondType>(&args)
1255 }
1256
1257 DataType::Date32 => create_accumulator::<Date32Type>(&args),
1258 DataType::Date64 => create_accumulator::<Date64Type>(&args),
1259 DataType::Time32(TimeUnit::Second) => {
1260 create_accumulator::<Time32SecondType>(&args)
1261 }
1262 DataType::Time32(TimeUnit::Millisecond) => {
1263 create_accumulator::<Time32MillisecondType>(&args)
1264 }
1265
1266 DataType::Time64(TimeUnit::Microsecond) => {
1267 create_accumulator::<Time64MicrosecondType>(&args)
1268 }
1269 DataType::Time64(TimeUnit::Nanosecond) => {
1270 create_accumulator::<Time64NanosecondType>(&args)
1271 }
1272
1273 _ => {
1274 internal_err!(
1275 "GroupsAccumulator not supported for last_value({})",
1276 args.return_field.data_type()
1277 )
1278 }
1279 }
1280 }
1281}
1282
1283#[derive(Debug)]
1288pub struct TrivialLastValueAccumulator {
1289 last: ScalarValue,
1290 is_set: bool,
1294 ignore_nulls: bool,
1296}
1297
1298impl TrivialLastValueAccumulator {
1299 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1301 ScalarValue::try_from(data_type).map(|last| Self {
1302 last,
1303 is_set: false,
1304 ignore_nulls,
1305 })
1306 }
1307}
1308
1309impl Accumulator for TrivialLastValueAccumulator {
1310 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1311 Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1312 }
1313
1314 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1315 let value = &values[0];
1317 let mut last_idx = None;
1318 if self.ignore_nulls {
1319 for i in (0..value.len()).rev() {
1321 if !value.is_null(i) {
1322 last_idx = Some(i);
1323 break;
1324 }
1325 }
1326 } else if !value.is_empty() {
1327 last_idx = Some(value.len() - 1);
1329 }
1330 if let Some(last_idx) = last_idx {
1331 let mut row = get_row_at_idx(values, last_idx)?;
1332 self.last = row.swap_remove(0);
1333 self.last.compact();
1334 self.is_set = true;
1335 }
1336 Ok(())
1337 }
1338
1339 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1340 let flags = states[1].as_boolean();
1343 validate_is_set_flags(flags, "last_value")?;
1344
1345 let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1346 if let Some(last) = filtered_states.last()
1347 && !last.is_empty()
1348 {
1349 self.last = ScalarValue::try_from_array(last, 0)?;
1350 self.is_set = true;
1351 }
1352 Ok(())
1353 }
1354
1355 fn evaluate(&mut self) -> Result<ScalarValue> {
1356 Ok(self.last.clone())
1357 }
1358
1359 fn size(&self) -> usize {
1360 size_of_val(self) - size_of_val(&self.last) + self.last.size()
1361 }
1362}
1363
1364#[derive(Debug)]
1365struct LastValueAccumulator {
1366 last: ScalarValue,
1367 is_set: bool,
1371 orderings: Vec<ScalarValue>,
1374 ordering_req: LexOrdering,
1376 is_input_pre_ordered: bool,
1378 ignore_nulls: bool,
1380}
1381
1382impl LastValueAccumulator {
1383 pub fn try_new(
1385 data_type: &DataType,
1386 ordering_dtypes: &[DataType],
1387 ordering_req: LexOrdering,
1388 is_input_pre_ordered: bool,
1389 ignore_nulls: bool,
1390 ) -> Result<Self> {
1391 let orderings = ordering_dtypes
1392 .iter()
1393 .map(ScalarValue::try_from)
1394 .collect::<Result<_>>()?;
1395 ScalarValue::try_from(data_type).map(|last| Self {
1396 last,
1397 is_set: false,
1398 orderings,
1399 ordering_req,
1400 is_input_pre_ordered,
1401 ignore_nulls,
1402 })
1403 }
1404
1405 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1407 for s in row.iter_mut() {
1409 s.compact();
1410 }
1411 self.last = row.remove(0);
1412 self.orderings = row;
1413 self.is_set = true;
1414 }
1415
1416 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1417 let [value, ordering_values @ ..] = values else {
1418 return internal_err!("Empty row in LAST_VALUE");
1419 };
1420 if self.is_input_pre_ordered {
1421 if self.ignore_nulls {
1423 for i in (0..value.len()).rev() {
1425 if !value.is_null(i) {
1426 return Ok(Some(i));
1427 }
1428 }
1429 return Ok(None);
1430 } else {
1431 return Ok((!value.is_empty()).then_some(value.len() - 1));
1432 }
1433 }
1434
1435 let sort_columns = ordering_values
1436 .iter()
1437 .zip(self.ordering_req.iter())
1438 .map(|(values, req)| SortColumn {
1439 values: Arc::clone(values),
1440 options: Some(req.options),
1441 })
1442 .collect::<Vec<_>>();
1443
1444 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1445 let max_ind = if self.ignore_nulls {
1446 (0..value.len())
1447 .filter(|&index| !(value.is_null(index)))
1448 .max_by(|&a, &b| comparator.compare(a, b))
1449 } else {
1450 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1451 };
1452
1453 Ok(max_ind)
1454 }
1455}
1456
1457impl Accumulator for LastValueAccumulator {
1458 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1459 let mut result = vec![self.last.clone()];
1460 result.extend(self.orderings.clone());
1461 result.push(ScalarValue::from(self.is_set));
1462 Ok(result)
1463 }
1464
1465 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1466 if let Some(last_idx) = self.get_last_idx(values)? {
1467 let row = get_row_at_idx(values, last_idx)?;
1468 let orderings = &row[1..];
1469 if !self.is_set
1471 || self.is_input_pre_ordered
1472 || compare_rows(
1473 &self.orderings,
1474 orderings,
1475 &get_sort_options(&self.ordering_req),
1476 )?
1477 .is_lt()
1478 {
1479 self.update_with_new_row(row);
1480 }
1481 }
1482 Ok(())
1483 }
1484
1485 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1486 let is_set_idx = states.len() - 1;
1489 let flags = states[is_set_idx].as_boolean();
1490 validate_is_set_flags(flags, "last_value")?;
1491
1492 let filtered_states =
1493 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1494 let sort_columns =
1496 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1497
1498 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1499 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1500
1501 if let Some(last_idx) = max {
1502 let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1503 let last_ordering = &last_row[1..is_set_idx];
1505 let sort_options = get_sort_options(&self.ordering_req);
1506 if !self.is_set
1509 || self.is_input_pre_ordered
1510 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1511 {
1512 assert!(is_set_idx <= last_row.len());
1516 last_row.resize(is_set_idx, ScalarValue::Null);
1517 self.update_with_new_row(last_row);
1518 }
1519 }
1520 Ok(())
1521 }
1522
1523 fn evaluate(&mut self) -> Result<ScalarValue> {
1524 Ok(self.last.clone())
1525 }
1526
1527 fn size(&self) -> usize {
1528 size_of_val(self) - size_of_val(&self.last)
1529 + self.last.size()
1530 + ScalarValue::size_of_vec(&self.orderings)
1531 - size_of_val(&self.orderings)
1532 }
1533}
1534
1535fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1537 if flags.null_count() > 0 {
1538 return Err(DataFusionError::Internal(format!(
1539 "{function_name}: is_set flags contain nulls"
1540 )));
1541 }
1542 Ok(())
1543}
1544
1545fn filter_states_according_to_is_set(
1548 states: &[ArrayRef],
1549 flags: &BooleanArray,
1550) -> Result<Vec<ArrayRef>> {
1551 states
1552 .iter()
1553 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1554 .collect()
1555}
1556
1557fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1559 arrs.iter()
1560 .zip(sort_exprs.iter())
1561 .map(|(item, sort_expr)| SortColumn {
1562 values: Arc::clone(item),
1563 options: Some(sort_expr.options),
1564 })
1565 .collect()
1566}
1567
1568#[cfg(test)]
1569mod tests {
1570 use std::iter::repeat_with;
1571
1572 use arrow::{
1573 array::{BooleanArray, Int64Array, ListArray, StringArray},
1574 compute::SortOptions,
1575 datatypes::Schema,
1576 };
1577 use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};
1578
1579 use super::*;
1580
1581 #[test]
1582 fn test_first_last_value_value() -> Result<()> {
1583 let mut first_accumulator =
1584 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1585 let mut last_accumulator =
1586 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1587 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1590 let arrs = ranges
1592 .into_iter()
1593 .map(|(start, end)| {
1594 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1595 })
1596 .collect::<Vec<_>>();
1597 for arr in arrs {
1598 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1601 last_accumulator.update_batch(&[arr])?;
1603 }
1604 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1606 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1608 Ok(())
1609 }
1610
1611 #[test]
1612 fn test_first_last_state_after_merge() -> Result<()> {
1613 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1614 let arrs = ranges
1616 .into_iter()
1617 .map(|(start, end)| {
1618 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1619 })
1620 .collect::<Vec<_>>();
1621
1622 let mut first_accumulator =
1624 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1625
1626 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1627 let state1 = first_accumulator.state()?;
1628
1629 let mut first_accumulator =
1630 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1631 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1632 let state2 = first_accumulator.state()?;
1633
1634 assert_eq!(state1.len(), state2.len());
1635
1636 let mut states = vec![];
1637
1638 for idx in 0..state1.len() {
1639 states.push(compute::concat(&[
1640 &state1[idx].to_array()?,
1641 &state2[idx].to_array()?,
1642 ])?);
1643 }
1644
1645 let mut first_accumulator =
1646 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1647 first_accumulator.merge_batch(&states)?;
1648
1649 let merged_state = first_accumulator.state()?;
1650 assert_eq!(merged_state.len(), state1.len());
1651
1652 let mut last_accumulator =
1654 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1655
1656 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1657 let state1 = last_accumulator.state()?;
1658
1659 let mut last_accumulator =
1660 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1661 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1662 let state2 = last_accumulator.state()?;
1663
1664 assert_eq!(state1.len(), state2.len());
1665
1666 let mut states = vec![];
1667
1668 for idx in 0..state1.len() {
1669 states.push(compute::concat(&[
1670 &state1[idx].to_array()?,
1671 &state2[idx].to_array()?,
1672 ])?);
1673 }
1674
1675 let mut last_accumulator =
1676 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1677 last_accumulator.merge_batch(&states)?;
1678
1679 let merged_state = last_accumulator.state()?;
1680 assert_eq!(merged_state.len(), state1.len());
1681
1682 Ok(())
1683 }
1684
1685 #[test]
1686 fn test_first_group_acc() -> Result<()> {
1687 let schema = Arc::new(Schema::new(vec![
1688 Field::new("a", DataType::Int64, true),
1689 Field::new("b", DataType::Int64, true),
1690 Field::new("c", DataType::Int64, true),
1691 Field::new("d", DataType::Int32, true),
1692 Field::new("e", DataType::Boolean, true),
1693 ]));
1694
1695 let sort_keys = [PhysicalSortExpr {
1696 expr: col("c", &schema).unwrap(),
1697 options: SortOptions::default(),
1698 }];
1699
1700 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1701 sort_keys.into(),
1702 true,
1703 &DataType::Int64,
1704 &[DataType::Int64],
1705 true,
1706 )?;
1707
1708 let mut val_with_orderings = {
1709 let mut val_with_orderings = Vec::<ArrayRef>::new();
1710
1711 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1712 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1713
1714 val_with_orderings.push(vals);
1715 val_with_orderings.push(orderings);
1716
1717 val_with_orderings
1718 };
1719
1720 group_acc.update_batch(
1721 &val_with_orderings,
1722 &[0, 1, 2, 1],
1723 Some(&BooleanArray::from(vec![true, true, false, true])),
1724 3,
1725 )?;
1726 assert_eq!(
1727 group_acc.size_of_orderings,
1728 group_acc.compute_size_of_orderings()
1729 );
1730
1731 let state = group_acc.state(EmitTo::All)?;
1732
1733 let expected_state: Vec<Arc<dyn Array>> = vec![
1734 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1735 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1736 Arc::new(BooleanArray::from(vec![true, true, false])),
1737 ];
1738 assert_eq!(state, expected_state);
1739
1740 assert_eq!(
1741 group_acc.size_of_orderings,
1742 group_acc.compute_size_of_orderings()
1743 );
1744
1745 group_acc.merge_batch(
1746 &state,
1747 &[0, 1, 2],
1748 Some(&BooleanArray::from(vec![true, false, false])),
1749 3,
1750 )?;
1751
1752 assert_eq!(
1753 group_acc.size_of_orderings,
1754 group_acc.compute_size_of_orderings()
1755 );
1756
1757 val_with_orderings.clear();
1758 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1759 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1760
1761 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1762
1763 let binding = group_acc.evaluate(EmitTo::All)?;
1764 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1765
1766 let expect: PrimitiveArray<Int64Type> =
1767 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1768
1769 assert_eq!(eval_result, &expect);
1770
1771 assert_eq!(
1772 group_acc.size_of_orderings,
1773 group_acc.compute_size_of_orderings()
1774 );
1775
1776 Ok(())
1777 }
1778
1779 #[test]
1780 fn test_group_acc_size_of_ordering() -> Result<()> {
1781 let schema = Arc::new(Schema::new(vec![
1782 Field::new("a", DataType::Int64, true),
1783 Field::new("b", DataType::Int64, true),
1784 Field::new("c", DataType::Int64, true),
1785 Field::new("d", DataType::Int32, true),
1786 Field::new("e", DataType::Boolean, true),
1787 ]));
1788
1789 let sort_keys = [PhysicalSortExpr {
1790 expr: col("c", &schema).unwrap(),
1791 options: SortOptions::default(),
1792 }];
1793
1794 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1795 sort_keys.into(),
1796 true,
1797 &DataType::Int64,
1798 &[DataType::Int64],
1799 true,
1800 )?;
1801
1802 let val_with_orderings = {
1803 let mut val_with_orderings = Vec::<ArrayRef>::new();
1804
1805 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1806 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1807
1808 val_with_orderings.push(vals);
1809 val_with_orderings.push(orderings);
1810
1811 val_with_orderings
1812 };
1813
1814 for _ in 0..10 {
1815 group_acc.update_batch(
1816 &val_with_orderings,
1817 &[0, 1, 2, 1],
1818 Some(&BooleanArray::from(vec![true, true, false, true])),
1819 100,
1820 )?;
1821 assert_eq!(
1822 group_acc.size_of_orderings,
1823 group_acc.compute_size_of_orderings()
1824 );
1825
1826 group_acc.state(EmitTo::First(2))?;
1827 assert_eq!(
1828 group_acc.size_of_orderings,
1829 group_acc.compute_size_of_orderings()
1830 );
1831
1832 let s = group_acc.state(EmitTo::All)?;
1833 assert_eq!(
1834 group_acc.size_of_orderings,
1835 group_acc.compute_size_of_orderings()
1836 );
1837
1838 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1839 assert_eq!(
1840 group_acc.size_of_orderings,
1841 group_acc.compute_size_of_orderings()
1842 );
1843
1844 group_acc.evaluate(EmitTo::First(2))?;
1845 assert_eq!(
1846 group_acc.size_of_orderings,
1847 group_acc.compute_size_of_orderings()
1848 );
1849
1850 group_acc.evaluate(EmitTo::All)?;
1851 assert_eq!(
1852 group_acc.size_of_orderings,
1853 group_acc.compute_size_of_orderings()
1854 );
1855 }
1856
1857 Ok(())
1858 }
1859
1860 #[test]
1861 fn test_last_group_acc() -> Result<()> {
1862 let schema = Arc::new(Schema::new(vec![
1863 Field::new("a", DataType::Int64, true),
1864 Field::new("b", DataType::Int64, true),
1865 Field::new("c", DataType::Int64, true),
1866 Field::new("d", DataType::Int32, true),
1867 Field::new("e", DataType::Boolean, true),
1868 ]));
1869
1870 let sort_keys = [PhysicalSortExpr {
1871 expr: col("c", &schema).unwrap(),
1872 options: SortOptions::default(),
1873 }];
1874
1875 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1876 sort_keys.into(),
1877 true,
1878 &DataType::Int64,
1879 &[DataType::Int64],
1880 false,
1881 )?;
1882
1883 let mut val_with_orderings = {
1884 let mut val_with_orderings = Vec::<ArrayRef>::new();
1885
1886 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1887 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1888
1889 val_with_orderings.push(vals);
1890 val_with_orderings.push(orderings);
1891
1892 val_with_orderings
1893 };
1894
1895 group_acc.update_batch(
1896 &val_with_orderings,
1897 &[0, 1, 2, 1],
1898 Some(&BooleanArray::from(vec![true, true, false, true])),
1899 3,
1900 )?;
1901
1902 let state = group_acc.state(EmitTo::All)?;
1903
1904 let expected_state: Vec<Arc<dyn Array>> = vec![
1905 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1906 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1907 Arc::new(BooleanArray::from(vec![true, true, false])),
1908 ];
1909 assert_eq!(state, expected_state);
1910
1911 group_acc.merge_batch(
1912 &state,
1913 &[0, 1, 2],
1914 Some(&BooleanArray::from(vec![true, false, false])),
1915 3,
1916 )?;
1917
1918 val_with_orderings.clear();
1919 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1920 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1921
1922 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1923
1924 let binding = group_acc.evaluate(EmitTo::All)?;
1925 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1926
1927 let expect: PrimitiveArray<Int64Type> =
1928 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1929
1930 assert_eq!(eval_result, &expect);
1931
1932 Ok(())
1933 }
1934
1935 #[test]
1936 fn test_first_list_acc_size() -> Result<()> {
1937 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1938 let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1939 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1940 false,
1941 )?;
1942
1943 first_accumulator.update_batch(values)?;
1944
1945 Ok(first_accumulator.size())
1946 }
1947
1948 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1949 repeat_with(|| Some(vec![Some(1)])).take(10000),
1950 );
1951 let batch2 =
1952 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1953
1954 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1955 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1956 assert_eq!(size1, size2);
1957
1958 Ok(())
1959 }
1960
1961 #[test]
1962 fn test_last_list_acc_size() -> Result<()> {
1963 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1964 let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1965 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1966 false,
1967 )?;
1968
1969 last_accumulator.update_batch(values)?;
1970
1971 Ok(last_accumulator.size())
1972 }
1973
1974 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1975 repeat_with(|| Some(vec![Some(1)])).take(10000),
1976 );
1977 let batch2 =
1978 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1979
1980 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1981 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1982 assert_eq!(size1, size2);
1983
1984 Ok(())
1985 }
1986
1987 #[test]
1988 fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1989 let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1991 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1992
1993 let mut trivial_accumulator =
1995 TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1996 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1997 let result = trivial_accumulator.merge_batch(&trivial_states);
1998 assert!(result.is_err());
1999 assert!(
2000 result
2001 .unwrap_err()
2002 .to_string()
2003 .contains("is_set flags contain nulls")
2004 );
2005
2006 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
2008 let ordering_expr = col("ordering", &schema)?;
2009 let mut ordered_accumulator = FirstValueAccumulator::try_new(
2010 &DataType::Utf8,
2011 &[DataType::Int64],
2012 LexOrdering::new(vec![PhysicalSortExpr {
2013 expr: ordering_expr,
2014 options: SortOptions::default(),
2015 }])
2016 .unwrap(),
2017 false,
2018 false,
2019 )?;
2020 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2021 let ordered_states = vec![value, ordering, corrupted_flag];
2022 let result = ordered_accumulator.merge_batch(&ordered_states);
2023 assert!(result.is_err());
2024 assert!(
2025 result
2026 .unwrap_err()
2027 .to_string()
2028 .contains("is_set flags contain nulls")
2029 );
2030
2031 Ok(())
2032 }
2033
2034 #[test]
2035 fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
2036 let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
2038 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
2039
2040 let mut trivial_accumulator =
2042 TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
2043 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
2044 let result = trivial_accumulator.merge_batch(&trivial_states);
2045 assert!(result.is_err());
2046 assert!(
2047 result
2048 .unwrap_err()
2049 .to_string()
2050 .contains("is_set flags contain nulls")
2051 );
2052
2053 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
2055 let ordering_expr = col("ordering", &schema)?;
2056 let mut ordered_accumulator = LastValueAccumulator::try_new(
2057 &DataType::Utf8,
2058 &[DataType::Int64],
2059 LexOrdering::new(vec![PhysicalSortExpr {
2060 expr: ordering_expr,
2061 options: SortOptions::default(),
2062 }])
2063 .unwrap(),
2064 false,
2065 false,
2066 )?;
2067 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2068 let ordered_states = vec![value, ordering, corrupted_flag];
2069 let result = ordered_accumulator.merge_batch(&ordered_states);
2070 assert!(result.is_err());
2071 assert!(
2072 result
2073 .unwrap_err()
2074 .to_string()
2075 .contains("is_set flags contain nulls")
2076 );
2077
2078 Ok(())
2079 }
2080}