Skip to main content

datafusion_functions_aggregate/
first_last.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the FIRST_VALUE/LAST_VALUE aggregations.
19
20use std::fmt::Debug;
21use std::hash::Hash;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
26use arrow::buffer::BooleanBuffer;
27use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
28use arrow::datatypes::{
29    DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type,
30    Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int8Type,
31    Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
32    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
33    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
34    UInt16Type, UInt32Type, UInt64Type,
35};
36use datafusion_common::cast::as_boolean_array;
37use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
38use datafusion_common::{
39    DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err,
40    not_impl_err,
41};
42use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
43use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
44use datafusion_expr::{
45    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
46    GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
47};
48use datafusion_functions_aggregate_common::utils::get_sort_options;
49use datafusion_macros::user_doc;
50use datafusion_physical_expr_common::sort_expr::LexOrdering;
51
52mod state;
53
54use state::{BytesValueState, PrimitiveValueState, ValueState};
55
56create_func!(FirstValue, first_value_udaf);
57create_func!(LastValue, last_value_udaf);
58
59/// Returns the first value in a group of values.
60pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
61    first_value_udaf()
62        .call(vec![expression])
63        .order_by(order_by)
64        .build()
65        // guaranteed to be `Expr::AggregateFunction`
66        .unwrap()
67}
68
69/// Returns the last value in a group of values.
70pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
71    last_value_udaf()
72        .call(vec![expression])
73        .order_by(order_by)
74        .build()
75        // guaranteed to be `Expr::AggregateFunction`
76        .unwrap()
77}
78
79fn create_groups_accumulator_helper<S: ValueState + 'static>(
80    args: &AccumulatorArgs,
81    is_first: bool,
82    state: S,
83) -> Result<Box<dyn GroupsAccumulator>> {
84    let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
85        return internal_err!("Groups accumulator must have an ordering.");
86    };
87
88    let ordering_dtypes = ordering
89        .iter()
90        .map(|e| e.expr.data_type(args.schema))
91        .collect::<Result<Vec<_>>>()?;
92
93    Ok(Box::new(FirstLastGroupsAccumulator::try_new(
94        state,
95        ordering,
96        args.ignore_nulls,
97        &ordering_dtypes,
98        is_first,
99    )?))
100}
101
102fn create_groups_accumulator(
103    args: &AccumulatorArgs,
104    is_first: bool,
105    function_name: &str,
106) -> Result<Box<dyn GroupsAccumulator>> {
107    let data_type = args.return_field.data_type();
108
109    macro_rules! instantiate_primitive {
110        ($t:ty) => {
111            create_groups_accumulator_helper(
112                args,
113                is_first,
114                PrimitiveValueState::<$t>::new(data_type.clone()),
115            )
116        };
117    }
118
119    match data_type {
120        DataType::Int8 => instantiate_primitive!(Int8Type),
121        DataType::Int16 => instantiate_primitive!(Int16Type),
122        DataType::Int32 => instantiate_primitive!(Int32Type),
123        DataType::Int64 => instantiate_primitive!(Int64Type),
124        DataType::UInt8 => instantiate_primitive!(UInt8Type),
125        DataType::UInt16 => instantiate_primitive!(UInt16Type),
126        DataType::UInt32 => instantiate_primitive!(UInt32Type),
127        DataType::UInt64 => instantiate_primitive!(UInt64Type),
128        DataType::Float16 => instantiate_primitive!(Float16Type),
129        DataType::Float32 => instantiate_primitive!(Float32Type),
130        DataType::Float64 => instantiate_primitive!(Float64Type),
131
132        DataType::Decimal32(_, _) => instantiate_primitive!(Decimal32Type),
133        DataType::Decimal64(_, _) => instantiate_primitive!(Decimal64Type),
134        DataType::Decimal128(_, _) => instantiate_primitive!(Decimal128Type),
135        DataType::Decimal256(_, _) => instantiate_primitive!(Decimal256Type),
136
137        DataType::Timestamp(TimeUnit::Second, _) => {
138            instantiate_primitive!(TimestampSecondType)
139        }
140        DataType::Timestamp(TimeUnit::Millisecond, _) => {
141            instantiate_primitive!(TimestampMillisecondType)
142        }
143        DataType::Timestamp(TimeUnit::Microsecond, _) => {
144            instantiate_primitive!(TimestampMicrosecondType)
145        }
146        DataType::Timestamp(TimeUnit::Nanosecond, _) => {
147            instantiate_primitive!(TimestampNanosecondType)
148        }
149
150        DataType::Date32 => instantiate_primitive!(Date32Type),
151        DataType::Date64 => instantiate_primitive!(Date64Type),
152        DataType::Time32(TimeUnit::Second) => instantiate_primitive!(Time32SecondType),
153        DataType::Time32(TimeUnit::Millisecond) => {
154            instantiate_primitive!(Time32MillisecondType)
155        }
156        DataType::Time64(TimeUnit::Microsecond) => {
157            instantiate_primitive!(Time64MicrosecondType)
158        }
159        DataType::Time64(TimeUnit::Nanosecond) => {
160            instantiate_primitive!(Time64NanosecondType)
161        }
162
163        DataType::Utf8
164        | DataType::LargeUtf8
165        | DataType::Utf8View
166        | DataType::Binary
167        | DataType::LargeBinary
168        | DataType::BinaryView => create_groups_accumulator_helper(
169            args,
170            is_first,
171            BytesValueState::try_new(data_type.clone())?,
172        ),
173
174        _ => internal_err!(
175            "GroupsAccumulator not supported for {}({})",
176            function_name,
177            data_type
178        ),
179    }
180}
181
182fn groups_accumulator_supported(args: &AccumulatorArgs) -> bool {
183    use DataType::*;
184    !args.order_bys.is_empty()
185        && matches!(
186            args.return_field.data_type(),
187            Int8 | Int16
188                | Int32
189                | Int64
190                | UInt8
191                | UInt16
192                | UInt32
193                | UInt64
194                | Float16
195                | Float32
196                | Float64
197                | Decimal32(_, _)
198                | Decimal64(_, _)
199                | Decimal128(_, _)
200                | Decimal256(_, _)
201                | Date32
202                | Date64
203                | Time32(_)
204                | Time64(_)
205                | Timestamp(_, _)
206                | Utf8
207                | LargeUtf8
208                | Utf8View
209                | Binary
210                | LargeBinary
211                | BinaryView
212        )
213}
214
215#[user_doc(
216    doc_section(label = "General Functions"),
217    description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
218    syntax_example = "first_value(expression [ORDER BY expression])",
219    sql_example = r#"```sql
220> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
221+-----------------------------------------------+
222| first_value(column_name ORDER BY other_column)|
223+-----------------------------------------------+
224| first_element                                 |
225+-----------------------------------------------+
226```"#,
227    standard_argument(name = "expression",)
228)]
229#[derive(PartialEq, Eq, Hash, Debug)]
230pub struct FirstValue {
231    signature: Signature,
232    is_input_pre_ordered: bool,
233}
234
235impl Default for FirstValue {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241impl FirstValue {
242    pub fn new() -> Self {
243        Self {
244            signature: Signature::any(1, Volatility::Immutable),
245            is_input_pre_ordered: false,
246        }
247    }
248}
249
250impl AggregateUDFImpl for FirstValue {
251    fn name(&self) -> &str {
252        "first_value"
253    }
254
255    fn signature(&self) -> &Signature {
256        &self.signature
257    }
258
259    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
260        not_impl_err!("Not called because the return_field_from_args is implemented")
261    }
262
263    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
264        // Preserve metadata from the first argument field
265        Ok(Arc::new(
266            Field::new(
267                self.name(),
268                arg_fields[0].data_type().clone(),
269                true, // always nullable, there may be no rows
270            )
271            .with_metadata(arg_fields[0].metadata().clone()),
272        ))
273    }
274
275    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
276        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
277            return TrivialFirstValueAccumulator::try_new(
278                acc_args.return_field.data_type(),
279                acc_args.ignore_nulls,
280            )
281            .map(|acc| Box::new(acc) as _);
282        };
283        let ordering_dtypes = ordering
284            .iter()
285            .map(|e| e.expr.data_type(acc_args.schema))
286            .collect::<Result<Vec<_>>>()?;
287        Ok(Box::new(FirstValueAccumulator::try_new(
288            acc_args.return_field.data_type(),
289            &ordering_dtypes,
290            ordering,
291            self.is_input_pre_ordered,
292            acc_args.ignore_nulls,
293        )?))
294    }
295
296    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
297        let mut fields = vec![
298            Field::new(
299                format_state_name(args.name, "first_value"),
300                args.return_type().clone(),
301                true,
302            )
303            .into(),
304        ];
305        fields.extend(args.ordering_fields.iter().cloned());
306        fields.push(
307            Field::new(
308                format_state_name(args.name, "first_value_is_set"),
309                DataType::Boolean,
310                true,
311            )
312            .into(),
313        );
314        Ok(fields)
315    }
316
317    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
318        groups_accumulator_supported(&args)
319    }
320
321    fn create_groups_accumulator(
322        &self,
323        args: AccumulatorArgs,
324    ) -> Result<Box<dyn GroupsAccumulator>> {
325        create_groups_accumulator(&args, true, self.name())
326    }
327
328    fn with_beneficial_ordering(
329        self: Arc<Self>,
330        beneficial_ordering: bool,
331    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
332        Ok(Some(Arc::new(Self {
333            signature: self.signature.clone(),
334            is_input_pre_ordered: beneficial_ordering,
335        })))
336    }
337
338    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
339        AggregateOrderSensitivity::Beneficial
340    }
341
342    fn reverse_expr(&self) -> ReversedUDAF {
343        ReversedUDAF::Reversed(last_value_udaf())
344    }
345
346    fn supports_null_handling_clause(&self) -> bool {
347        true
348    }
349
350    fn documentation(&self) -> Option<&Documentation> {
351        self.doc()
352    }
353}
354
355struct FirstLastGroupsAccumulator<S: ValueState> {
356    // ================ state ===========
357    state: S,
358    // Stores ordering values, of the aggregator requirement corresponding to first value
359    // of the aggregator.
360    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
361    // represents the ordering values corresponding to the `group_idx`-th group.
362    orderings: Vec<Vec<ScalarValue>>,
363    // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet.
364    // Once we see the first value, we set the `is_sets[group_idx]` flag
365    is_sets: BooleanBufferBuilder,
366    // size of `self.orderings`
367    // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly.
368    // Therefore, we cache it and compute `size_of` only after each update
369    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
370    size_of_orderings: usize,
371
372    // buffer for `get_filtered_extreme_of_each_group`
373    // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
374    // only valid if filter_min_of_each_group_buf.1[group_idx] == true
375    extreme_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
376
377    // =========== option ============
378
379    // Stores the applicable ordering requirement.
380    ordering_req: LexOrdering,
381    // true: take first element in an aggregation group according to the requested ordering.
382    // false: take last element in an aggregation group according to the requested ordering.
383    pick_first_in_group: bool,
384    // derived from `ordering_req`.
385    sort_options: Vec<SortOptions>,
386    // Ignore null values.
387    ignore_nulls: bool,
388    default_orderings: Vec<ScalarValue>,
389}
390
391impl<S: ValueState> FirstLastGroupsAccumulator<S> {
392    fn try_new(
393        state: S,
394        ordering_req: LexOrdering,
395        ignore_nulls: bool,
396        ordering_dtypes: &[DataType],
397        pick_first_in_group: bool,
398    ) -> Result<Self> {
399        let default_orderings = ordering_dtypes
400            .iter()
401            .map(ScalarValue::try_from)
402            .collect::<Result<_>>()?;
403
404        let sort_options = get_sort_options(&ordering_req);
405
406        Ok(Self {
407            ordering_req,
408            sort_options,
409            ignore_nulls,
410            default_orderings,
411            state,
412            orderings: Vec::new(),
413            is_sets: BooleanBufferBuilder::new(0),
414            size_of_orderings: 0,
415            extreme_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
416            pick_first_in_group,
417        })
418    }
419
420    fn should_update_state(
421        &self,
422        group_idx: usize,
423        new_ordering_values: &[ScalarValue],
424    ) -> Result<bool> {
425        if !self.is_sets.get_bit(group_idx) {
426            return Ok(true);
427        }
428
429        debug_assert!(new_ordering_values.len() == self.ordering_req.len());
430        let current_ordering = &self.orderings[group_idx];
431        compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
432            if self.pick_first_in_group {
433                x.is_gt()
434            } else {
435                x.is_lt()
436            }
437        })
438    }
439
440    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
441        let result = emit_to.take_needed(&mut self.orderings);
442
443        match emit_to {
444            EmitTo::All => self.size_of_orderings = 0,
445            EmitTo::First(_) => {
446                self.size_of_orderings -=
447                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
448            }
449        }
450
451        result
452    }
453
454    fn resize_states(&mut self, new_size: usize) {
455        self.state.resize(new_size);
456
457        if self.orderings.len() < new_size {
458            let current_len = self.orderings.len();
459
460            self.orderings
461                .resize(new_size, self.default_orderings.clone());
462
463            self.size_of_orderings += (new_size - current_len)
464                * ScalarValue::size_of_vec(
465                    // Note: In some cases (such as in the unit test below)
466                    // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone())
467                    // This may be caused by the different vec.capacity() values?
468                    self.orderings.last().unwrap(),
469                );
470        }
471
472        self.is_sets.resize(new_size);
473
474        self.extreme_of_each_group_buf.0.resize(new_size, 0);
475        self.extreme_of_each_group_buf.1.resize(new_size);
476    }
477
478    fn update_state(
479        &mut self,
480        group_idx: usize,
481        orderings: &[ScalarValue],
482        array: &ArrayRef,
483        idx: usize,
484    ) -> Result<()> {
485        self.state.update(group_idx, array, idx)?;
486        self.is_sets.set_bit(group_idx, true);
487
488        debug_assert!(orderings.len() == self.ordering_req.len());
489        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
490        self.orderings[group_idx].clear();
491        self.orderings[group_idx].extend_from_slice(orderings);
492        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
493        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
494        Ok(())
495    }
496
497    fn take_state(
498        &mut self,
499        emit_to: EmitTo,
500    ) -> Result<(ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer)> {
501        emit_to.take_needed(&mut self.extreme_of_each_group_buf.0);
502        self.extreme_of_each_group_buf
503            .1
504            .truncate(self.extreme_of_each_group_buf.0.len());
505
506        Ok((
507            self.state.take(emit_to)?,
508            self.take_orderings(emit_to),
509            state::take_need(&mut self.is_sets, emit_to),
510        ))
511    }
512
513    // should be used in test only
514    #[cfg(test)]
515    fn compute_size_of_orderings(&self) -> usize {
516        self.orderings
517            .iter()
518            .map(ScalarValue::size_of_vec)
519            .sum::<usize>()
520    }
521    /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
522    /// minimum value in `orderings` for each group, using lexicographical comparison.
523    /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
524    fn get_filtered_extreme_of_each_group(
525        &mut self,
526        orderings: &[ArrayRef],
527        group_indices: &[usize],
528        opt_filter: Option<&BooleanArray>,
529        vals: &ArrayRef,
530        is_set_arr: Option<&BooleanArray>,
531    ) -> Result<Vec<(usize, usize)>> {
532        // Set all values in min_of_each_group_buf.1 to false.
533        self.extreme_of_each_group_buf.1.truncate(0);
534        self.extreme_of_each_group_buf
535            .1
536            .append_n(self.is_sets.len(), false);
537
538        // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]`
539        // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`.
540
541        let comparator = {
542            assert_eq!(orderings.len(), self.ordering_req.len());
543            let sort_columns = orderings
544                .iter()
545                .zip(self.ordering_req.iter())
546                .map(|(array, req)| SortColumn {
547                    values: Arc::clone(array),
548                    options: Some(req.options),
549                })
550                .collect::<Vec<_>>();
551
552            LexicographicalComparator::try_new(&sort_columns)?
553        };
554
555        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
556            let group_idx = *group_idx;
557
558            let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
559            let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
560
561            if !passed_filter || !is_set {
562                continue;
563            }
564
565            if self.ignore_nulls && vals.is_null(idx_in_val) {
566                continue;
567            }
568
569            let is_valid = self.extreme_of_each_group_buf.1.get_bit(group_idx);
570
571            if !is_valid {
572                self.extreme_of_each_group_buf.1.set_bit(group_idx, true);
573                self.extreme_of_each_group_buf.0[group_idx] = idx_in_val;
574            } else {
575                let ordering = comparator
576                    .compare(self.extreme_of_each_group_buf.0[group_idx], idx_in_val);
577
578                if (ordering.is_gt() && self.pick_first_in_group)
579                    || (ordering.is_lt() && !self.pick_first_in_group)
580                {
581                    self.extreme_of_each_group_buf.0[group_idx] = idx_in_val;
582                }
583            }
584        }
585
586        Ok(self
587            .extreme_of_each_group_buf
588            .0
589            .iter()
590            .enumerate()
591            .filter(|(group_idx, _)| self.extreme_of_each_group_buf.1.get_bit(*group_idx))
592            .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
593            .collect::<Vec<_>>())
594    }
595}
596
597impl<S: ValueState + 'static> GroupsAccumulator for FirstLastGroupsAccumulator<S> {
598    fn update_batch(
599        &mut self,
600        // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
601        values_and_order_cols: &[ArrayRef],
602        group_indices: &[usize],
603        opt_filter: Option<&BooleanArray>,
604        total_num_groups: usize,
605    ) -> Result<()> {
606        self.resize_states(total_num_groups);
607
608        let vals = &values_and_order_cols[0];
609
610        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
611
612        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
613        for (group_idx, idx) in self
614            .get_filtered_extreme_of_each_group(
615                &values_and_order_cols[1..],
616                group_indices,
617                opt_filter,
618                vals,
619                None,
620            )?
621            .into_iter()
622        {
623            extract_row_at_idx_to_buf(
624                &values_and_order_cols[1..],
625                idx,
626                &mut ordering_buf,
627            )?;
628
629            if self.should_update_state(group_idx, &ordering_buf)? {
630                self.update_state(group_idx, &ordering_buf, vals, idx)?;
631            }
632        }
633
634        Ok(())
635    }
636
637    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
638        Ok(self.take_state(emit_to)?.0)
639    }
640
641    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
642        let (val_arr, orderings, is_sets) = self.take_state(emit_to)?;
643        let mut result = Vec::with_capacity(self.orderings.len() + 2);
644
645        result.push(val_arr);
646
647        let ordering_cols = {
648            let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
649            for _ in 0..self.ordering_req.len() {
650                ordering_cols.push(Vec::with_capacity(self.orderings.len()));
651            }
652            for row in orderings.into_iter() {
653                debug_assert!(row.len() == self.ordering_req.len());
654                for (col_idx, ordering) in row.into_iter().enumerate() {
655                    ordering_cols[col_idx].push(ordering);
656                }
657            }
658
659            ordering_cols
660        };
661        for ordering_col in ordering_cols {
662            result.push(ScalarValue::iter_to_array(ordering_col)?);
663        }
664
665        result.push(Arc::new(BooleanArray::new(is_sets, None)));
666
667        Ok(result)
668    }
669
670    fn merge_batch(
671        &mut self,
672        values: &[ArrayRef],
673        group_indices: &[usize],
674        opt_filter: Option<&BooleanArray>,
675        total_num_groups: usize,
676    ) -> Result<()> {
677        self.resize_states(total_num_groups);
678
679        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
680
681        let (is_set_arr, val_and_order_cols) = match values.split_last() {
682            Some(result) => result,
683            None => return internal_err!("Empty row in FIRST_VALUE"),
684        };
685
686        let is_set_arr = as_boolean_array(is_set_arr)?;
687
688        let vals = &values[0];
689        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
690        let groups = self.get_filtered_extreme_of_each_group(
691            &val_and_order_cols[1..],
692            group_indices,
693            opt_filter,
694            vals,
695            Some(is_set_arr),
696        )?;
697
698        for (group_idx, idx) in groups.into_iter() {
699            extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
700
701            if self.should_update_state(group_idx, &ordering_buf)? {
702                self.update_state(group_idx, &ordering_buf, vals, idx)?;
703            }
704        }
705
706        Ok(())
707    }
708
709    fn size(&self) -> usize {
710        self.state.size()
711            + self.is_sets.capacity() / 8 // capacity is in bits, so convert to bytes
712            + self.size_of_orderings
713            + self.extreme_of_each_group_buf.0.capacity() * size_of::<usize>()
714            + self.extreme_of_each_group_buf.1.capacity() / 8
715    }
716
717    fn supports_convert_to_state(&self) -> bool {
718        true
719    }
720
721    fn convert_to_state(
722        &self,
723        values: &[ArrayRef],
724        opt_filter: Option<&BooleanArray>,
725    ) -> Result<Vec<ArrayRef>> {
726        let mut result = values.to_vec();
727        match opt_filter {
728            Some(f) => {
729                result.push(Arc::new(f.clone()));
730                Ok(result)
731            }
732            None => {
733                result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
734                Ok(result)
735            }
736        }
737    }
738}
739
740/// This accumulator is used when there is no ordering specified for the
741/// `FIRST_VALUE` aggregation. It simply returns the first value it sees
742/// according to the pre-existing ordering of the input data, and provides
743/// a fast path for this case without needing to maintain any ordering state.
744#[derive(Debug)]
745pub struct TrivialFirstValueAccumulator {
746    first: ScalarValue,
747    // Whether we have seen the first value yet.
748    is_set: bool,
749    // Ignore null values.
750    ignore_nulls: bool,
751}
752
753impl TrivialFirstValueAccumulator {
754    /// Creates a new `TrivialFirstValueAccumulator` for the given `data_type`.
755    pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
756        ScalarValue::try_from(data_type).map(|first| Self {
757            first,
758            is_set: false,
759            ignore_nulls,
760        })
761    }
762}
763
764impl Accumulator for TrivialFirstValueAccumulator {
765    fn state(&mut self) -> Result<Vec<ScalarValue>> {
766        Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
767    }
768
769    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
770        if !self.is_set {
771            // Get first entry according to the pre-existing ordering (0th index):
772            let value = &values[0];
773            let mut first_idx = None;
774            if self.ignore_nulls {
775                // If ignoring nulls, find the first non-null value.
776                for i in 0..value.len() {
777                    if !value.is_null(i) {
778                        first_idx = Some(i);
779                        break;
780                    }
781                }
782            } else if !value.is_empty() {
783                // If not ignoring nulls, return the first value if it exists.
784                first_idx = Some(0);
785            }
786            if let Some(first_idx) = first_idx {
787                self.first = ScalarValue::try_from_array(&values[0], first_idx)?;
788                self.first.compact();
789                self.is_set = true;
790            }
791        }
792        Ok(())
793    }
794
795    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
796        // FIRST_VALUE(first1, first2, first3, ...)
797        // Second index contains is_set flag.
798        if !self.is_set {
799            let flags = states[1].as_boolean();
800            validate_is_set_flags(flags, "first_value")?;
801
802            let filtered_states =
803                filter_states_according_to_is_set(&states[0..1], flags)?;
804            if let Some(first) = filtered_states.first()
805                && !first.is_empty()
806            {
807                self.first = ScalarValue::try_from_array(first, 0)?;
808                self.is_set = true;
809            }
810        }
811        Ok(())
812    }
813
814    fn evaluate(&mut self) -> Result<ScalarValue> {
815        Ok(self.first.clone())
816    }
817
818    fn size(&self) -> usize {
819        size_of_val(self) - size_of_val(&self.first) + self.first.size()
820    }
821}
822
823#[derive(Debug)]
824pub struct FirstValueAccumulator {
825    first: ScalarValue,
826    // Whether we have seen the first value yet.
827    is_set: bool,
828    // Stores values of the ordering columns corresponding to the first value.
829    // These values are used during merging of multiple partitions.
830    orderings: Vec<ScalarValue>,
831    // Stores the applicable ordering requirement.
832    ordering_req: LexOrdering,
833    // derived from `ordering_req`.
834    sort_options: Vec<SortOptions>,
835    // Stores whether incoming data already satisfies the ordering requirement.
836    is_input_pre_ordered: bool,
837    // Ignore null values.
838    ignore_nulls: bool,
839}
840
841impl FirstValueAccumulator {
842    /// Creates a new `FirstValueAccumulator` for the given `data_type`.
843    pub fn try_new(
844        data_type: &DataType,
845        ordering_dtypes: &[DataType],
846        ordering_req: LexOrdering,
847        is_input_pre_ordered: bool,
848        ignore_nulls: bool,
849    ) -> Result<Self> {
850        let orderings = ordering_dtypes
851            .iter()
852            .map(ScalarValue::try_from)
853            .collect::<Result<_>>()?;
854        let sort_options = get_sort_options(&ordering_req);
855        ScalarValue::try_from(data_type).map(|first| Self {
856            first,
857            is_set: false,
858            orderings,
859            ordering_req,
860            sort_options,
861            is_input_pre_ordered,
862            ignore_nulls,
863        })
864    }
865
866    // Updates state with the values in the given row.
867    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
868        // Ensure any Array based scalars hold have a single value to reduce memory pressure
869        for s in row.iter_mut() {
870            s.compact();
871        }
872        self.first = row.remove(0);
873        self.orderings = row;
874        self.is_set = true;
875    }
876
877    fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
878        let [value, ordering_values @ ..] = values else {
879            return internal_err!("Empty row in FIRST_VALUE");
880        };
881        if self.is_input_pre_ordered {
882            // Get first entry according to the pre-existing ordering (0th index):
883            if self.ignore_nulls {
884                // If ignoring nulls, find the first non-null value.
885                for i in 0..value.len() {
886                    if !value.is_null(i) {
887                        return Ok(Some(i));
888                    }
889                }
890                return Ok(None);
891            } else {
892                // If not ignoring nulls, return the first value if it exists.
893                return Ok((!value.is_empty()).then_some(0));
894            }
895        }
896
897        let sort_columns = ordering_values
898            .iter()
899            .zip(self.ordering_req.iter())
900            .map(|(values, req)| SortColumn {
901                values: Arc::clone(values),
902                options: Some(req.options),
903            })
904            .collect::<Vec<_>>();
905
906        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
907
908        let min_index = if self.ignore_nulls {
909            (0..value.len())
910                .filter(|&index| !value.is_null(index))
911                .min_by(|&a, &b| comparator.compare(a, b))
912        } else {
913            (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
914        };
915
916        Ok(min_index)
917    }
918}
919
920impl Accumulator for FirstValueAccumulator {
921    fn state(&mut self) -> Result<Vec<ScalarValue>> {
922        let mut result = vec![self.first.clone()];
923        result.extend(self.orderings.iter().cloned());
924        result.push(ScalarValue::from(self.is_set));
925        Ok(result)
926    }
927
928    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
929        if let Some(first_idx) = self.get_first_idx(values)? {
930            let row = get_row_at_idx(values, first_idx)?;
931            if !self.is_set
932                || (!self.is_input_pre_ordered
933                    && compare_rows(&self.orderings, &row[1..], &self.sort_options)?
934                        .is_gt())
935            {
936                self.update_with_new_row(row);
937            }
938        }
939        Ok(())
940    }
941
942    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
943        // FIRST_VALUE(first1, first2, first3, ...)
944        // last index contains is_set flag.
945        let is_set_idx = states.len() - 1;
946        let flags = states[is_set_idx].as_boolean();
947        validate_is_set_flags(flags, "first_value")?;
948
949        let filtered_states =
950            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
951        // 1..is_set_idx range corresponds to ordering section
952        let sort_columns =
953            convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
954
955        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
956        let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
957
958        if let Some(first_idx) = min {
959            let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
960            // When collecting orderings, we exclude the is_set flag from the state.
961            let first_ordering = &first_row[1..is_set_idx];
962            // Either there is no existing value, or there is an earlier version in new data.
963            if !self.is_set
964                || compare_rows(&self.orderings, first_ordering, &self.sort_options)?
965                    .is_gt()
966            {
967                // Update with first value in the state. Note that we should exclude the
968                // is_set flag from the state. Otherwise, we will end up with a state
969                // containing two is_set flags.
970                assert!(is_set_idx <= first_row.len());
971                first_row.resize(is_set_idx, ScalarValue::Null);
972                self.update_with_new_row(first_row);
973            }
974        }
975        Ok(())
976    }
977
978    fn evaluate(&mut self) -> Result<ScalarValue> {
979        Ok(self.first.clone())
980    }
981
982    fn size(&self) -> usize {
983        size_of_val(self) - size_of_val(&self.first)
984            + self.first.size()
985            + ScalarValue::size_of_vec(&self.orderings)
986            - size_of_val(&self.orderings)
987    }
988}
989
990#[user_doc(
991    doc_section(label = "General Functions"),
992    description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
993    syntax_example = "last_value(expression [ORDER BY expression])",
994    sql_example = r#"```sql
995> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
996+-----------------------------------------------+
997| last_value(column_name ORDER BY other_column) |
998+-----------------------------------------------+
999| last_element                                  |
1000+-----------------------------------------------+
1001```"#,
1002    standard_argument(name = "expression",)
1003)]
1004#[derive(PartialEq, Eq, Hash, Debug)]
1005pub struct LastValue {
1006    signature: Signature,
1007    is_input_pre_ordered: bool,
1008}
1009
1010impl Default for LastValue {
1011    fn default() -> Self {
1012        Self::new()
1013    }
1014}
1015
1016impl LastValue {
1017    pub fn new() -> Self {
1018        Self {
1019            signature: Signature::any(1, Volatility::Immutable),
1020            is_input_pre_ordered: false,
1021        }
1022    }
1023}
1024
1025impl AggregateUDFImpl for LastValue {
1026    fn name(&self) -> &str {
1027        "last_value"
1028    }
1029
1030    fn signature(&self) -> &Signature {
1031        &self.signature
1032    }
1033
1034    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1035        not_impl_err!("Not called because the return_field_from_args is implemented")
1036    }
1037
1038    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
1039        // Preserve metadata from the first argument field
1040        Ok(Arc::new(
1041            Field::new(
1042                self.name(),
1043                arg_fields[0].data_type().clone(),
1044                true, // always nullable, there may be no rows
1045            )
1046            .with_metadata(arg_fields[0].metadata().clone()),
1047        ))
1048    }
1049
1050    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1051        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1052            return TrivialLastValueAccumulator::try_new(
1053                acc_args.return_field.data_type(),
1054                acc_args.ignore_nulls,
1055            )
1056            .map(|acc| Box::new(acc) as _);
1057        };
1058        let ordering_dtypes = ordering
1059            .iter()
1060            .map(|e| e.expr.data_type(acc_args.schema))
1061            .collect::<Result<Vec<_>>>()?;
1062        Ok(Box::new(LastValueAccumulator::try_new(
1063            acc_args.return_field.data_type(),
1064            &ordering_dtypes,
1065            ordering,
1066            self.is_input_pre_ordered,
1067            acc_args.ignore_nulls,
1068        )?))
1069    }
1070
1071    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1072        let mut fields = vec![
1073            Field::new(
1074                format_state_name(args.name, "last_value"),
1075                args.return_field.data_type().clone(),
1076                true,
1077            )
1078            .into(),
1079        ];
1080        fields.extend(args.ordering_fields.iter().cloned());
1081        fields.push(
1082            Field::new(
1083                format_state_name(args.name, "last_value_is_set"),
1084                DataType::Boolean,
1085                true,
1086            )
1087            .into(),
1088        );
1089        Ok(fields)
1090    }
1091
1092    fn with_beneficial_ordering(
1093        self: Arc<Self>,
1094        beneficial_ordering: bool,
1095    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1096        Ok(Some(Arc::new(Self {
1097            signature: self.signature.clone(),
1098            is_input_pre_ordered: beneficial_ordering,
1099        })))
1100    }
1101
1102    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1103        AggregateOrderSensitivity::Beneficial
1104    }
1105
1106    fn reverse_expr(&self) -> ReversedUDAF {
1107        ReversedUDAF::Reversed(first_value_udaf())
1108    }
1109
1110    fn supports_null_handling_clause(&self) -> bool {
1111        true
1112    }
1113
1114    fn documentation(&self) -> Option<&Documentation> {
1115        self.doc()
1116    }
1117
1118    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1119        groups_accumulator_supported(&args)
1120    }
1121
1122    fn create_groups_accumulator(
1123        &self,
1124        args: AccumulatorArgs,
1125    ) -> Result<Box<dyn GroupsAccumulator>> {
1126        create_groups_accumulator(&args, false, self.name())
1127    }
1128}
1129
1130/// This accumulator is used when there is no ordering specified for the
1131/// `LAST_VALUE` aggregation. It simply updates the last value it sees
1132/// according to the pre-existing ordering of the input data, and provides
1133/// a fast path for this case without needing to maintain any ordering state.
1134#[derive(Debug)]
1135pub struct TrivialLastValueAccumulator {
1136    last: ScalarValue,
1137    // The `is_set` flag keeps track of whether the last value is finalized.
1138    // This information is used to discriminate genuine NULLs and NULLS that
1139    // occur due to empty partitions.
1140    is_set: bool,
1141    // Ignore null values.
1142    ignore_nulls: bool,
1143}
1144
1145impl TrivialLastValueAccumulator {
1146    /// Creates a new `TrivialLastValueAccumulator` for the given `data_type`.
1147    pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1148        ScalarValue::try_from(data_type).map(|last| Self {
1149            last,
1150            is_set: false,
1151            ignore_nulls,
1152        })
1153    }
1154}
1155
1156impl Accumulator for TrivialLastValueAccumulator {
1157    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1158        Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1159    }
1160
1161    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1162        // Get last entry according to the pre-existing ordering (0th index):
1163        let value = &values[0];
1164        let mut last_idx = None;
1165        if self.ignore_nulls {
1166            // If ignoring nulls, find the last non-null value.
1167            for i in (0..value.len()).rev() {
1168                if !value.is_null(i) {
1169                    last_idx = Some(i);
1170                    break;
1171                }
1172            }
1173        } else if !value.is_empty() {
1174            // If not ignoring nulls, return the last value if it exists.
1175            last_idx = Some(value.len() - 1);
1176        }
1177        if let Some(last_idx) = last_idx {
1178            self.last = ScalarValue::try_from_array(&values[0], last_idx)?;
1179            self.last.compact();
1180            self.is_set = true;
1181        }
1182        Ok(())
1183    }
1184
1185    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1186        // LAST_VALUE(last1, last2, last3, ...)
1187        // Second index contains is_set flag.
1188        let flags = states[1].as_boolean();
1189        validate_is_set_flags(flags, "last_value")?;
1190
1191        let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1192        if let Some(last) = filtered_states.last()
1193            && !last.is_empty()
1194        {
1195            self.last = ScalarValue::try_from_array(last, 0)?;
1196            self.is_set = true;
1197        }
1198        Ok(())
1199    }
1200
1201    fn evaluate(&mut self) -> Result<ScalarValue> {
1202        Ok(self.last.clone())
1203    }
1204
1205    fn size(&self) -> usize {
1206        size_of_val(self) - size_of_val(&self.last) + self.last.size()
1207    }
1208}
1209
1210#[derive(Debug)]
1211struct LastValueAccumulator {
1212    last: ScalarValue,
1213    // The `is_set` flag keeps track of whether the last value is finalized.
1214    // This information is used to discriminate genuine NULLs and NULLS that
1215    // occur due to empty partitions.
1216    is_set: bool,
1217    // Stores values of the ordering columns corresponding to the first value.
1218    // These values are used during merging of multiple partitions.
1219    orderings: Vec<ScalarValue>,
1220    // Stores the applicable ordering requirement.
1221    ordering_req: LexOrdering,
1222    // derived from `ordering_req`.
1223    sort_options: Vec<SortOptions>,
1224    // Stores whether incoming data already satisfies the ordering requirement.
1225    is_input_pre_ordered: bool,
1226    // Ignore null values.
1227    ignore_nulls: bool,
1228}
1229
1230impl LastValueAccumulator {
1231    /// Creates a new `LastValueAccumulator` for the given `data_type`.
1232    pub fn try_new(
1233        data_type: &DataType,
1234        ordering_dtypes: &[DataType],
1235        ordering_req: LexOrdering,
1236        is_input_pre_ordered: bool,
1237        ignore_nulls: bool,
1238    ) -> Result<Self> {
1239        let orderings = ordering_dtypes
1240            .iter()
1241            .map(ScalarValue::try_from)
1242            .collect::<Result<_>>()?;
1243        let sort_options = get_sort_options(&ordering_req);
1244        ScalarValue::try_from(data_type).map(|last| Self {
1245            last,
1246            is_set: false,
1247            orderings,
1248            ordering_req,
1249            sort_options,
1250            is_input_pre_ordered,
1251            ignore_nulls,
1252        })
1253    }
1254
1255    // Updates state with the values in the given row.
1256    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1257        // Ensure any Array based scalars hold have a single value to reduce memory pressure
1258        for s in row.iter_mut() {
1259            s.compact();
1260        }
1261        self.last = row.remove(0);
1262        self.orderings = row;
1263        self.is_set = true;
1264    }
1265
1266    fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1267        let [value, ordering_values @ ..] = values else {
1268            return internal_err!("Empty row in LAST_VALUE");
1269        };
1270        if self.is_input_pre_ordered {
1271            // Get last entry according to the order of data:
1272            if self.ignore_nulls {
1273                // If ignoring nulls, find the last non-null value.
1274                for i in (0..value.len()).rev() {
1275                    if !value.is_null(i) {
1276                        return Ok(Some(i));
1277                    }
1278                }
1279                return Ok(None);
1280            } else {
1281                return Ok((!value.is_empty()).then_some(value.len() - 1));
1282            }
1283        }
1284
1285        let sort_columns = ordering_values
1286            .iter()
1287            .zip(self.ordering_req.iter())
1288            .map(|(values, req)| SortColumn {
1289                values: Arc::clone(values),
1290                options: Some(req.options),
1291            })
1292            .collect::<Vec<_>>();
1293
1294        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1295        let max_ind = if self.ignore_nulls {
1296            (0..value.len())
1297                .filter(|&index| !(value.is_null(index)))
1298                .max_by(|&a, &b| comparator.compare(a, b))
1299        } else {
1300            (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1301        };
1302
1303        Ok(max_ind)
1304    }
1305}
1306
1307impl Accumulator for LastValueAccumulator {
1308    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1309        let mut result = vec![self.last.clone()];
1310        result.extend(self.orderings.clone());
1311        result.push(ScalarValue::from(self.is_set));
1312        Ok(result)
1313    }
1314
1315    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1316        if let Some(last_idx) = self.get_last_idx(values)? {
1317            let row = get_row_at_idx(values, last_idx)?;
1318            let orderings = &row[1..];
1319            // Update when there is a more recent entry
1320            if !self.is_set
1321                || self.is_input_pre_ordered
1322                || compare_rows(&self.orderings, orderings, &self.sort_options)?.is_lt()
1323            {
1324                self.update_with_new_row(row);
1325            }
1326        }
1327        Ok(())
1328    }
1329
1330    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1331        // LAST_VALUE(last1, last2, last3, ...)
1332        // last index contains is_set flag.
1333        let is_set_idx = states.len() - 1;
1334        let flags = states[is_set_idx].as_boolean();
1335        validate_is_set_flags(flags, "last_value")?;
1336
1337        let filtered_states =
1338            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1339        // 1..is_set_idx range corresponds to ordering section
1340        let sort_columns =
1341            convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1342
1343        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1344        let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1345
1346        if let Some(last_idx) = max {
1347            let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1348            // When collecting orderings, we exclude the is_set flag from the state.
1349            let last_ordering = &last_row[1..is_set_idx];
1350            // Either there is no existing value, or there is a newer (latest)
1351            // version in the new data:
1352            if !self.is_set
1353                || self.is_input_pre_ordered
1354                || compare_rows(&self.orderings, last_ordering, &self.sort_options)?
1355                    .is_lt()
1356            {
1357                // Update with last value in the state. Note that we should exclude the
1358                // is_set flag from the state. Otherwise, we will end up with a state
1359                // containing two is_set flags.
1360                assert!(is_set_idx <= last_row.len());
1361                last_row.resize(is_set_idx, ScalarValue::Null);
1362                self.update_with_new_row(last_row);
1363            }
1364        }
1365        Ok(())
1366    }
1367
1368    fn evaluate(&mut self) -> Result<ScalarValue> {
1369        Ok(self.last.clone())
1370    }
1371
1372    fn size(&self) -> usize {
1373        size_of_val(self) - size_of_val(&self.last)
1374            + self.last.size()
1375            + ScalarValue::size_of_vec(&self.orderings)
1376            - size_of_val(&self.orderings)
1377    }
1378}
1379
1380/// Validates that `is_set flags` do not contain NULL values.
1381fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1382    if flags.null_count() > 0 {
1383        return Err(DataFusionError::Internal(format!(
1384            "{function_name}: is_set flags contain nulls"
1385        )));
1386    }
1387    Ok(())
1388}
1389
1390/// Filters states according to the `is_set` flag at the last column and returns
1391/// the resulting states.
1392fn filter_states_according_to_is_set(
1393    states: &[ArrayRef],
1394    flags: &BooleanArray,
1395) -> Result<Vec<ArrayRef>> {
1396    states
1397        .iter()
1398        .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1399        .collect()
1400}
1401
1402/// Combines array refs and their corresponding orderings to construct `SortColumn`s.
1403fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1404    arrs.iter()
1405        .zip(sort_exprs.iter())
1406        .map(|(item, sort_expr)| SortColumn {
1407            values: Arc::clone(item),
1408            options: Some(sort_expr.options),
1409        })
1410        .collect()
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415    use std::iter::repeat_with;
1416
1417    use arrow::{
1418        array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray},
1419        compute::SortOptions,
1420        datatypes::Schema,
1421    };
1422    use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};
1423
1424    use super::*;
1425
1426    #[test]
1427    fn test_first_last_value_value() -> Result<()> {
1428        let mut first_accumulator =
1429            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1430        let mut last_accumulator =
1431            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1432        // first value in the tuple is start of the range (inclusive),
1433        // second value in the tuple is end of the range (exclusive)
1434        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1435        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1436        let arrs = ranges
1437            .into_iter()
1438            .map(|(start, end)| {
1439                Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1440            })
1441            .collect::<Vec<_>>();
1442        for arr in arrs {
1443            // Once first_value is set, accumulator should remember it.
1444            // It shouldn't update first_value for each new batch
1445            first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1446            // last_value should be updated for each new batch.
1447            last_accumulator.update_batch(&[arr])?;
1448        }
1449        // First Value comes from the first value of the first batch which is 0
1450        assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1451        // Last value comes from the last value of the last batch which is 12
1452        assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1453        Ok(())
1454    }
1455
1456    #[test]
1457    fn test_first_last_state_after_merge() -> Result<()> {
1458        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1459        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1460        let arrs = ranges
1461            .into_iter()
1462            .map(|(start, end)| {
1463                Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1464            })
1465            .collect::<Vec<_>>();
1466
1467        // FirstValueAccumulator
1468        let mut first_accumulator =
1469            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1470
1471        first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1472        let state1 = first_accumulator.state()?;
1473
1474        let mut first_accumulator =
1475            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1476        first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1477        let state2 = first_accumulator.state()?;
1478
1479        assert_eq!(state1.len(), state2.len());
1480
1481        let mut states = vec![];
1482
1483        for idx in 0..state1.len() {
1484            states.push(compute::concat(&[
1485                &state1[idx].to_array()?,
1486                &state2[idx].to_array()?,
1487            ])?);
1488        }
1489
1490        let mut first_accumulator =
1491            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1492        first_accumulator.merge_batch(&states)?;
1493
1494        let merged_state = first_accumulator.state()?;
1495        assert_eq!(merged_state.len(), state1.len());
1496
1497        // LastValueAccumulator
1498        let mut last_accumulator =
1499            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1500
1501        last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1502        let state1 = last_accumulator.state()?;
1503
1504        let mut last_accumulator =
1505            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1506        last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1507        let state2 = last_accumulator.state()?;
1508
1509        assert_eq!(state1.len(), state2.len());
1510
1511        let mut states = vec![];
1512
1513        for idx in 0..state1.len() {
1514            states.push(compute::concat(&[
1515                &state1[idx].to_array()?,
1516                &state2[idx].to_array()?,
1517            ])?);
1518        }
1519
1520        let mut last_accumulator =
1521            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1522        last_accumulator.merge_batch(&states)?;
1523
1524        let merged_state = last_accumulator.state()?;
1525        assert_eq!(merged_state.len(), state1.len());
1526
1527        Ok(())
1528    }
1529
1530    #[test]
1531    fn test_first_group_acc() -> Result<()> {
1532        let schema = Arc::new(Schema::new(vec![
1533            Field::new("a", DataType::Int64, true),
1534            Field::new("b", DataType::Int64, true),
1535            Field::new("c", DataType::Int64, true),
1536            Field::new("d", DataType::Int32, true),
1537            Field::new("e", DataType::Boolean, true),
1538        ]));
1539
1540        let sort_keys = [PhysicalSortExpr {
1541            expr: col("c", &schema).unwrap(),
1542            options: SortOptions::default(),
1543        }];
1544
1545        let mut group_acc = FirstLastGroupsAccumulator::try_new(
1546            PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1547            sort_keys.into(),
1548            true,
1549            &[DataType::Int64],
1550            true,
1551        )?;
1552
1553        let mut val_with_orderings = {
1554            let mut val_with_orderings = Vec::<ArrayRef>::new();
1555
1556            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1557            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1558
1559            val_with_orderings.push(vals);
1560            val_with_orderings.push(orderings);
1561
1562            val_with_orderings
1563        };
1564
1565        group_acc.update_batch(
1566            &val_with_orderings,
1567            &[0, 1, 2, 1],
1568            Some(&BooleanArray::from(vec![true, true, false, true])),
1569            3,
1570        )?;
1571        assert_eq!(
1572            group_acc.size_of_orderings,
1573            group_acc.compute_size_of_orderings()
1574        );
1575
1576        let state = group_acc.state(EmitTo::All)?;
1577
1578        let expected_state: Vec<Arc<dyn Array>> = vec![
1579            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1580            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1581            Arc::new(BooleanArray::from(vec![true, true, false])),
1582        ];
1583        assert_eq!(state, expected_state);
1584
1585        assert_eq!(
1586            group_acc.size_of_orderings,
1587            group_acc.compute_size_of_orderings()
1588        );
1589
1590        group_acc.merge_batch(
1591            &state,
1592            &[0, 1, 2],
1593            Some(&BooleanArray::from(vec![true, false, false])),
1594            3,
1595        )?;
1596
1597        assert_eq!(
1598            group_acc.size_of_orderings,
1599            group_acc.compute_size_of_orderings()
1600        );
1601
1602        val_with_orderings.clear();
1603        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1604        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1605
1606        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1607
1608        let binding = group_acc.evaluate(EmitTo::All)?;
1609        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1610
1611        let expect: PrimitiveArray<Int64Type> =
1612            Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1613
1614        assert_eq!(eval_result, &expect);
1615
1616        assert_eq!(
1617            group_acc.size_of_orderings,
1618            group_acc.compute_size_of_orderings()
1619        );
1620
1621        Ok(())
1622    }
1623
1624    #[test]
1625    fn test_group_acc_size_of_ordering() -> Result<()> {
1626        let schema = Arc::new(Schema::new(vec![
1627            Field::new("a", DataType::Int64, true),
1628            Field::new("b", DataType::Int64, true),
1629            Field::new("c", DataType::Int64, true),
1630            Field::new("d", DataType::Int32, true),
1631            Field::new("e", DataType::Boolean, true),
1632        ]));
1633
1634        let sort_keys = [PhysicalSortExpr {
1635            expr: col("c", &schema).unwrap(),
1636            options: SortOptions::default(),
1637        }];
1638
1639        let mut group_acc = FirstLastGroupsAccumulator::try_new(
1640            PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1641            sort_keys.into(),
1642            true,
1643            &[DataType::Int64],
1644            true,
1645        )?;
1646
1647        let val_with_orderings = {
1648            let mut val_with_orderings = Vec::<ArrayRef>::new();
1649
1650            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1651            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1652
1653            val_with_orderings.push(vals);
1654            val_with_orderings.push(orderings);
1655
1656            val_with_orderings
1657        };
1658
1659        for _ in 0..10 {
1660            group_acc.update_batch(
1661                &val_with_orderings,
1662                &[0, 1, 2, 1],
1663                Some(&BooleanArray::from(vec![true, true, false, true])),
1664                100,
1665            )?;
1666            assert_eq!(
1667                group_acc.size_of_orderings,
1668                group_acc.compute_size_of_orderings()
1669            );
1670
1671            group_acc.state(EmitTo::First(2))?;
1672            assert_eq!(
1673                group_acc.size_of_orderings,
1674                group_acc.compute_size_of_orderings()
1675            );
1676
1677            let s = group_acc.state(EmitTo::All)?;
1678            assert_eq!(
1679                group_acc.size_of_orderings,
1680                group_acc.compute_size_of_orderings()
1681            );
1682
1683            group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1684            assert_eq!(
1685                group_acc.size_of_orderings,
1686                group_acc.compute_size_of_orderings()
1687            );
1688
1689            group_acc.evaluate(EmitTo::First(2))?;
1690            assert_eq!(
1691                group_acc.size_of_orderings,
1692                group_acc.compute_size_of_orderings()
1693            );
1694
1695            group_acc.evaluate(EmitTo::All)?;
1696            assert_eq!(
1697                group_acc.size_of_orderings,
1698                group_acc.compute_size_of_orderings()
1699            );
1700        }
1701
1702        Ok(())
1703    }
1704
1705    #[test]
1706    fn test_last_group_acc() -> Result<()> {
1707        let schema = Arc::new(Schema::new(vec![
1708            Field::new("a", DataType::Int64, true),
1709            Field::new("b", DataType::Int64, true),
1710            Field::new("c", DataType::Int64, true),
1711            Field::new("d", DataType::Int32, true),
1712            Field::new("e", DataType::Boolean, true),
1713        ]));
1714
1715        let sort_keys = [PhysicalSortExpr {
1716            expr: col("c", &schema).unwrap(),
1717            options: SortOptions::default(),
1718        }];
1719
1720        let mut group_acc = FirstLastGroupsAccumulator::try_new(
1721            PrimitiveValueState::<Int64Type>::new(DataType::Int64),
1722            sort_keys.into(),
1723            true,
1724            &[DataType::Int64],
1725            false,
1726        )?;
1727
1728        let mut val_with_orderings = {
1729            let mut val_with_orderings = Vec::<ArrayRef>::new();
1730
1731            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1732            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1733
1734            val_with_orderings.push(vals);
1735            val_with_orderings.push(orderings);
1736
1737            val_with_orderings
1738        };
1739
1740        group_acc.update_batch(
1741            &val_with_orderings,
1742            &[0, 1, 2, 1],
1743            Some(&BooleanArray::from(vec![true, true, false, true])),
1744            3,
1745        )?;
1746
1747        let state = group_acc.state(EmitTo::All)?;
1748
1749        let expected_state: Vec<Arc<dyn Array>> = vec![
1750            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1751            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1752            Arc::new(BooleanArray::from(vec![true, true, false])),
1753        ];
1754        assert_eq!(state, expected_state);
1755
1756        group_acc.merge_batch(
1757            &state,
1758            &[0, 1, 2],
1759            Some(&BooleanArray::from(vec![true, false, false])),
1760            3,
1761        )?;
1762
1763        val_with_orderings.clear();
1764        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1765        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1766
1767        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1768
1769        let binding = group_acc.evaluate(EmitTo::All)?;
1770        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1771
1772        let expect: PrimitiveArray<Int64Type> =
1773            Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1774
1775        assert_eq!(eval_result, &expect);
1776
1777        Ok(())
1778    }
1779
1780    #[test]
1781    fn test_first_list_acc_size() -> Result<()> {
1782        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1783            let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1784                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1785                false,
1786            )?;
1787
1788            first_accumulator.update_batch(values)?;
1789
1790            Ok(first_accumulator.size())
1791        }
1792
1793        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1794            repeat_with(|| Some(vec![Some(1)])).take(10000),
1795        );
1796        let batch2 =
1797            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1798
1799        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1800        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1801        assert_eq!(size1, size2);
1802
1803        Ok(())
1804    }
1805
1806    #[test]
1807    fn test_last_list_acc_size() -> Result<()> {
1808        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1809            let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1810                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1811                false,
1812            )?;
1813
1814            last_accumulator.update_batch(values)?;
1815
1816            Ok(last_accumulator.size())
1817        }
1818
1819        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1820            repeat_with(|| Some(vec![Some(1)])).take(10000),
1821        );
1822        let batch2 =
1823            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1824
1825        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1826        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1827        assert_eq!(size1, size2);
1828
1829        Ok(())
1830    }
1831
1832    #[test]
1833    fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1834        // Test data with corrupted is_set flag
1835        let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1836        let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1837
1838        // Test TrivialFirstValueAccumulator
1839        let mut trivial_accumulator =
1840            TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1841        let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1842        let result = trivial_accumulator.merge_batch(&trivial_states);
1843        assert!(result.is_err());
1844        assert!(
1845            result
1846                .unwrap_err()
1847                .to_string()
1848                .contains("is_set flags contain nulls")
1849        );
1850
1851        // Test FirstValueAccumulator (with ordering)
1852        let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1853        let ordering_expr = col("ordering", &schema)?;
1854        let mut ordered_accumulator = FirstValueAccumulator::try_new(
1855            &DataType::Utf8,
1856            &[DataType::Int64],
1857            LexOrdering::new(vec![PhysicalSortExpr {
1858                expr: ordering_expr,
1859                options: SortOptions::default(),
1860            }])
1861            .unwrap(),
1862            false,
1863            false,
1864        )?;
1865        let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1866        let ordered_states = vec![value, ordering, corrupted_flag];
1867        let result = ordered_accumulator.merge_batch(&ordered_states);
1868        assert!(result.is_err());
1869        assert!(
1870            result
1871                .unwrap_err()
1872                .to_string()
1873                .contains("is_set flags contain nulls")
1874        );
1875
1876        Ok(())
1877    }
1878
1879    #[test]
1880    fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
1881        // Test data with corrupted is_set flag
1882        let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
1883        let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1884
1885        // Test TrivialLastValueAccumulator
1886        let mut trivial_accumulator =
1887            TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
1888        let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1889        let result = trivial_accumulator.merge_batch(&trivial_states);
1890        assert!(result.is_err());
1891        assert!(
1892            result
1893                .unwrap_err()
1894                .to_string()
1895                .contains("is_set flags contain nulls")
1896        );
1897
1898        // Test LastValueAccumulator (with ordering)
1899        let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1900        let ordering_expr = col("ordering", &schema)?;
1901        let mut ordered_accumulator = LastValueAccumulator::try_new(
1902            &DataType::Utf8,
1903            &[DataType::Int64],
1904            LexOrdering::new(vec![PhysicalSortExpr {
1905                expr: ordering_expr,
1906                options: SortOptions::default(),
1907            }])
1908            .unwrap(),
1909            false,
1910            false,
1911        )?;
1912        let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1913        let ordered_states = vec![value, ordering, corrupted_flag];
1914        let result = ordered_accumulator.merge_batch(&ordered_states);
1915        assert!(result.is_err());
1916        assert!(
1917            result
1918                .unwrap_err()
1919                .to_string()
1920                .contains("is_set flags contain nulls")
1921        );
1922
1923        Ok(())
1924    }
1925}