1use std::fmt::Debug;
21use std::hash::Hash;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
26use arrow::buffer::BooleanBuffer;
27use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
28use arrow::datatypes::{
29 DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type,
30 Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int8Type,
31 Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
32 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
33 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
34 UInt16Type, UInt32Type, UInt64Type,
35};
36use datafusion_common::cast::as_boolean_array;
37use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
38use datafusion_common::{
39 DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err,
40 not_impl_err,
41};
42use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
43use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
44use datafusion_expr::{
45 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
46 GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
47};
48use datafusion_functions_aggregate_common::utils::get_sort_options;
49use datafusion_macros::user_doc;
50use datafusion_physical_expr_common::sort_expr::LexOrdering;
51
52mod state;
53
54use state::{BytesValueState, PrimitiveValueState, ValueState};
55
56create_func!(FirstValue, first_value_udaf);
57create_func!(LastValue, last_value_udaf);
58
59pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
61 first_value_udaf()
62 .call(vec![expression])
63 .order_by(order_by)
64 .build()
65 .unwrap()
67}
68
69pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
71 last_value_udaf()
72 .call(vec![expression])
73 .order_by(order_by)
74 .build()
75 .unwrap()
77}
78
79fn create_groups_accumulator_helper<S: ValueState + 'static>(
80 args: &AccumulatorArgs,
81 is_first: bool,
82 state: S,
83) -> Result<Box<dyn GroupsAccumulator>> {
84 let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
85 return internal_err!("Groups accumulator must have an ordering.");
86 };
87
88 let ordering_dtypes = ordering
89 .iter()
90 .map(|e| e.expr.data_type(args.schema))
91 .collect::<Result<Vec<_>>>()?;
92
93 Ok(Box::new(FirstLastGroupsAccumulator::try_new(
94 state,
95 ordering,
96 args.ignore_nulls,
97 &ordering_dtypes,
98 is_first,
99 )?))
100}
101
102fn create_groups_accumulator(
103 args: &AccumulatorArgs,
104 is_first: bool,
105 function_name: &str,
106) -> Result<Box<dyn GroupsAccumulator>> {
107 let data_type = args.return_field.data_type();
108
109 macro_rules! instantiate_primitive {
110 ($t:ty) => {
111 create_groups_accumulator_helper(
112 args,
113 is_first,
114 PrimitiveValueState::<$t>::new(data_type.clone()),
115 )
116 };
117 }
118
119 match data_type {
120 DataType::Int8 => instantiate_primitive!(Int8Type),
121 DataType::Int16 => instantiate_primitive!(Int16Type),
122 DataType::Int32 => instantiate_primitive!(Int32Type),
123 DataType::Int64 => instantiate_primitive!(Int64Type),
124 DataType::UInt8 => instantiate_primitive!(UInt8Type),
125 DataType::UInt16 => instantiate_primitive!(UInt16Type),
126 DataType::UInt32 => instantiate_primitive!(UInt32Type),
127 DataType::UInt64 => instantiate_primitive!(UInt64Type),
128 DataType::Float16 => instantiate_primitive!(Float16Type),
129 DataType::Float32 => instantiate_primitive!(Float32Type),
130 DataType::Float64 => instantiate_primitive!(Float64Type),
131
132 DataType::Decimal32(_, _) => instantiate_primitive!(Decimal32Type),
133 DataType::Decimal64(_, _) => instantiate_primitive!(Decimal64Type),
134 DataType::Decimal128(_, _) => instantiate_primitive!(Decimal128Type),
135 DataType::Decimal256(_, _) => instantiate_primitive!(Decimal256Type),
136
137 DataType::Timestamp(TimeUnit::Second, _) => {
138 instantiate_primitive!(TimestampSecondType)
139 }
140 DataType::Timestamp(TimeUnit::Millisecond, _) => {
141 instantiate_primitive!(TimestampMillisecondType)
142 }
143 DataType::Timestamp(TimeUnit::Microsecond, _) => {
144 instantiate_primitive!(TimestampMicrosecondType)
145 }
146 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
147 instantiate_primitive!(TimestampNanosecondType)
148 }
149
150 DataType::Date32 => instantiate_primitive!(Date32Type),
151 DataType::Date64 => instantiate_primitive!(Date64Type),
152 DataType::Time32(TimeUnit::Second) => instantiate_primitive!(Time32SecondType),
153 DataType::Time32(TimeUnit::Millisecond) => {
154 instantiate_primitive!(Time32MillisecondType)
155 }
156 DataType::Time64(TimeUnit::Microsecond) => {
157 instantiate_primitive!(Time64MicrosecondType)
158 }
159 DataType::Time64(TimeUnit::Nanosecond) => {
160 instantiate_primitive!(Time64NanosecondType)
161 }
162
163 DataType::Utf8
164 | DataType::LargeUtf8
165 | DataType::Utf8View
166 | DataType::Binary
167 | DataType::LargeBinary
168 | DataType::BinaryView => create_groups_accumulator_helper(
169 args,
170 is_first,
171 BytesValueState::try_new(data_type.clone())?,
172 ),
173
174 _ => internal_err!(
175 "GroupsAccumulator not supported for {}({})",
176 function_name,
177 data_type
178 ),
179 }
180}
181
182fn groups_accumulator_supported(args: &AccumulatorArgs) -> bool {
183 use DataType::*;
184 !args.order_bys.is_empty()
185 && matches!(
186 args.return_field.data_type(),
187 Int8 | Int16
188 | Int32
189 | Int64
190 | UInt8
191 | UInt16
192 | UInt32
193 | UInt64
194 | Float16
195 | Float32
196 | Float64
197 | Decimal32(_, _)
198 | Decimal64(_, _)
199 | Decimal128(_, _)
200 | Decimal256(_, _)
201 | Date32
202 | Date64
203 | Time32(_)
204 | Time64(_)
205 | Timestamp(_, _)
206 | Utf8
207 | LargeUtf8
208 | Utf8View
209 | Binary
210 | LargeBinary
211 | BinaryView
212 )
213}
214
215#[user_doc(
216 doc_section(label = "General Functions"),
217 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
218 syntax_example = "first_value(expression [ORDER BY expression])",
219 sql_example = r#"```sql
220> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
221+-----------------------------------------------+
222| first_value(column_name ORDER BY other_column)|
223+-----------------------------------------------+
224| first_element |
225+-----------------------------------------------+
226```"#,
227 standard_argument(name = "expression",)
228)]
229#[derive(PartialEq, Eq, Hash, Debug)]
230pub struct FirstValue {
231 signature: Signature,
232 is_input_pre_ordered: bool,
233}
234
235impl Default for FirstValue {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl FirstValue {
242 pub fn new() -> Self {
243 Self {
244 signature: Signature::any(1, Volatility::Immutable),
245 is_input_pre_ordered: false,
246 }
247 }
248}
249
250impl AggregateUDFImpl for FirstValue {
251 fn name(&self) -> &str {
252 "first_value"
253 }
254
255 fn signature(&self) -> &Signature {
256 &self.signature
257 }
258
259 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
260 not_impl_err!("Not called because the return_field_from_args is implemented")
261 }
262
263 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
264 Ok(Arc::new(
266 Field::new(
267 self.name(),
268 arg_fields[0].data_type().clone(),
269 true, )
271 .with_metadata(arg_fields[0].metadata().clone()),
272 ))
273 }
274
275 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
276 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
277 return TrivialFirstValueAccumulator::try_new(
278 acc_args.return_field.data_type(),
279 acc_args.ignore_nulls,
280 )
281 .map(|acc| Box::new(acc) as _);
282 };
283 let ordering_dtypes = ordering
284 .iter()
285 .map(|e| e.expr.data_type(acc_args.schema))
286 .collect::<Result<Vec<_>>>()?;
287 Ok(Box::new(FirstValueAccumulator::try_new(
288 acc_args.return_field.data_type(),
289 &ordering_dtypes,
290 ordering,
291 self.is_input_pre_ordered,
292 acc_args.ignore_nulls,
293 )?))
294 }
295
296 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
297 let mut fields = vec![
298 Field::new(
299 format_state_name(args.name, "first_value"),
300 args.return_type().clone(),
301 true,
302 )
303 .into(),
304 ];
305 fields.extend(args.ordering_fields.iter().cloned());
306 fields.push(
307 Field::new(
308 format_state_name(args.name, "first_value_is_set"),
309 DataType::Boolean,
310 true,
311 )
312 .into(),
313 );
314 Ok(fields)
315 }
316
317 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
318 groups_accumulator_supported(&args)
319 }
320
321 fn create_groups_accumulator(
322 &self,
323 args: AccumulatorArgs,
324 ) -> Result<Box<dyn GroupsAccumulator>> {
325 create_groups_accumulator(&args, true, self.name())
326 }
327
328 fn with_beneficial_ordering(
329 self: Arc<Self>,
330 beneficial_ordering: bool,
331 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
332 Ok(Some(Arc::new(Self {
333 signature: self.signature.clone(),
334 is_input_pre_ordered: beneficial_ordering,
335 })))
336 }
337
338 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
339 AggregateOrderSensitivity::Beneficial
340 }
341
342 fn reverse_expr(&self) -> ReversedUDAF {
343 ReversedUDAF::Reversed(last_value_udaf())
344 }
345
346 fn supports_null_handling_clause(&self) -> bool {
347 true
348 }
349
350 fn documentation(&self) -> Option<&Documentation> {
351 self.doc()
352 }
353}
354
355struct FirstLastGroupsAccumulator<S: ValueState> {
356 state: S,
358 orderings: Vec<Vec<ScalarValue>>,
363 is_sets: BooleanBufferBuilder,
366 size_of_orderings: usize,
371
372 extreme_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
376
377 ordering_req: LexOrdering,
381 pick_first_in_group: bool,
384 sort_options: Vec<SortOptions>,
386 ignore_nulls: bool,
388 default_orderings: Vec<ScalarValue>,
389}
390
391impl<S: ValueState> FirstLastGroupsAccumulator<S> {
392 fn try_new(
393 state: S,
394 ordering_req: LexOrdering,
395 ignore_nulls: bool,
396 ordering_dtypes: &[DataType],
397 pick_first_in_group: bool,
398 ) -> Result<Self> {
399 let default_orderings = ordering_dtypes
400 .iter()
401 .map(ScalarValue::try_from)
402 .collect::<Result<_>>()?;
403
404 let sort_options = get_sort_options(&ordering_req);
405
406 Ok(Self {
407 ordering_req,
408 sort_options,
409 ignore_nulls,
410 default_orderings,
411 state,
412 orderings: Vec::new(),
413 is_sets: BooleanBufferBuilder::new(0),
414 size_of_orderings: 0,
415 extreme_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
416 pick_first_in_group,
417 })
418 }
419
420 fn should_update_state(
421 &self,
422 group_idx: usize,
423 new_ordering_values: &[ScalarValue],
424 ) -> Result<bool> {
425 if !self.is_sets.get_bit(group_idx) {
426 return Ok(true);
427 }
428
429 debug_assert!(new_ordering_values.len() == self.ordering_req.len());
430 let current_ordering = &self.orderings[group_idx];
431 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
432 if self.pick_first_in_group {
433 x.is_gt()
434 } else {
435 x.is_lt()
436 }
437 })
438 }
439
440 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
441 let result = emit_to.take_needed(&mut self.orderings);
442
443 match emit_to {
444 EmitTo::All => self.size_of_orderings = 0,
445 EmitTo::First(_) => {
446 self.size_of_orderings -=
447 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
448 }
449 }
450
451 result
452 }
453
454 fn resize_states(&mut self, new_size: usize) {
455 self.state.resize(new_size);
456
457 if self.orderings.len() < new_size {
458 let current_len = self.orderings.len();
459
460 self.orderings
461 .resize(new_size, self.default_orderings.clone());
462
463 self.size_of_orderings += (new_size - current_len)
464 * ScalarValue::size_of_vec(
465 self.orderings.last().unwrap(),
469 );
470 }
471
472 self.is_sets.resize(new_size);
473
474 self.extreme_of_each_group_buf.0.resize(new_size, 0);
475 self.extreme_of_each_group_buf.1.resize(new_size);
476 }
477
478 fn update_state(
479 &mut self,
480 group_idx: usize,
481 orderings: &[ScalarValue],
482 array: &ArrayRef,
483 idx: usize,
484 ) -> Result<()> {
485 self.state.update(group_idx, array, idx)?;
486 self.is_sets.set_bit(group_idx, true);
487
488 debug_assert!(orderings.len() == self.ordering_req.len());
489 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
490 self.orderings[group_idx].clear();
491 self.orderings[group_idx].extend_from_slice(orderings);
492 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
493 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
494 Ok(())
495 }
496
497 fn take_state(
498 &mut self,
499 emit_to: EmitTo,
500 ) -> Result<(ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer)> {
501 emit_to.take_needed(&mut self.extreme_of_each_group_buf.0);
502 self.extreme_of_each_group_buf
503 .1
504 .truncate(self.extreme_of_each_group_buf.0.len());
505
506 Ok((
507 self.state.take(emit_to)?,
508 self.take_orderings(emit_to),
509 state::take_need(&mut self.is_sets, emit_to),
510 ))
511 }
512
513 #[cfg(test)]
515 fn compute_size_of_orderings(&self) -> usize {
516 self.orderings
517 .iter()
518 .map(ScalarValue::size_of_vec)
519 .sum::<usize>()
520 }
521 fn get_filtered_extreme_of_each_group(
525 &mut self,
526 orderings: &[ArrayRef],
527 group_indices: &[usize],
528 opt_filter: Option<&BooleanArray>,
529 vals: &ArrayRef,
530 is_set_arr: Option<&BooleanArray>,
531 ) -> Result<Vec<(usize, usize)>> {
532 self.extreme_of_each_group_buf.1.truncate(0);
534 self.extreme_of_each_group_buf
535 .1
536 .append_n(self.is_sets.len(), false);
537
538 let comparator = {
542 assert_eq!(orderings.len(), self.ordering_req.len());
543 let sort_columns = orderings
544 .iter()
545 .zip(self.ordering_req.iter())
546 .map(|(array, req)| SortColumn {
547 values: Arc::clone(array),
548 options: Some(req.options),
549 })
550 .collect::<Vec<_>>();
551
552 LexicographicalComparator::try_new(&sort_columns)?
553 };
554
555 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
556 let group_idx = *group_idx;
557
558 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
559 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
560
561 if !passed_filter || !is_set {
562 continue;
563 }
564
565 if self.ignore_nulls && vals.is_null(idx_in_val) {
566 continue;
567 }
568
569 let is_valid = self.extreme_of_each_group_buf.1.get_bit(group_idx);
570
571 if !is_valid {
572 self.extreme_of_each_group_buf.1.set_bit(group_idx, true);
573 self.extreme_of_each_group_buf.0[group_idx] = idx_in_val;
574 } else {
575 let ordering = comparator
576 .compare(self.extreme_of_each_group_buf.0[group_idx], idx_in_val);
577
578 if (ordering.is_gt() && self.pick_first_in_group)
579 || (ordering.is_lt() && !self.pick_first_in_group)
580 {
581 self.extreme_of_each_group_buf.0[group_idx] = idx_in_val;
582 }
583 }
584 }
585
586 Ok(self
587 .extreme_of_each_group_buf
588 .0
589 .iter()
590 .enumerate()
591 .filter(|(group_idx, _)| self.extreme_of_each_group_buf.1.get_bit(*group_idx))
592 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
593 .collect::<Vec<_>>())
594 }
595}
596
597impl<S: ValueState + 'static> GroupsAccumulator for FirstLastGroupsAccumulator<S> {
598 fn update_batch(
599 &mut self,
600 values_and_order_cols: &[ArrayRef],
602 group_indices: &[usize],
603 opt_filter: Option<&BooleanArray>,
604 total_num_groups: usize,
605 ) -> Result<()> {
606 self.resize_states(total_num_groups);
607
608 let vals = &values_and_order_cols[0];
609
610 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
611
612 for (group_idx, idx) in self
614 .get_filtered_extreme_of_each_group(
615 &values_and_order_cols[1..],
616 group_indices,
617 opt_filter,
618 vals,
619 None,
620 )?
621 .into_iter()
622 {
623 extract_row_at_idx_to_buf(
624 &values_and_order_cols[1..],
625 idx,
626 &mut ordering_buf,
627 )?;
628
629 if self.should_update_state(group_idx, &ordering_buf)? {
630 self.update_state(group_idx, &ordering_buf, vals, idx)?;
631 }
632 }
633
634 Ok(())
635 }
636
637 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
638 Ok(self.take_state(emit_to)?.0)
639 }
640
641 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
642 let (val_arr, orderings, is_sets) = self.take_state(emit_to)?;
643 let mut result = Vec::with_capacity(self.orderings.len() + 2);
644
645 result.push(val_arr);
646
647 let ordering_cols = {
648 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
649 for _ in 0..self.ordering_req.len() {
650 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
651 }
652 for row in orderings.into_iter() {
653 debug_assert!(row.len() == self.ordering_req.len());
654 for (col_idx, ordering) in row.into_iter().enumerate() {
655 ordering_cols[col_idx].push(ordering);
656 }
657 }
658
659 ordering_cols
660 };
661 for ordering_col in ordering_cols {
662 result.push(ScalarValue::iter_to_array(ordering_col)?);
663 }
664
665 result.push(Arc::new(BooleanArray::new(is_sets, None)));
666
667 Ok(result)
668 }
669
670 fn merge_batch(
671 &mut self,
672 values: &[ArrayRef],
673 group_indices: &[usize],
674 opt_filter: Option<&BooleanArray>,
675 total_num_groups: usize,
676 ) -> Result<()> {
677 self.resize_states(total_num_groups);
678
679 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
680
681 let (is_set_arr, val_and_order_cols) = match values.split_last() {
682 Some(result) => result,
683 None => return internal_err!("Empty row in FIRST_VALUE"),
684 };
685
686 let is_set_arr = as_boolean_array(is_set_arr)?;
687
688 let vals = &values[0];
689 let groups = self.get_filtered_extreme_of_each_group(
691 &val_and_order_cols[1..],
692 group_indices,
693 opt_filter,
694 vals,
695 Some(is_set_arr),
696 )?;
697
698 for (group_idx, idx) in groups.into_iter() {
699 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
700
701 if self.should_update_state(group_idx, &ordering_buf)? {
702 self.update_state(group_idx, &ordering_buf, vals, idx)?;
703 }
704 }
705
706 Ok(())
707 }
708
709 fn size(&self) -> usize {
710 self.state.size()
711 + self.is_sets.capacity() / 8 + self.size_of_orderings
713 + self.extreme_of_each_group_buf.0.capacity() * size_of::<usize>()
714 + self.extreme_of_each_group_buf.1.capacity() / 8
715 }
716
717 fn supports_convert_to_state(&self) -> bool {
718 true
719 }
720
721 fn convert_to_state(
722 &self,
723 values: &[ArrayRef],
724 opt_filter: Option<&BooleanArray>,
725 ) -> Result<Vec<ArrayRef>> {
726 let mut result = values.to_vec();
727 match opt_filter {
728 Some(f) => {
729 result.push(Arc::new(f.clone()));
730 Ok(result)
731 }
732 None => {
733 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
734 Ok(result)
735 }
736 }
737 }
738}
739
740#[derive(Debug)]
745pub struct TrivialFirstValueAccumulator {
746 first: ScalarValue,
747 is_set: bool,
749 ignore_nulls: bool,
751}
752
753impl TrivialFirstValueAccumulator {
754 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
756 ScalarValue::try_from(data_type).map(|first| Self {
757 first,
758 is_set: false,
759 ignore_nulls,
760 })
761 }
762}
763
764impl Accumulator for TrivialFirstValueAccumulator {
765 fn state(&mut self) -> Result<Vec<ScalarValue>> {
766 Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
767 }
768
769 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
770 if !self.is_set {
771 let value = &values[0];
773 let mut first_idx = None;
774 if self.ignore_nulls {
775 for i in 0..value.len() {
777 if !value.is_null(i) {
778 first_idx = Some(i);
779 break;
780 }
781 }
782 } else if !value.is_empty() {
783 first_idx = Some(0);
785 }
786 if let Some(first_idx) = first_idx {
787 self.first = ScalarValue::try_from_array(&values[0], first_idx)?;
788 self.first.compact();
789 self.is_set = true;
790 }
791 }
792 Ok(())
793 }
794
795 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
796 if !self.is_set {
799 let flags = states[1].as_boolean();
800 validate_is_set_flags(flags, "first_value")?;
801
802 let filtered_states =
803 filter_states_according_to_is_set(&states[0..1], flags)?;
804 if let Some(first) = filtered_states.first()
805 && !first.is_empty()
806 {
807 self.first = ScalarValue::try_from_array(first, 0)?;
808 self.is_set = true;
809 }
810 }
811 Ok(())
812 }
813
814 fn evaluate(&mut self) -> Result<ScalarValue> {
815 Ok(self.first.clone())
816 }
817
818 fn size(&self) -> usize {
819 size_of_val(self) - size_of_val(&self.first) + self.first.size()
820 }
821}
822
823#[derive(Debug)]
824pub struct FirstValueAccumulator {
825 first: ScalarValue,
826 is_set: bool,
828 orderings: Vec<ScalarValue>,
831 ordering_req: LexOrdering,
833 sort_options: Vec<SortOptions>,
835 is_input_pre_ordered: bool,
837 ignore_nulls: bool,
839}
840
841impl FirstValueAccumulator {
842 pub fn try_new(
844 data_type: &DataType,
845 ordering_dtypes: &[DataType],
846 ordering_req: LexOrdering,
847 is_input_pre_ordered: bool,
848 ignore_nulls: bool,
849 ) -> Result<Self> {
850 let orderings = ordering_dtypes
851 .iter()
852 .map(ScalarValue::try_from)
853 .collect::<Result<_>>()?;
854 let sort_options = get_sort_options(&ordering_req);
855 ScalarValue::try_from(data_type).map(|first| Self {
856 first,
857 is_set: false,
858 orderings,
859 ordering_req,
860 sort_options,
861 is_input_pre_ordered,
862 ignore_nulls,
863 })
864 }
865
866 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
868 for s in row.iter_mut() {
870 s.compact();
871 }
872 self.first = row.remove(0);
873 self.orderings = row;
874 self.is_set = true;
875 }
876
877 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
878 let [value, ordering_values @ ..] = values else {
879 return internal_err!("Empty row in FIRST_VALUE");
880 };
881 if self.is_input_pre_ordered {
882 if self.ignore_nulls {
884 for i in 0..value.len() {
886 if !value.is_null(i) {
887 return Ok(Some(i));
888 }
889 }
890 return Ok(None);
891 } else {
892 return Ok((!value.is_empty()).then_some(0));
894 }
895 }
896
897 let sort_columns = ordering_values
898 .iter()
899 .zip(self.ordering_req.iter())
900 .map(|(values, req)| SortColumn {
901 values: Arc::clone(values),
902 options: Some(req.options),
903 })
904 .collect::<Vec<_>>();
905
906 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
907
908 let min_index = if self.ignore_nulls {
909 (0..value.len())
910 .filter(|&index| !value.is_null(index))
911 .min_by(|&a, &b| comparator.compare(a, b))
912 } else {
913 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
914 };
915
916 Ok(min_index)
917 }
918}
919
920impl Accumulator for FirstValueAccumulator {
921 fn state(&mut self) -> Result<Vec<ScalarValue>> {
922 let mut result = vec![self.first.clone()];
923 result.extend(self.orderings.iter().cloned());
924 result.push(ScalarValue::from(self.is_set));
925 Ok(result)
926 }
927
928 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
929 if let Some(first_idx) = self.get_first_idx(values)? {
930 let row = get_row_at_idx(values, first_idx)?;
931 if !self.is_set
932 || (!self.is_input_pre_ordered
933 && compare_rows(&self.orderings, &row[1..], &self.sort_options)?
934 .is_gt())
935 {
936 self.update_with_new_row(row);
937 }
938 }
939 Ok(())
940 }
941
942 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
943 let is_set_idx = states.len() - 1;
946 let flags = states[is_set_idx].as_boolean();
947 validate_is_set_flags(flags, "first_value")?;
948
949 let filtered_states =
950 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
951 let sort_columns =
953 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
954
955 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
956 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
957
958 if let Some(first_idx) = min {
959 let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
960 let first_ordering = &first_row[1..is_set_idx];
962 if !self.is_set
964 || compare_rows(&self.orderings, first_ordering, &self.sort_options)?
965 .is_gt()
966 {
967 assert!(is_set_idx <= first_row.len());
971 first_row.resize(is_set_idx, ScalarValue::Null);
972 self.update_with_new_row(first_row);
973 }
974 }
975 Ok(())
976 }
977
978 fn evaluate(&mut self) -> Result<ScalarValue> {
979 Ok(self.first.clone())
980 }
981
982 fn size(&self) -> usize {
983 size_of_val(self) - size_of_val(&self.first)
984 + self.first.size()
985 + ScalarValue::size_of_vec(&self.orderings)
986 - size_of_val(&self.orderings)
987 }
988}
989
990#[user_doc(
991 doc_section(label = "General Functions"),
992 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
993 syntax_example = "last_value(expression [ORDER BY expression])",
994 sql_example = r#"```sql
995> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
996+-----------------------------------------------+
997| last_value(column_name ORDER BY other_column) |
998+-----------------------------------------------+
999| last_element |
1000+-----------------------------------------------+
1001```"#,
1002 standard_argument(name = "expression",)
1003)]
1004#[derive(PartialEq, Eq, Hash, Debug)]
1005pub struct LastValue {
1006 signature: Signature,
1007 is_input_pre_ordered: bool,
1008}
1009
1010impl Default for LastValue {
1011 fn default() -> Self {
1012 Self::new()
1013 }
1014}
1015
1016impl LastValue {
1017 pub fn new() -> Self {
1018 Self {
1019 signature: Signature::any(1, Volatility::Immutable),
1020 is_input_pre_ordered: false,
1021 }
1022 }
1023}
1024
1025impl AggregateUDFImpl for LastValue {
1026 fn name(&self) -> &str {
1027 "last_value"
1028 }
1029
1030 fn signature(&self) -> &Signature {
1031 &self.signature
1032 }
1033
1034 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1035 not_impl_err!("Not called because the return_field_from_args is implemented")
1036 }
1037
1038 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1039 Ok(Arc::new(
1041 Field::new(
1042 self.name(),
1043 arg_fields[0].data_type().clone(),
1044 true, )
1046 .with_metadata(arg_fields[0].metadata().clone()),
1047 ))
1048 }
1049
1050 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1051 let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1052 return TrivialLastValueAccumulator::try_new(
1053 acc_args.return_field.data_type(),
1054 acc_args.ignore_nulls,
1055 )
1056 .map(|acc| Box::new(acc) as _);
1057 };
1058 let ordering_dtypes = ordering
1059 .iter()
1060 .map(|e| e.expr.data_type(acc_args.schema))
1061 .collect::<Result<Vec<_>>>()?;
1062 Ok(Box::new(LastValueAccumulator::try_new(
1063 acc_args.return_field.data_type(),
1064 &ordering_dtypes,
1065 ordering,
1066 self.is_input_pre_ordered,
1067 acc_args.ignore_nulls,
1068 )?))
1069 }
1070
1071 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1072 let mut fields = vec![
1073 Field::new(
1074 format_state_name(args.name, "last_value"),
1075 args.return_field.data_type().clone(),
1076 true,
1077 )
1078 .into(),
1079 ];
1080 fields.extend(args.ordering_fields.iter().cloned());
1081 fields.push(
1082 Field::new(
1083 format_state_name(args.name, "last_value_is_set"),
1084 DataType::Boolean,
1085 true,
1086 )
1087 .into(),
1088 );
1089 Ok(fields)
1090 }
1091
1092 fn with_beneficial_ordering(
1093 self: Arc<Self>,
1094 beneficial_ordering: bool,
1095 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1096 Ok(Some(Arc::new(Self {
1097 signature: self.signature.clone(),
1098 is_input_pre_ordered: beneficial_ordering,
1099 })))
1100 }
1101
1102 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1103 AggregateOrderSensitivity::Beneficial
1104 }
1105
1106 fn reverse_expr(&self) -> ReversedUDAF {
1107 ReversedUDAF::Reversed(first_value_udaf())
1108 }
1109
1110 fn supports_null_handling_clause(&self) -> bool {
1111 true
1112 }
1113
1114 fn documentation(&self) -> Option<&Documentation> {
1115 self.doc()
1116 }
1117
1118 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1119 groups_accumulator_supported(&args)
1120 }
1121
1122 fn create_groups_accumulator(
1123 &self,
1124 args: AccumulatorArgs,
1125 ) -> Result<Box<dyn GroupsAccumulator>> {
1126 create_groups_accumulator(&args, false, self.name())
1127 }
1128}
1129
1130#[derive(Debug)]
1135pub struct TrivialLastValueAccumulator {
1136 last: ScalarValue,
1137 is_set: bool,
1141 ignore_nulls: bool,
1143}
1144
1145impl TrivialLastValueAccumulator {
1146 pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1148 ScalarValue::try_from(data_type).map(|last| Self {
1149 last,
1150 is_set: false,
1151 ignore_nulls,
1152 })
1153 }
1154}
1155
1156impl Accumulator for TrivialLastValueAccumulator {
1157 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1158 Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1159 }
1160
1161 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1162 let value = &values[0];
1164 let mut last_idx = None;
1165 if self.ignore_nulls {
1166 for i in (0..value.len()).rev() {
1168 if !value.is_null(i) {
1169 last_idx = Some(i);
1170 break;
1171 }
1172 }
1173 } else if !value.is_empty() {
1174 last_idx = Some(value.len() - 1);
1176 }
1177 if let Some(last_idx) = last_idx {
1178 self.last = ScalarValue::try_from_array(&values[0], last_idx)?;
1179 self.last.compact();
1180 self.is_set = true;
1181 }
1182 Ok(())
1183 }
1184
1185 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1186 let flags = states[1].as_boolean();
1189 validate_is_set_flags(flags, "last_value")?;
1190
1191 let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1192 if let Some(last) = filtered_states.last()
1193 && !last.is_empty()
1194 {
1195 self.last = ScalarValue::try_from_array(last, 0)?;
1196 self.is_set = true;
1197 }
1198 Ok(())
1199 }
1200
1201 fn evaluate(&mut self) -> Result<ScalarValue> {
1202 Ok(self.last.clone())
1203 }
1204
1205 fn size(&self) -> usize {
1206 size_of_val(self) - size_of_val(&self.last) + self.last.size()
1207 }
1208}
1209
1210#[derive(Debug)]
1211struct LastValueAccumulator {
1212 last: ScalarValue,
1213 is_set: bool,
1217 orderings: Vec<ScalarValue>,
1220 ordering_req: LexOrdering,
1222 sort_options: Vec<SortOptions>,
1224 is_input_pre_ordered: bool,
1226 ignore_nulls: bool,
1228}
1229
1230impl LastValueAccumulator {
1231 pub fn try_new(
1233 data_type: &DataType,
1234 ordering_dtypes: &[DataType],
1235 ordering_req: LexOrdering,
1236 is_input_pre_ordered: bool,
1237 ignore_nulls: bool,
1238 ) -> Result<Self> {
1239 let orderings = ordering_dtypes
1240 .iter()
1241 .map(ScalarValue::try_from)
1242 .collect::<Result<_>>()?;
1243 let sort_options = get_sort_options(&ordering_req);
1244 ScalarValue::try_from(data_type).map(|last| Self {
1245 last,
1246 is_set: false,
1247 orderings,
1248 ordering_req,
1249 sort_options,
1250 is_input_pre_ordered,
1251 ignore_nulls,
1252 })
1253 }
1254
1255 fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1257 for s in row.iter_mut() {
1259 s.compact();
1260 }
1261 self.last = row.remove(0);
1262 self.orderings = row;
1263 self.is_set = true;
1264 }
1265
1266 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1267 let [value, ordering_values @ ..] = values else {
1268 return internal_err!("Empty row in LAST_VALUE");
1269 };
1270 if self.is_input_pre_ordered {
1271 if self.ignore_nulls {
1273 for i in (0..value.len()).rev() {
1275 if !value.is_null(i) {
1276 return Ok(Some(i));
1277 }
1278 }
1279 return Ok(None);
1280 } else {
1281 return Ok((!value.is_empty()).then_some(value.len() - 1));
1282 }
1283 }
1284
1285 let sort_columns = ordering_values
1286 .iter()
1287 .zip(self.ordering_req.iter())
1288 .map(|(values, req)| SortColumn {
1289 values: Arc::clone(values),
1290 options: Some(req.options),
1291 })
1292 .collect::<Vec<_>>();
1293
1294 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1295 let max_ind = if self.ignore_nulls {
1296 (0..value.len())
1297 .filter(|&index| !(value.is_null(index)))
1298 .max_by(|&a, &b| comparator.compare(a, b))
1299 } else {
1300 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1301 };
1302
1303 Ok(max_ind)
1304 }
1305}
1306
1307impl Accumulator for LastValueAccumulator {
1308 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1309 let mut result = vec![self.last.clone()];
1310 result.extend(self.orderings.clone());
1311 result.push(ScalarValue::from(self.is_set));
1312 Ok(result)
1313 }
1314
1315 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1316 if let Some(last_idx) = self.get_last_idx(values)? {
1317 let row = get_row_at_idx(values, last_idx)?;
1318 let orderings = &row[1..];
1319 if !self.is_set
1321 || self.is_input_pre_ordered
1322 || compare_rows(&self.orderings, orderings, &self.sort_options)?.is_lt()
1323 {
1324 self.update_with_new_row(row);
1325 }
1326 }
1327 Ok(())
1328 }
1329
1330 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1331 let is_set_idx = states.len() - 1;
1334 let flags = states[is_set_idx].as_boolean();
1335 validate_is_set_flags(flags, "last_value")?;
1336
1337 let filtered_states =
1338 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1339 let sort_columns =
1341 convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1342
1343 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1344 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1345
1346 if let Some(last_idx) = max {
1347 let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1348 let last_ordering = &last_row[1..is_set_idx];
1350 if !self.is_set
1353 || self.is_input_pre_ordered
1354 || compare_rows(&self.orderings, last_ordering, &self.sort_options)?
1355 .is_lt()
1356 {
1357 assert!(is_set_idx <= last_row.len());
1361 last_row.resize(is_set_idx, ScalarValue::Null);
1362 self.update_with_new_row(last_row);
1363 }
1364 }
1365 Ok(())
1366 }
1367
1368 fn evaluate(&mut self) -> Result<ScalarValue> {
1369 Ok(self.last.clone())
1370 }
1371
1372 fn size(&self) -> usize {
1373 size_of_val(self) - size_of_val(&self.last)
1374 + self.last.size()
1375 + ScalarValue::size_of_vec(&self.orderings)
1376 - size_of_val(&self.orderings)
1377 }
1378}
1379
1380fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1382 if flags.null_count() > 0 {
1383 return Err(DataFusionError::Internal(format!(
1384 "{function_name}: is_set flags contain nulls"
1385 )));
1386 }
1387 Ok(())
1388}
1389
1390fn filter_states_according_to_is_set(
1393 states: &[ArrayRef],
1394 flags: &BooleanArray,
1395) -> Result<Vec<ArrayRef>> {
1396 states
1397 .iter()
1398 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1399 .collect()
1400}
1401
1402fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1404 arrs.iter()
1405 .zip(sort_exprs.iter())
1406 .map(|(item, sort_expr)| SortColumn {
1407 values: Arc::clone(item),
1408 options: Some(sort_expr.options),
1409 })
1410 .collect()
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use std::iter::repeat_with;
1416
1417 use arrow::{
1418 array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray},
1419 compute::SortOptions,
1420 datatypes::Schema,
1421 };
1422 use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};
1423
1424 use super::*;
1425
1426 #[test]
1427 fn test_first_last_value_value() -> Result<()> {
1428 let mut first_accumulator =
1429 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1430 let mut last_accumulator =
1431 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1432 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1435 let arrs = ranges
1437 .into_iter()
1438 .map(|(start, end)| {
1439 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1440 })
1441 .collect::<Vec<_>>();
1442 for arr in arrs {
1443 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1446 last_accumulator.update_batch(&[arr])?;
1448 }
1449 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1451 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1453 Ok(())
1454 }
1455
1456 #[test]
1457 fn test_first_last_state_after_merge() -> Result<()> {
1458 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1459 let arrs = ranges
1461 .into_iter()
1462 .map(|(start, end)| {
1463 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1464 })
1465 .collect::<Vec<_>>();
1466
1467 let mut first_accumulator =
1469 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1470
1471 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1472 let state1 = first_accumulator.state()?;
1473
1474 let mut first_accumulator =
1475 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1476 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1477 let state2 = first_accumulator.state()?;
1478
1479 assert_eq!(state1.len(), state2.len());
1480
1481 let mut states = vec![];
1482
1483 for idx in 0..state1.len() {
1484 states.push(compute::concat(&[
1485 &state1[idx].to_array()?,
1486 &state2[idx].to_array()?,
1487 ])?);
1488 }
1489
1490 let mut first_accumulator =
1491 TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1492 first_accumulator.merge_batch(&states)?;
1493
1494 let merged_state = first_accumulator.state()?;
1495 assert_eq!(merged_state.len(), state1.len());
1496
1497 let mut last_accumulator =
1499 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1500
1501 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1502 let state1 = last_accumulator.state()?;
1503
1504 let mut last_accumulator =
1505 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1506 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1507 let state2 = last_accumulator.state()?;
1508
1509 assert_eq!(state1.len(), state2.len());
1510
1511 let mut states = vec![];
1512
1513 for idx in 0..state1.len() {
1514 states.push(compute::concat(&[
1515 &state1[idx].to_array()?,
1516 &state2[idx].to_array()?,
1517 ])?);
1518 }
1519
1520 let mut last_accumulator =
1521 TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1522 last_accumulator.merge_batch(&states)?;
1523
1524 let merged_state = last_accumulator.state()?;
1525 assert_eq!(merged_state.len(), state1.len());
1526
1527 Ok(())
1528 }
1529
1530 #[test]
1531 fn test_first_group_acc() -> Result<()> {
1532 let schema = Arc::new(Schema::new(vec![
1533 Field::new("a", DataType::Int64, true),
1534 Field::new("b", DataType::Int64, true),
1535 Field::new("c", DataType::Int64, true),
1536 Field::new("d", DataType::Int32, true),
1537 Field::new("e", DataType::Boolean, true),
1538 ]));
1539
1540 let sort_keys = [PhysicalSortExpr {
1541 expr: col("c", &schema).unwrap(),
1542 options: SortOptions::default(),
1543 }];
1544
1545 let mut group_acc = FirstLastGroupsAccumulator::try_new(
1546 PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1547 sort_keys.into(),
1548 true,
1549 &[DataType::Int64],
1550 true,
1551 )?;
1552
1553 let mut val_with_orderings = {
1554 let mut val_with_orderings = Vec::<ArrayRef>::new();
1555
1556 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1557 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1558
1559 val_with_orderings.push(vals);
1560 val_with_orderings.push(orderings);
1561
1562 val_with_orderings
1563 };
1564
1565 group_acc.update_batch(
1566 &val_with_orderings,
1567 &[0, 1, 2, 1],
1568 Some(&BooleanArray::from(vec![true, true, false, true])),
1569 3,
1570 )?;
1571 assert_eq!(
1572 group_acc.size_of_orderings,
1573 group_acc.compute_size_of_orderings()
1574 );
1575
1576 let state = group_acc.state(EmitTo::All)?;
1577
1578 let expected_state: Vec<Arc<dyn Array>> = vec![
1579 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1580 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1581 Arc::new(BooleanArray::from(vec![true, true, false])),
1582 ];
1583 assert_eq!(state, expected_state);
1584
1585 assert_eq!(
1586 group_acc.size_of_orderings,
1587 group_acc.compute_size_of_orderings()
1588 );
1589
1590 group_acc.merge_batch(
1591 &state,
1592 &[0, 1, 2],
1593 Some(&BooleanArray::from(vec![true, false, false])),
1594 3,
1595 )?;
1596
1597 assert_eq!(
1598 group_acc.size_of_orderings,
1599 group_acc.compute_size_of_orderings()
1600 );
1601
1602 val_with_orderings.clear();
1603 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1604 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1605
1606 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1607
1608 let binding = group_acc.evaluate(EmitTo::All)?;
1609 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1610
1611 let expect: PrimitiveArray<Int64Type> =
1612 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1613
1614 assert_eq!(eval_result, &expect);
1615
1616 assert_eq!(
1617 group_acc.size_of_orderings,
1618 group_acc.compute_size_of_orderings()
1619 );
1620
1621 Ok(())
1622 }
1623
1624 #[test]
1625 fn test_group_acc_size_of_ordering() -> Result<()> {
1626 let schema = Arc::new(Schema::new(vec![
1627 Field::new("a", DataType::Int64, true),
1628 Field::new("b", DataType::Int64, true),
1629 Field::new("c", DataType::Int64, true),
1630 Field::new("d", DataType::Int32, true),
1631 Field::new("e", DataType::Boolean, true),
1632 ]));
1633
1634 let sort_keys = [PhysicalSortExpr {
1635 expr: col("c", &schema).unwrap(),
1636 options: SortOptions::default(),
1637 }];
1638
1639 let mut group_acc = FirstLastGroupsAccumulator::try_new(
1640 PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1641 sort_keys.into(),
1642 true,
1643 &[DataType::Int64],
1644 true,
1645 )?;
1646
1647 let val_with_orderings = {
1648 let mut val_with_orderings = Vec::<ArrayRef>::new();
1649
1650 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1651 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1652
1653 val_with_orderings.push(vals);
1654 val_with_orderings.push(orderings);
1655
1656 val_with_orderings
1657 };
1658
1659 for _ in 0..10 {
1660 group_acc.update_batch(
1661 &val_with_orderings,
1662 &[0, 1, 2, 1],
1663 Some(&BooleanArray::from(vec![true, true, false, true])),
1664 100,
1665 )?;
1666 assert_eq!(
1667 group_acc.size_of_orderings,
1668 group_acc.compute_size_of_orderings()
1669 );
1670
1671 group_acc.state(EmitTo::First(2))?;
1672 assert_eq!(
1673 group_acc.size_of_orderings,
1674 group_acc.compute_size_of_orderings()
1675 );
1676
1677 let s = group_acc.state(EmitTo::All)?;
1678 assert_eq!(
1679 group_acc.size_of_orderings,
1680 group_acc.compute_size_of_orderings()
1681 );
1682
1683 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1684 assert_eq!(
1685 group_acc.size_of_orderings,
1686 group_acc.compute_size_of_orderings()
1687 );
1688
1689 group_acc.evaluate(EmitTo::First(2))?;
1690 assert_eq!(
1691 group_acc.size_of_orderings,
1692 group_acc.compute_size_of_orderings()
1693 );
1694
1695 group_acc.evaluate(EmitTo::All)?;
1696 assert_eq!(
1697 group_acc.size_of_orderings,
1698 group_acc.compute_size_of_orderings()
1699 );
1700 }
1701
1702 Ok(())
1703 }
1704
1705 #[test]
1706 fn test_last_group_acc() -> Result<()> {
1707 let schema = Arc::new(Schema::new(vec![
1708 Field::new("a", DataType::Int64, true),
1709 Field::new("b", DataType::Int64, true),
1710 Field::new("c", DataType::Int64, true),
1711 Field::new("d", DataType::Int32, true),
1712 Field::new("e", DataType::Boolean, true),
1713 ]));
1714
1715 let sort_keys = [PhysicalSortExpr {
1716 expr: col("c", &schema).unwrap(),
1717 options: SortOptions::default(),
1718 }];
1719
1720 let mut group_acc = FirstLastGroupsAccumulator::try_new(
1721 PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1722 sort_keys.into(),
1723 true,
1724 &[DataType::Int64],
1725 false,
1726 )?;
1727
1728 let mut val_with_orderings = {
1729 let mut val_with_orderings = Vec::<ArrayRef>::new();
1730
1731 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1732 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1733
1734 val_with_orderings.push(vals);
1735 val_with_orderings.push(orderings);
1736
1737 val_with_orderings
1738 };
1739
1740 group_acc.update_batch(
1741 &val_with_orderings,
1742 &[0, 1, 2, 1],
1743 Some(&BooleanArray::from(vec![true, true, false, true])),
1744 3,
1745 )?;
1746
1747 let state = group_acc.state(EmitTo::All)?;
1748
1749 let expected_state: Vec<Arc<dyn Array>> = vec![
1750 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1751 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1752 Arc::new(BooleanArray::from(vec![true, true, false])),
1753 ];
1754 assert_eq!(state, expected_state);
1755
1756 group_acc.merge_batch(
1757 &state,
1758 &[0, 1, 2],
1759 Some(&BooleanArray::from(vec![true, false, false])),
1760 3,
1761 )?;
1762
1763 val_with_orderings.clear();
1764 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1765 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1766
1767 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1768
1769 let binding = group_acc.evaluate(EmitTo::All)?;
1770 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1771
1772 let expect: PrimitiveArray<Int64Type> =
1773 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1774
1775 assert_eq!(eval_result, &expect);
1776
1777 Ok(())
1778 }
1779
1780 #[test]
1781 fn test_first_list_acc_size() -> Result<()> {
1782 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1783 let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1784 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1785 false,
1786 )?;
1787
1788 first_accumulator.update_batch(values)?;
1789
1790 Ok(first_accumulator.size())
1791 }
1792
1793 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1794 repeat_with(|| Some(vec![Some(1)])).take(10000),
1795 );
1796 let batch2 =
1797 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1798
1799 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1800 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1801 assert_eq!(size1, size2);
1802
1803 Ok(())
1804 }
1805
1806 #[test]
1807 fn test_last_list_acc_size() -> Result<()> {
1808 fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1809 let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1810 &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1811 false,
1812 )?;
1813
1814 last_accumulator.update_batch(values)?;
1815
1816 Ok(last_accumulator.size())
1817 }
1818
1819 let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1820 repeat_with(|| Some(vec![Some(1)])).take(10000),
1821 );
1822 let batch2 =
1823 ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1824
1825 let size1 = size_after_batch(&[Arc::new(batch1)])?;
1826 let size2 = size_after_batch(&[Arc::new(batch2)])?;
1827 assert_eq!(size1, size2);
1828
1829 Ok(())
1830 }
1831
1832 #[test]
1833 fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1834 let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1836 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1837
1838 let mut trivial_accumulator =
1840 TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1841 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1842 let result = trivial_accumulator.merge_batch(&trivial_states);
1843 assert!(result.is_err());
1844 assert!(
1845 result
1846 .unwrap_err()
1847 .to_string()
1848 .contains("is_set flags contain nulls")
1849 );
1850
1851 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1853 let ordering_expr = col("ordering", &schema)?;
1854 let mut ordered_accumulator = FirstValueAccumulator::try_new(
1855 &DataType::Utf8,
1856 &[DataType::Int64],
1857 LexOrdering::new(vec![PhysicalSortExpr {
1858 expr: ordering_expr,
1859 options: SortOptions::default(),
1860 }])
1861 .unwrap(),
1862 false,
1863 false,
1864 )?;
1865 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1866 let ordered_states = vec![value, ordering, corrupted_flag];
1867 let result = ordered_accumulator.merge_batch(&ordered_states);
1868 assert!(result.is_err());
1869 assert!(
1870 result
1871 .unwrap_err()
1872 .to_string()
1873 .contains("is_set flags contain nulls")
1874 );
1875
1876 Ok(())
1877 }
1878
1879 #[test]
1880 fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
1881 let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
1883 let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1884
1885 let mut trivial_accumulator =
1887 TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
1888 let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1889 let result = trivial_accumulator.merge_batch(&trivial_states);
1890 assert!(result.is_err());
1891 assert!(
1892 result
1893 .unwrap_err()
1894 .to_string()
1895 .contains("is_set flags contain nulls")
1896 );
1897
1898 let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1900 let ordering_expr = col("ordering", &schema)?;
1901 let mut ordered_accumulator = LastValueAccumulator::try_new(
1902 &DataType::Utf8,
1903 &[DataType::Int64],
1904 LexOrdering::new(vec![PhysicalSortExpr {
1905 expr: ordering_expr,
1906 options: SortOptions::default(),
1907 }])
1908 .unwrap(),
1909 false,
1910 false,
1911 )?;
1912 let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1913 let ordered_states = vec![value, ordering, corrupted_flag];
1914 let result = ordered_accumulator.merge_batch(&ordered_states);
1915 assert!(result.is_err());
1916 assert!(
1917 result
1918 .unwrap_err()
1919 .to_string()
1920 .contains("is_set flags contain nulls")
1921 );
1922
1923 Ok(())
1924 }
1925}