datafusion_functions_aggregate/
first_last.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the FIRST_VALUE/LAST_VALUE aggregations.
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::hash::Hash;
23use std::mem::size_of_val;
24use std::sync::Arc;
25
26use arrow::array::{
27    Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
28    PrimitiveArray,
29};
30use arrow::buffer::{BooleanBuffer, NullBuffer};
31use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
32use arrow::datatypes::{
33    DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef,
34    Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
36    TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
37    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
38    UInt8Type,
39};
40use datafusion_common::cast::as_boolean_array;
41use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
42use datafusion_common::{
43    arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
44};
45use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
46use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
47use datafusion_expr::{
48    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
49    GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
50};
51use datafusion_functions_aggregate_common::utils::get_sort_options;
52use datafusion_macros::user_doc;
53use datafusion_physical_expr_common::sort_expr::LexOrdering;
54
55create_func!(FirstValue, first_value_udaf);
56create_func!(LastValue, last_value_udaf);
57
58/// Returns the first value in a group of values.
59pub fn first_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
60    first_value_udaf()
61        .call(vec![expression])
62        .order_by(order_by)
63        .build()
64        // guaranteed to be `Expr::AggregateFunction`
65        .unwrap()
66}
67
68/// Returns the last value in a group of values.
69pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
70    last_value_udaf()
71        .call(vec![expression])
72        .order_by(order_by)
73        .build()
74        // guaranteed to be `Expr::AggregateFunction`
75        .unwrap()
76}
77
78#[user_doc(
79    doc_section(label = "General Functions"),
80    description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
81    syntax_example = "first_value(expression [ORDER BY expression])",
82    sql_example = r#"```sql
83> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
84+-----------------------------------------------+
85| first_value(column_name ORDER BY other_column)|
86+-----------------------------------------------+
87| first_element                                 |
88+-----------------------------------------------+
89```"#,
90    standard_argument(name = "expression",)
91)]
92#[derive(PartialEq, Eq, Hash)]
93pub struct FirstValue {
94    signature: Signature,
95    is_input_pre_ordered: bool,
96}
97
98impl Debug for FirstValue {
99    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
100        f.debug_struct("FirstValue")
101            .field("name", &self.name())
102            .field("signature", &self.signature)
103            .field("accumulator", &"<FUNC>")
104            .finish()
105    }
106}
107
108impl Default for FirstValue {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl FirstValue {
115    pub fn new() -> Self {
116        Self {
117            signature: Signature::any(1, Volatility::Immutable),
118            is_input_pre_ordered: false,
119        }
120    }
121}
122
123impl AggregateUDFImpl for FirstValue {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn name(&self) -> &str {
129        "first_value"
130    }
131
132    fn signature(&self) -> &Signature {
133        &self.signature
134    }
135
136    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137        Ok(arg_types[0].clone())
138    }
139
140    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
141        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
142            return TrivialFirstValueAccumulator::try_new(
143                acc_args.return_field.data_type(),
144                acc_args.ignore_nulls,
145            )
146            .map(|acc| Box::new(acc) as _);
147        };
148        let ordering_dtypes = ordering
149            .iter()
150            .map(|e| e.expr.data_type(acc_args.schema))
151            .collect::<Result<Vec<_>>>()?;
152        Ok(Box::new(FirstValueAccumulator::try_new(
153            acc_args.return_field.data_type(),
154            &ordering_dtypes,
155            ordering,
156            self.is_input_pre_ordered,
157            acc_args.ignore_nulls,
158        )?))
159    }
160
161    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
162        let mut fields = vec![Field::new(
163            format_state_name(args.name, "first_value"),
164            args.return_type().clone(),
165            true,
166        )
167        .into()];
168        fields.extend(args.ordering_fields.iter().cloned());
169        fields.push(Field::new("is_set", DataType::Boolean, true).into());
170        Ok(fields)
171    }
172
173    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
174        use DataType::*;
175        !args.order_bys.is_empty()
176            && matches!(
177                args.return_field.data_type(),
178                Int8 | Int16
179                    | Int32
180                    | Int64
181                    | UInt8
182                    | UInt16
183                    | UInt32
184                    | UInt64
185                    | Float16
186                    | Float32
187                    | Float64
188                    | Decimal128(_, _)
189                    | Decimal256(_, _)
190                    | Date32
191                    | Date64
192                    | Time32(_)
193                    | Time64(_)
194                    | Timestamp(_, _)
195            )
196    }
197
198    fn create_groups_accumulator(
199        &self,
200        args: AccumulatorArgs,
201    ) -> Result<Box<dyn GroupsAccumulator>> {
202        fn create_accumulator<T: ArrowPrimitiveType + Send>(
203            args: AccumulatorArgs,
204        ) -> Result<Box<dyn GroupsAccumulator>> {
205            let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
206                return internal_err!("Groups accumulator must have an ordering.");
207            };
208
209            let ordering_dtypes = ordering
210                .iter()
211                .map(|e| e.expr.data_type(args.schema))
212                .collect::<Result<Vec<_>>>()?;
213
214            FirstPrimitiveGroupsAccumulator::<T>::try_new(
215                ordering,
216                args.ignore_nulls,
217                args.return_field.data_type(),
218                &ordering_dtypes,
219                true,
220            )
221            .map(|acc| Box::new(acc) as _)
222        }
223
224        match args.return_field.data_type() {
225            DataType::Int8 => create_accumulator::<Int8Type>(args),
226            DataType::Int16 => create_accumulator::<Int16Type>(args),
227            DataType::Int32 => create_accumulator::<Int32Type>(args),
228            DataType::Int64 => create_accumulator::<Int64Type>(args),
229            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
230            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
231            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
232            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
233            DataType::Float16 => create_accumulator::<Float16Type>(args),
234            DataType::Float32 => create_accumulator::<Float32Type>(args),
235            DataType::Float64 => create_accumulator::<Float64Type>(args),
236
237            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
238            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
239
240            DataType::Timestamp(TimeUnit::Second, _) => {
241                create_accumulator::<TimestampSecondType>(args)
242            }
243            DataType::Timestamp(TimeUnit::Millisecond, _) => {
244                create_accumulator::<TimestampMillisecondType>(args)
245            }
246            DataType::Timestamp(TimeUnit::Microsecond, _) => {
247                create_accumulator::<TimestampMicrosecondType>(args)
248            }
249            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
250                create_accumulator::<TimestampNanosecondType>(args)
251            }
252
253            DataType::Date32 => create_accumulator::<Date32Type>(args),
254            DataType::Date64 => create_accumulator::<Date64Type>(args),
255            DataType::Time32(TimeUnit::Second) => {
256                create_accumulator::<Time32SecondType>(args)
257            }
258            DataType::Time32(TimeUnit::Millisecond) => {
259                create_accumulator::<Time32MillisecondType>(args)
260            }
261
262            DataType::Time64(TimeUnit::Microsecond) => {
263                create_accumulator::<Time64MicrosecondType>(args)
264            }
265            DataType::Time64(TimeUnit::Nanosecond) => {
266                create_accumulator::<Time64NanosecondType>(args)
267            }
268
269            _ => internal_err!(
270                "GroupsAccumulator not supported for first_value({})",
271                args.return_field.data_type()
272            ),
273        }
274    }
275
276    fn with_beneficial_ordering(
277        self: Arc<Self>,
278        beneficial_ordering: bool,
279    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
280        Ok(Some(Arc::new(Self {
281            signature: self.signature.clone(),
282            is_input_pre_ordered: beneficial_ordering,
283        })))
284    }
285
286    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
287        AggregateOrderSensitivity::Beneficial
288    }
289
290    fn reverse_expr(&self) -> ReversedUDAF {
291        ReversedUDAF::Reversed(last_value_udaf())
292    }
293
294    fn documentation(&self) -> Option<&Documentation> {
295        self.doc()
296    }
297}
298
299// TODO: rename to PrimitiveGroupsAccumulator
300struct FirstPrimitiveGroupsAccumulator<T>
301where
302    T: ArrowPrimitiveType + Send,
303{
304    // ================ state ===========
305    vals: Vec<T::Native>,
306    // Stores ordering values, of the aggregator requirement corresponding to first value
307    // of the aggregator.
308    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
309    // represents the ordering values corresponding to the `group_idx`-th group.
310    orderings: Vec<Vec<ScalarValue>>,
311    // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet.
312    // Once we see the first value, we set the `is_sets[group_idx]` flag
313    is_sets: BooleanBufferBuilder,
314    // null_builder[group_idx] == false => vals[group_idx] is null
315    null_builder: BooleanBufferBuilder,
316    // size of `self.orderings`
317    // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly.
318    // Therefore, we cache it and compute `size_of` only after each update
319    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
320    size_of_orderings: usize,
321
322    // buffer for `get_filtered_min_of_each_group`
323    // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
324    // only valid if filter_min_of_each_group_buf.1[group_idx] == true
325    // TODO: rename to extreme_of_each_group_buf
326    min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
327
328    // =========== option ============
329
330    // Stores the applicable ordering requirement.
331    ordering_req: LexOrdering,
332    // true: take first element in an aggregation group according to the requested ordering.
333    // false: take last element in an aggregation group according to the requested ordering.
334    pick_first_in_group: bool,
335    // derived from `ordering_req`.
336    sort_options: Vec<SortOptions>,
337    // Ignore null values.
338    ignore_nulls: bool,
339    /// The output type
340    data_type: DataType,
341    default_orderings: Vec<ScalarValue>,
342}
343
344impl<T> FirstPrimitiveGroupsAccumulator<T>
345where
346    T: ArrowPrimitiveType + Send,
347{
348    fn try_new(
349        ordering_req: LexOrdering,
350        ignore_nulls: bool,
351        data_type: &DataType,
352        ordering_dtypes: &[DataType],
353        pick_first_in_group: bool,
354    ) -> Result<Self> {
355        let default_orderings = ordering_dtypes
356            .iter()
357            .map(ScalarValue::try_from)
358            .collect::<Result<_>>()?;
359
360        let sort_options = get_sort_options(&ordering_req);
361
362        Ok(Self {
363            null_builder: BooleanBufferBuilder::new(0),
364            ordering_req,
365            sort_options,
366            ignore_nulls,
367            default_orderings,
368            data_type: data_type.clone(),
369            vals: Vec::new(),
370            orderings: Vec::new(),
371            is_sets: BooleanBufferBuilder::new(0),
372            size_of_orderings: 0,
373            min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
374            pick_first_in_group,
375        })
376    }
377
378    fn should_update_state(
379        &self,
380        group_idx: usize,
381        new_ordering_values: &[ScalarValue],
382    ) -> Result<bool> {
383        if !self.is_sets.get_bit(group_idx) {
384            return Ok(true);
385        }
386
387        assert!(new_ordering_values.len() == self.ordering_req.len());
388        let current_ordering = &self.orderings[group_idx];
389        compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
390            if self.pick_first_in_group {
391                x.is_gt()
392            } else {
393                x.is_lt()
394            }
395        })
396    }
397
398    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
399        let result = emit_to.take_needed(&mut self.orderings);
400
401        match emit_to {
402            EmitTo::All => self.size_of_orderings = 0,
403            EmitTo::First(_) => {
404                self.size_of_orderings -=
405                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
406            }
407        }
408
409        result
410    }
411
412    fn take_need(
413        bool_buf_builder: &mut BooleanBufferBuilder,
414        emit_to: EmitTo,
415    ) -> BooleanBuffer {
416        let bool_buf = bool_buf_builder.finish();
417        match emit_to {
418            EmitTo::All => bool_buf,
419            EmitTo::First(n) => {
420                // split off the first N values in seen_values
421                //
422                // TODO make this more efficient rather than two
423                // copies and bitwise manipulation
424                let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
425                // reset the existing buffer
426                for b in bool_buf.iter().skip(n) {
427                    bool_buf_builder.append(b);
428                }
429                first_n
430            }
431        }
432    }
433
434    fn resize_states(&mut self, new_size: usize) {
435        self.vals.resize(new_size, T::default_value());
436
437        self.null_builder.resize(new_size);
438
439        if self.orderings.len() < new_size {
440            let current_len = self.orderings.len();
441
442            self.orderings
443                .resize(new_size, self.default_orderings.clone());
444
445            self.size_of_orderings += (new_size - current_len)
446                * ScalarValue::size_of_vec(
447                    // Note: In some cases (such as in the unit test below)
448                    // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone())
449                    // This may be caused by the different vec.capacity() values?
450                    self.orderings.last().unwrap(),
451                );
452        }
453
454        self.is_sets.resize(new_size);
455
456        self.min_of_each_group_buf.0.resize(new_size, 0);
457        self.min_of_each_group_buf.1.resize(new_size);
458    }
459
460    fn update_state(
461        &mut self,
462        group_idx: usize,
463        orderings: &[ScalarValue],
464        new_val: T::Native,
465        is_null: bool,
466    ) {
467        self.vals[group_idx] = new_val;
468        self.is_sets.set_bit(group_idx, true);
469
470        self.null_builder.set_bit(group_idx, !is_null);
471
472        assert!(orderings.len() == self.ordering_req.len());
473        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
474        self.orderings[group_idx].clear();
475        self.orderings[group_idx].extend_from_slice(orderings);
476        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
477        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
478    }
479
480    fn take_state(
481        &mut self,
482        emit_to: EmitTo,
483    ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
484        emit_to.take_needed(&mut self.min_of_each_group_buf.0);
485        self.min_of_each_group_buf
486            .1
487            .truncate(self.min_of_each_group_buf.0.len());
488
489        (
490            self.take_vals_and_null_buf(emit_to),
491            self.take_orderings(emit_to),
492            Self::take_need(&mut self.is_sets, emit_to),
493        )
494    }
495
496    // should be used in test only
497    #[cfg(test)]
498    fn compute_size_of_orderings(&self) -> usize {
499        self.orderings
500            .iter()
501            .map(ScalarValue::size_of_vec)
502            .sum::<usize>()
503    }
504    /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
505    /// minimum value in `orderings` for each group, using lexicographical comparison.
506    /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
507    /// TODO: rename to get_filtered_extreme_of_each_group
508    fn get_filtered_min_of_each_group(
509        &mut self,
510        orderings: &[ArrayRef],
511        group_indices: &[usize],
512        opt_filter: Option<&BooleanArray>,
513        vals: &PrimitiveArray<T>,
514        is_set_arr: Option<&BooleanArray>,
515    ) -> Result<Vec<(usize, usize)>> {
516        // Set all values in min_of_each_group_buf.1 to false.
517        self.min_of_each_group_buf.1.truncate(0);
518        self.min_of_each_group_buf
519            .1
520            .append_n(self.vals.len(), false);
521
522        // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]`
523        // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`.
524
525        let comparator = {
526            assert_eq!(orderings.len(), self.ordering_req.len());
527            let sort_columns = orderings
528                .iter()
529                .zip(self.ordering_req.iter())
530                .map(|(array, req)| SortColumn {
531                    values: Arc::clone(array),
532                    options: Some(req.options),
533                })
534                .collect::<Vec<_>>();
535
536            LexicographicalComparator::try_new(&sort_columns)?
537        };
538
539        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
540            let group_idx = *group_idx;
541
542            let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
543            let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
544
545            if !passed_filter || !is_set {
546                continue;
547            }
548
549            if self.ignore_nulls && vals.is_null(idx_in_val) {
550                continue;
551            }
552
553            let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
554
555            if !is_valid {
556                self.min_of_each_group_buf.1.set_bit(group_idx, true);
557                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
558            } else {
559                let ordering = comparator
560                    .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
561
562                if (ordering.is_gt() && self.pick_first_in_group)
563                    || (ordering.is_lt() && !self.pick_first_in_group)
564                {
565                    self.min_of_each_group_buf.0[group_idx] = idx_in_val;
566                }
567            }
568        }
569
570        Ok(self
571            .min_of_each_group_buf
572            .0
573            .iter()
574            .enumerate()
575            .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
576            .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
577            .collect::<Vec<_>>())
578    }
579
580    fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
581        let r = emit_to.take_needed(&mut self.vals);
582
583        let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
584
585        let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no copy
586            .with_data_type(self.data_type.clone());
587        Arc::new(values)
588    }
589}
590
591impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
592where
593    T: ArrowPrimitiveType + Send,
594{
595    fn update_batch(
596        &mut self,
597        // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
598        values_and_order_cols: &[ArrayRef],
599        group_indices: &[usize],
600        opt_filter: Option<&BooleanArray>,
601        total_num_groups: usize,
602    ) -> Result<()> {
603        self.resize_states(total_num_groups);
604
605        let vals = values_and_order_cols[0].as_primitive::<T>();
606
607        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
608
609        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
610        for (group_idx, idx) in self
611            .get_filtered_min_of_each_group(
612                &values_and_order_cols[1..],
613                group_indices,
614                opt_filter,
615                vals,
616                None,
617            )?
618            .into_iter()
619        {
620            extract_row_at_idx_to_buf(
621                &values_and_order_cols[1..],
622                idx,
623                &mut ordering_buf,
624            )?;
625
626            if self.should_update_state(group_idx, &ordering_buf)? {
627                self.update_state(
628                    group_idx,
629                    &ordering_buf,
630                    vals.value(idx),
631                    vals.is_null(idx),
632                );
633            }
634        }
635
636        Ok(())
637    }
638
639    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
640        Ok(self.take_state(emit_to).0)
641    }
642
643    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
644        let (val_arr, orderings, is_sets) = self.take_state(emit_to);
645        let mut result = Vec::with_capacity(self.orderings.len() + 2);
646
647        result.push(val_arr);
648
649        let ordering_cols = {
650            let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
651            for _ in 0..self.ordering_req.len() {
652                ordering_cols.push(Vec::with_capacity(self.orderings.len()));
653            }
654            for row in orderings.into_iter() {
655                assert_eq!(row.len(), self.ordering_req.len());
656                for (col_idx, ordering) in row.into_iter().enumerate() {
657                    ordering_cols[col_idx].push(ordering);
658                }
659            }
660
661            ordering_cols
662        };
663        for ordering_col in ordering_cols {
664            result.push(ScalarValue::iter_to_array(ordering_col)?);
665        }
666
667        result.push(Arc::new(BooleanArray::new(is_sets, None)));
668
669        Ok(result)
670    }
671
672    fn merge_batch(
673        &mut self,
674        values: &[ArrayRef],
675        group_indices: &[usize],
676        opt_filter: Option<&BooleanArray>,
677        total_num_groups: usize,
678    ) -> Result<()> {
679        self.resize_states(total_num_groups);
680
681        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
682
683        let (is_set_arr, val_and_order_cols) = match values.split_last() {
684            Some(result) => result,
685            None => return internal_err!("Empty row in FIRST_VALUE"),
686        };
687
688        let is_set_arr = as_boolean_array(is_set_arr)?;
689
690        let vals = values[0].as_primitive::<T>();
691        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
692        let groups = self.get_filtered_min_of_each_group(
693            &val_and_order_cols[1..],
694            group_indices,
695            opt_filter,
696            vals,
697            Some(is_set_arr),
698        )?;
699
700        for (group_idx, idx) in groups.into_iter() {
701            extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
702
703            if self.should_update_state(group_idx, &ordering_buf)? {
704                self.update_state(
705                    group_idx,
706                    &ordering_buf,
707                    vals.value(idx),
708                    vals.is_null(idx),
709                );
710            }
711        }
712
713        Ok(())
714    }
715
716    fn size(&self) -> usize {
717        self.vals.capacity() * size_of::<T::Native>()
718            + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes
719            + self.is_sets.capacity() / 8
720            + self.size_of_orderings
721            + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
722            + self.min_of_each_group_buf.1.capacity() / 8
723    }
724
725    fn supports_convert_to_state(&self) -> bool {
726        true
727    }
728
729    fn convert_to_state(
730        &self,
731        values: &[ArrayRef],
732        opt_filter: Option<&BooleanArray>,
733    ) -> Result<Vec<ArrayRef>> {
734        let mut result = values.to_vec();
735        match opt_filter {
736            Some(f) => {
737                result.push(Arc::new(f.clone()));
738                Ok(result)
739            }
740            None => {
741                result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
742                Ok(result)
743            }
744        }
745    }
746}
747
748/// This accumulator is used when there is no ordering specified for the
749/// `FIRST_VALUE` aggregation. It simply returns the first value it sees
750/// according to the pre-existing ordering of the input data, and provides
751/// a fast path for this case without needing to maintain any ordering state.
752#[derive(Debug)]
753pub struct TrivialFirstValueAccumulator {
754    first: ScalarValue,
755    // Whether we have seen the first value yet.
756    is_set: bool,
757    // Ignore null values.
758    ignore_nulls: bool,
759}
760
761impl TrivialFirstValueAccumulator {
762    /// Creates a new `TrivialFirstValueAccumulator` for the given `data_type`.
763    pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
764        ScalarValue::try_from(data_type).map(|first| Self {
765            first,
766            is_set: false,
767            ignore_nulls,
768        })
769    }
770}
771
772impl Accumulator for TrivialFirstValueAccumulator {
773    fn state(&mut self) -> Result<Vec<ScalarValue>> {
774        Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)])
775    }
776
777    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
778        if !self.is_set {
779            // Get first entry according to the pre-existing ordering (0th index):
780            let value = &values[0];
781            let mut first_idx = None;
782            if self.ignore_nulls {
783                // If ignoring nulls, find the first non-null value.
784                for i in 0..value.len() {
785                    if !value.is_null(i) {
786                        first_idx = Some(i);
787                        break;
788                    }
789                }
790            } else if !value.is_empty() {
791                // If not ignoring nulls, return the first value if it exists.
792                first_idx = Some(0);
793            }
794            if let Some(first_idx) = first_idx {
795                let mut row = get_row_at_idx(values, first_idx)?;
796                self.first = row.swap_remove(0);
797                self.first.compact();
798                self.is_set = true;
799            }
800        }
801        Ok(())
802    }
803
804    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
805        // FIRST_VALUE(first1, first2, first3, ...)
806        // Second index contains is_set flag.
807        if !self.is_set {
808            let flags = states[1].as_boolean();
809            let filtered_states =
810                filter_states_according_to_is_set(&states[0..1], flags)?;
811            if let Some(first) = filtered_states.first() {
812                if !first.is_empty() {
813                    self.first = ScalarValue::try_from_array(first, 0)?;
814                    self.is_set = true;
815                }
816            }
817        }
818        Ok(())
819    }
820
821    fn evaluate(&mut self) -> Result<ScalarValue> {
822        Ok(self.first.clone())
823    }
824
825    fn size(&self) -> usize {
826        size_of_val(self) - size_of_val(&self.first) + self.first.size()
827    }
828}
829
830#[derive(Debug)]
831pub struct FirstValueAccumulator {
832    first: ScalarValue,
833    // Whether we have seen the first value yet.
834    is_set: bool,
835    // Stores values of the ordering columns corresponding to the first value.
836    // These values are used during merging of multiple partitions.
837    orderings: Vec<ScalarValue>,
838    // Stores the applicable ordering requirement.
839    ordering_req: LexOrdering,
840    // Stores whether incoming data already satisfies the ordering requirement.
841    is_input_pre_ordered: bool,
842    // Ignore null values.
843    ignore_nulls: bool,
844}
845
846impl FirstValueAccumulator {
847    /// Creates a new `FirstValueAccumulator` for the given `data_type`.
848    pub fn try_new(
849        data_type: &DataType,
850        ordering_dtypes: &[DataType],
851        ordering_req: LexOrdering,
852        is_input_pre_ordered: bool,
853        ignore_nulls: bool,
854    ) -> Result<Self> {
855        let orderings = ordering_dtypes
856            .iter()
857            .map(ScalarValue::try_from)
858            .collect::<Result<_>>()?;
859        ScalarValue::try_from(data_type).map(|first| Self {
860            first,
861            is_set: false,
862            orderings,
863            ordering_req,
864            is_input_pre_ordered,
865            ignore_nulls,
866        })
867    }
868
869    // Updates state with the values in the given row.
870    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
871        // Ensure any Array based scalars hold have a single value to reduce memory pressure
872        for s in row.iter_mut() {
873            s.compact();
874        }
875        self.first = row.remove(0);
876        self.orderings = row;
877        self.is_set = true;
878    }
879
880    fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
881        let [value, ordering_values @ ..] = values else {
882            return internal_err!("Empty row in FIRST_VALUE");
883        };
884        if self.is_input_pre_ordered {
885            // Get first entry according to the pre-existing ordering (0th index):
886            if self.ignore_nulls {
887                // If ignoring nulls, find the first non-null value.
888                for i in 0..value.len() {
889                    if !value.is_null(i) {
890                        return Ok(Some(i));
891                    }
892                }
893                return Ok(None);
894            } else {
895                // If not ignoring nulls, return the first value if it exists.
896                return Ok((!value.is_empty()).then_some(0));
897            }
898        }
899
900        let sort_columns = ordering_values
901            .iter()
902            .zip(self.ordering_req.iter())
903            .map(|(values, req)| SortColumn {
904                values: Arc::clone(values),
905                options: Some(req.options),
906            })
907            .collect::<Vec<_>>();
908
909        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
910
911        let min_index = if self.ignore_nulls {
912            (0..value.len())
913                .filter(|&index| !value.is_null(index))
914                .min_by(|&a, &b| comparator.compare(a, b))
915        } else {
916            (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
917        };
918
919        Ok(min_index)
920    }
921}
922
923impl Accumulator for FirstValueAccumulator {
924    fn state(&mut self) -> Result<Vec<ScalarValue>> {
925        let mut result = vec![self.first.clone()];
926        result.extend(self.orderings.iter().cloned());
927        result.push(ScalarValue::from(self.is_set));
928        Ok(result)
929    }
930
931    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
932        if let Some(first_idx) = self.get_first_idx(values)? {
933            let row = get_row_at_idx(values, first_idx)?;
934            if !self.is_set
935                || (!self.is_input_pre_ordered
936                    && compare_rows(
937                        &self.orderings,
938                        &row[1..],
939                        &get_sort_options(&self.ordering_req),
940                    )?
941                    .is_gt())
942            {
943                self.update_with_new_row(row);
944            }
945        }
946        Ok(())
947    }
948
949    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
950        // FIRST_VALUE(first1, first2, first3, ...)
951        // last index contains is_set flag.
952        let is_set_idx = states.len() - 1;
953        let flags = states[is_set_idx].as_boolean();
954        let filtered_states =
955            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
956        // 1..is_set_idx range corresponds to ordering section
957        let sort_columns =
958            convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
959
960        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
961        let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
962
963        if let Some(first_idx) = min {
964            let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
965            // When collecting orderings, we exclude the is_set flag from the state.
966            let first_ordering = &first_row[1..is_set_idx];
967            let sort_options = get_sort_options(&self.ordering_req);
968            // Either there is no existing value, or there is an earlier version in new data.
969            if !self.is_set
970                || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
971            {
972                // Update with first value in the state. Note that we should exclude the
973                // is_set flag from the state. Otherwise, we will end up with a state
974                // containing two is_set flags.
975                assert!(is_set_idx <= first_row.len());
976                first_row.resize(is_set_idx, ScalarValue::Null);
977                self.update_with_new_row(first_row);
978            }
979        }
980        Ok(())
981    }
982
983    fn evaluate(&mut self) -> Result<ScalarValue> {
984        Ok(self.first.clone())
985    }
986
987    fn size(&self) -> usize {
988        size_of_val(self) - size_of_val(&self.first)
989            + self.first.size()
990            + ScalarValue::size_of_vec(&self.orderings)
991            - size_of_val(&self.orderings)
992    }
993}
994
995#[user_doc(
996    doc_section(label = "General Functions"),
997    description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
998    syntax_example = "last_value(expression [ORDER BY expression])",
999    sql_example = r#"```sql
1000> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
1001+-----------------------------------------------+
1002| last_value(column_name ORDER BY other_column) |
1003+-----------------------------------------------+
1004| last_element                                  |
1005+-----------------------------------------------+
1006```"#,
1007    standard_argument(name = "expression",)
1008)]
1009#[derive(PartialEq, Eq, Hash)]
1010pub struct LastValue {
1011    signature: Signature,
1012    is_input_pre_ordered: bool,
1013}
1014
1015impl Debug for LastValue {
1016    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1017        f.debug_struct("LastValue")
1018            .field("name", &self.name())
1019            .field("signature", &self.signature)
1020            .field("accumulator", &"<FUNC>")
1021            .finish()
1022    }
1023}
1024
1025impl Default for LastValue {
1026    fn default() -> Self {
1027        Self::new()
1028    }
1029}
1030
1031impl LastValue {
1032    pub fn new() -> Self {
1033        Self {
1034            signature: Signature::any(1, Volatility::Immutable),
1035            is_input_pre_ordered: false,
1036        }
1037    }
1038}
1039
1040impl AggregateUDFImpl for LastValue {
1041    fn as_any(&self) -> &dyn Any {
1042        self
1043    }
1044
1045    fn name(&self) -> &str {
1046        "last_value"
1047    }
1048
1049    fn signature(&self) -> &Signature {
1050        &self.signature
1051    }
1052
1053    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1054        Ok(arg_types[0].clone())
1055    }
1056
1057    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1058        let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else {
1059            return TrivialLastValueAccumulator::try_new(
1060                acc_args.return_field.data_type(),
1061                acc_args.ignore_nulls,
1062            )
1063            .map(|acc| Box::new(acc) as _);
1064        };
1065        let ordering_dtypes = ordering
1066            .iter()
1067            .map(|e| e.expr.data_type(acc_args.schema))
1068            .collect::<Result<Vec<_>>>()?;
1069        Ok(Box::new(LastValueAccumulator::try_new(
1070            acc_args.return_field.data_type(),
1071            &ordering_dtypes,
1072            ordering,
1073            self.is_input_pre_ordered,
1074            acc_args.ignore_nulls,
1075        )?))
1076    }
1077
1078    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1079        let mut fields = vec![Field::new(
1080            format_state_name(args.name, "last_value"),
1081            args.return_field.data_type().clone(),
1082            true,
1083        )
1084        .into()];
1085        fields.extend(args.ordering_fields.iter().cloned());
1086        fields.push(Field::new("is_set", DataType::Boolean, true).into());
1087        Ok(fields)
1088    }
1089
1090    fn with_beneficial_ordering(
1091        self: Arc<Self>,
1092        beneficial_ordering: bool,
1093    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1094        Ok(Some(Arc::new(Self {
1095            signature: self.signature.clone(),
1096            is_input_pre_ordered: beneficial_ordering,
1097        })))
1098    }
1099
1100    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1101        AggregateOrderSensitivity::Beneficial
1102    }
1103
1104    fn reverse_expr(&self) -> ReversedUDAF {
1105        ReversedUDAF::Reversed(first_value_udaf())
1106    }
1107
1108    fn documentation(&self) -> Option<&Documentation> {
1109        self.doc()
1110    }
1111
1112    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1113        use DataType::*;
1114        !args.order_bys.is_empty()
1115            && matches!(
1116                args.return_field.data_type(),
1117                Int8 | Int16
1118                    | Int32
1119                    | Int64
1120                    | UInt8
1121                    | UInt16
1122                    | UInt32
1123                    | UInt64
1124                    | Float16
1125                    | Float32
1126                    | Float64
1127                    | Decimal128(_, _)
1128                    | Decimal256(_, _)
1129                    | Date32
1130                    | Date64
1131                    | Time32(_)
1132                    | Time64(_)
1133                    | Timestamp(_, _)
1134            )
1135    }
1136
1137    fn create_groups_accumulator(
1138        &self,
1139        args: AccumulatorArgs,
1140    ) -> Result<Box<dyn GroupsAccumulator>> {
1141        fn create_accumulator<T>(
1142            args: AccumulatorArgs,
1143        ) -> Result<Box<dyn GroupsAccumulator>>
1144        where
1145            T: ArrowPrimitiveType + Send,
1146        {
1147            let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else {
1148                return internal_err!("Groups accumulator must have an ordering.");
1149            };
1150
1151            let ordering_dtypes = ordering
1152                .iter()
1153                .map(|e| e.expr.data_type(args.schema))
1154                .collect::<Result<Vec<_>>>()?;
1155
1156            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1157                ordering,
1158                args.ignore_nulls,
1159                args.return_field.data_type(),
1160                &ordering_dtypes,
1161                false,
1162            )?))
1163        }
1164
1165        match args.return_field.data_type() {
1166            DataType::Int8 => create_accumulator::<Int8Type>(args),
1167            DataType::Int16 => create_accumulator::<Int16Type>(args),
1168            DataType::Int32 => create_accumulator::<Int32Type>(args),
1169            DataType::Int64 => create_accumulator::<Int64Type>(args),
1170            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1171            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1172            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1173            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1174            DataType::Float16 => create_accumulator::<Float16Type>(args),
1175            DataType::Float32 => create_accumulator::<Float32Type>(args),
1176            DataType::Float64 => create_accumulator::<Float64Type>(args),
1177
1178            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1179            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1180
1181            DataType::Timestamp(TimeUnit::Second, _) => {
1182                create_accumulator::<TimestampSecondType>(args)
1183            }
1184            DataType::Timestamp(TimeUnit::Millisecond, _) => {
1185                create_accumulator::<TimestampMillisecondType>(args)
1186            }
1187            DataType::Timestamp(TimeUnit::Microsecond, _) => {
1188                create_accumulator::<TimestampMicrosecondType>(args)
1189            }
1190            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1191                create_accumulator::<TimestampNanosecondType>(args)
1192            }
1193
1194            DataType::Date32 => create_accumulator::<Date32Type>(args),
1195            DataType::Date64 => create_accumulator::<Date64Type>(args),
1196            DataType::Time32(TimeUnit::Second) => {
1197                create_accumulator::<Time32SecondType>(args)
1198            }
1199            DataType::Time32(TimeUnit::Millisecond) => {
1200                create_accumulator::<Time32MillisecondType>(args)
1201            }
1202
1203            DataType::Time64(TimeUnit::Microsecond) => {
1204                create_accumulator::<Time64MicrosecondType>(args)
1205            }
1206            DataType::Time64(TimeUnit::Nanosecond) => {
1207                create_accumulator::<Time64NanosecondType>(args)
1208            }
1209
1210            _ => {
1211                internal_err!(
1212                    "GroupsAccumulator not supported for last_value({})",
1213                    args.return_field.data_type()
1214                )
1215            }
1216        }
1217    }
1218}
1219
1220/// This accumulator is used when there is no ordering specified for the
1221/// `LAST_VALUE` aggregation. It simply updates the last value it sees
1222/// according to the pre-existing ordering of the input data, and provides
1223/// a fast path for this case without needing to maintain any ordering state.
1224#[derive(Debug)]
1225pub struct TrivialLastValueAccumulator {
1226    last: ScalarValue,
1227    // The `is_set` flag keeps track of whether the last value is finalized.
1228    // This information is used to discriminate genuine NULLs and NULLS that
1229    // occur due to empty partitions.
1230    is_set: bool,
1231    // Ignore null values.
1232    ignore_nulls: bool,
1233}
1234
1235impl TrivialLastValueAccumulator {
1236    /// Creates a new `TrivialLastValueAccumulator` for the given `data_type`.
1237    pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result<Self> {
1238        ScalarValue::try_from(data_type).map(|last| Self {
1239            last,
1240            is_set: false,
1241            ignore_nulls,
1242        })
1243    }
1244}
1245
1246impl Accumulator for TrivialLastValueAccumulator {
1247    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1248        Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)])
1249    }
1250
1251    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1252        // Get last entry according to the pre-existing ordering (0th index):
1253        let value = &values[0];
1254        let mut last_idx = None;
1255        if self.ignore_nulls {
1256            // If ignoring nulls, find the last non-null value.
1257            for i in (0..value.len()).rev() {
1258                if !value.is_null(i) {
1259                    last_idx = Some(i);
1260                    break;
1261                }
1262            }
1263        } else if !value.is_empty() {
1264            // If not ignoring nulls, return the last value if it exists.
1265            last_idx = Some(value.len() - 1);
1266        }
1267        if let Some(last_idx) = last_idx {
1268            let mut row = get_row_at_idx(values, last_idx)?;
1269            self.last = row.swap_remove(0);
1270            self.last.compact();
1271            self.is_set = true;
1272        }
1273        Ok(())
1274    }
1275
1276    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1277        // LAST_VALUE(last1, last2, last3, ...)
1278        // Second index contains is_set flag.
1279        let flags = states[1].as_boolean();
1280        let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
1281        if let Some(last) = filtered_states.last() {
1282            if !last.is_empty() {
1283                self.last = ScalarValue::try_from_array(last, 0)?;
1284                self.is_set = true;
1285            }
1286        }
1287        Ok(())
1288    }
1289
1290    fn evaluate(&mut self) -> Result<ScalarValue> {
1291        Ok(self.last.clone())
1292    }
1293
1294    fn size(&self) -> usize {
1295        size_of_val(self) - size_of_val(&self.last) + self.last.size()
1296    }
1297}
1298
1299#[derive(Debug)]
1300struct LastValueAccumulator {
1301    last: ScalarValue,
1302    // The `is_set` flag keeps track of whether the last value is finalized.
1303    // This information is used to discriminate genuine NULLs and NULLS that
1304    // occur due to empty partitions.
1305    is_set: bool,
1306    // Stores values of the ordering columns corresponding to the first value.
1307    // These values are used during merging of multiple partitions.
1308    orderings: Vec<ScalarValue>,
1309    // Stores the applicable ordering requirement.
1310    ordering_req: LexOrdering,
1311    // Stores whether incoming data already satisfies the ordering requirement.
1312    is_input_pre_ordered: bool,
1313    // Ignore null values.
1314    ignore_nulls: bool,
1315}
1316
1317impl LastValueAccumulator {
1318    /// Creates a new `LastValueAccumulator` for the given `data_type`.
1319    pub fn try_new(
1320        data_type: &DataType,
1321        ordering_dtypes: &[DataType],
1322        ordering_req: LexOrdering,
1323        is_input_pre_ordered: bool,
1324        ignore_nulls: bool,
1325    ) -> Result<Self> {
1326        let orderings = ordering_dtypes
1327            .iter()
1328            .map(ScalarValue::try_from)
1329            .collect::<Result<_>>()?;
1330        ScalarValue::try_from(data_type).map(|last| Self {
1331            last,
1332            is_set: false,
1333            orderings,
1334            ordering_req,
1335            is_input_pre_ordered,
1336            ignore_nulls,
1337        })
1338    }
1339
1340    // Updates state with the values in the given row.
1341    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1342        // Ensure any Array based scalars hold have a single value to reduce memory pressure
1343        for s in row.iter_mut() {
1344            s.compact();
1345        }
1346        self.last = row.remove(0);
1347        self.orderings = row;
1348        self.is_set = true;
1349    }
1350
1351    fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1352        let [value, ordering_values @ ..] = values else {
1353            return internal_err!("Empty row in LAST_VALUE");
1354        };
1355        if self.is_input_pre_ordered {
1356            // Get last entry according to the order of data:
1357            if self.ignore_nulls {
1358                // If ignoring nulls, find the last non-null value.
1359                for i in (0..value.len()).rev() {
1360                    if !value.is_null(i) {
1361                        return Ok(Some(i));
1362                    }
1363                }
1364                return Ok(None);
1365            } else {
1366                return Ok((!value.is_empty()).then_some(value.len() - 1));
1367            }
1368        }
1369
1370        let sort_columns = ordering_values
1371            .iter()
1372            .zip(self.ordering_req.iter())
1373            .map(|(values, req)| SortColumn {
1374                values: Arc::clone(values),
1375                options: Some(req.options),
1376            })
1377            .collect::<Vec<_>>();
1378
1379        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1380        let max_ind = if self.ignore_nulls {
1381            (0..value.len())
1382                .filter(|&index| !(value.is_null(index)))
1383                .max_by(|&a, &b| comparator.compare(a, b))
1384        } else {
1385            (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1386        };
1387
1388        Ok(max_ind)
1389    }
1390}
1391
1392impl Accumulator for LastValueAccumulator {
1393    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1394        let mut result = vec![self.last.clone()];
1395        result.extend(self.orderings.clone());
1396        result.push(ScalarValue::from(self.is_set));
1397        Ok(result)
1398    }
1399
1400    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1401        if let Some(last_idx) = self.get_last_idx(values)? {
1402            let row = get_row_at_idx(values, last_idx)?;
1403            let orderings = &row[1..];
1404            // Update when there is a more recent entry
1405            if !self.is_set
1406                || self.is_input_pre_ordered
1407                || compare_rows(
1408                    &self.orderings,
1409                    orderings,
1410                    &get_sort_options(&self.ordering_req),
1411                )?
1412                .is_lt()
1413            {
1414                self.update_with_new_row(row);
1415            }
1416        }
1417        Ok(())
1418    }
1419
1420    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1421        // LAST_VALUE(last1, last2, last3, ...)
1422        // last index contains is_set flag.
1423        let is_set_idx = states.len() - 1;
1424        let flags = states[is_set_idx].as_boolean();
1425        let filtered_states =
1426            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1427        // 1..is_set_idx range corresponds to ordering section
1428        let sort_columns =
1429            convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
1430
1431        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1432        let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1433
1434        if let Some(last_idx) = max {
1435            let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1436            // When collecting orderings, we exclude the is_set flag from the state.
1437            let last_ordering = &last_row[1..is_set_idx];
1438            let sort_options = get_sort_options(&self.ordering_req);
1439            // Either there is no existing value, or there is a newer (latest)
1440            // version in the new data:
1441            if !self.is_set
1442                || self.is_input_pre_ordered
1443                || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1444            {
1445                // Update with last value in the state. Note that we should exclude the
1446                // is_set flag from the state. Otherwise, we will end up with a state
1447                // containing two is_set flags.
1448                assert!(is_set_idx <= last_row.len());
1449                last_row.resize(is_set_idx, ScalarValue::Null);
1450                self.update_with_new_row(last_row);
1451            }
1452        }
1453        Ok(())
1454    }
1455
1456    fn evaluate(&mut self) -> Result<ScalarValue> {
1457        Ok(self.last.clone())
1458    }
1459
1460    fn size(&self) -> usize {
1461        size_of_val(self) - size_of_val(&self.last)
1462            + self.last.size()
1463            + ScalarValue::size_of_vec(&self.orderings)
1464            - size_of_val(&self.orderings)
1465    }
1466}
1467
1468/// Filters states according to the `is_set` flag at the last column and returns
1469/// the resulting states.
1470fn filter_states_according_to_is_set(
1471    states: &[ArrayRef],
1472    flags: &BooleanArray,
1473) -> Result<Vec<ArrayRef>> {
1474    states
1475        .iter()
1476        .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1477        .collect()
1478}
1479
1480/// Combines array refs and their corresponding orderings to construct `SortColumn`s.
1481fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1482    arrs.iter()
1483        .zip(sort_exprs.iter())
1484        .map(|(item, sort_expr)| SortColumn {
1485            values: Arc::clone(item),
1486            options: Some(sort_expr.options),
1487        })
1488        .collect()
1489}
1490
1491#[cfg(test)]
1492mod tests {
1493    use std::iter::repeat_with;
1494
1495    use arrow::{
1496        array::{Int64Array, ListArray},
1497        compute::SortOptions,
1498        datatypes::Schema,
1499    };
1500    use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1501
1502    use super::*;
1503
1504    #[test]
1505    fn test_first_last_value_value() -> Result<()> {
1506        let mut first_accumulator =
1507            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1508        let mut last_accumulator =
1509            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1510        // first value in the tuple is start of the range (inclusive),
1511        // second value in the tuple is end of the range (exclusive)
1512        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1513        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1514        let arrs = ranges
1515            .into_iter()
1516            .map(|(start, end)| {
1517                Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1518            })
1519            .collect::<Vec<_>>();
1520        for arr in arrs {
1521            // Once first_value is set, accumulator should remember it.
1522            // It shouldn't update first_value for each new batch
1523            first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1524            // last_value should be updated for each new batch.
1525            last_accumulator.update_batch(&[arr])?;
1526        }
1527        // First Value comes from the first value of the first batch which is 0
1528        assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1529        // Last value comes from the last value of the last batch which is 12
1530        assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1531        Ok(())
1532    }
1533
1534    #[test]
1535    fn test_first_last_state_after_merge() -> Result<()> {
1536        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1537        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1538        let arrs = ranges
1539            .into_iter()
1540            .map(|(start, end)| {
1541                Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1542            })
1543            .collect::<Vec<_>>();
1544
1545        // FirstValueAccumulator
1546        let mut first_accumulator =
1547            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1548
1549        first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1550        let state1 = first_accumulator.state()?;
1551
1552        let mut first_accumulator =
1553            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1554        first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1555        let state2 = first_accumulator.state()?;
1556
1557        assert_eq!(state1.len(), state2.len());
1558
1559        let mut states = vec![];
1560
1561        for idx in 0..state1.len() {
1562            states.push(compute::concat(&[
1563                &state1[idx].to_array()?,
1564                &state2[idx].to_array()?,
1565            ])?);
1566        }
1567
1568        let mut first_accumulator =
1569            TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?;
1570        first_accumulator.merge_batch(&states)?;
1571
1572        let merged_state = first_accumulator.state()?;
1573        assert_eq!(merged_state.len(), state1.len());
1574
1575        // LastValueAccumulator
1576        let mut last_accumulator =
1577            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1578
1579        last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1580        let state1 = last_accumulator.state()?;
1581
1582        let mut last_accumulator =
1583            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1584        last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1585        let state2 = last_accumulator.state()?;
1586
1587        assert_eq!(state1.len(), state2.len());
1588
1589        let mut states = vec![];
1590
1591        for idx in 0..state1.len() {
1592            states.push(compute::concat(&[
1593                &state1[idx].to_array()?,
1594                &state2[idx].to_array()?,
1595            ])?);
1596        }
1597
1598        let mut last_accumulator =
1599            TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?;
1600        last_accumulator.merge_batch(&states)?;
1601
1602        let merged_state = last_accumulator.state()?;
1603        assert_eq!(merged_state.len(), state1.len());
1604
1605        Ok(())
1606    }
1607
1608    #[test]
1609    fn test_first_group_acc() -> Result<()> {
1610        let schema = Arc::new(Schema::new(vec![
1611            Field::new("a", DataType::Int64, true),
1612            Field::new("b", DataType::Int64, true),
1613            Field::new("c", DataType::Int64, true),
1614            Field::new("d", DataType::Int32, true),
1615            Field::new("e", DataType::Boolean, true),
1616        ]));
1617
1618        let sort_keys = [PhysicalSortExpr {
1619            expr: col("c", &schema).unwrap(),
1620            options: SortOptions::default(),
1621        }];
1622
1623        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1624            sort_keys.into(),
1625            true,
1626            &DataType::Int64,
1627            &[DataType::Int64],
1628            true,
1629        )?;
1630
1631        let mut val_with_orderings = {
1632            let mut val_with_orderings = Vec::<ArrayRef>::new();
1633
1634            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1635            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1636
1637            val_with_orderings.push(vals);
1638            val_with_orderings.push(orderings);
1639
1640            val_with_orderings
1641        };
1642
1643        group_acc.update_batch(
1644            &val_with_orderings,
1645            &[0, 1, 2, 1],
1646            Some(&BooleanArray::from(vec![true, true, false, true])),
1647            3,
1648        )?;
1649        assert_eq!(
1650            group_acc.size_of_orderings,
1651            group_acc.compute_size_of_orderings()
1652        );
1653
1654        let state = group_acc.state(EmitTo::All)?;
1655
1656        let expected_state: Vec<Arc<dyn Array>> = vec![
1657            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1658            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1659            Arc::new(BooleanArray::from(vec![true, true, false])),
1660        ];
1661        assert_eq!(state, expected_state);
1662
1663        assert_eq!(
1664            group_acc.size_of_orderings,
1665            group_acc.compute_size_of_orderings()
1666        );
1667
1668        group_acc.merge_batch(
1669            &state,
1670            &[0, 1, 2],
1671            Some(&BooleanArray::from(vec![true, false, false])),
1672            3,
1673        )?;
1674
1675        assert_eq!(
1676            group_acc.size_of_orderings,
1677            group_acc.compute_size_of_orderings()
1678        );
1679
1680        val_with_orderings.clear();
1681        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1682        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1683
1684        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1685
1686        let binding = group_acc.evaluate(EmitTo::All)?;
1687        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1688
1689        let expect: PrimitiveArray<Int64Type> =
1690            Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1691
1692        assert_eq!(eval_result, &expect);
1693
1694        assert_eq!(
1695            group_acc.size_of_orderings,
1696            group_acc.compute_size_of_orderings()
1697        );
1698
1699        Ok(())
1700    }
1701
1702    #[test]
1703    fn test_group_acc_size_of_ordering() -> Result<()> {
1704        let schema = Arc::new(Schema::new(vec![
1705            Field::new("a", DataType::Int64, true),
1706            Field::new("b", DataType::Int64, true),
1707            Field::new("c", DataType::Int64, true),
1708            Field::new("d", DataType::Int32, true),
1709            Field::new("e", DataType::Boolean, true),
1710        ]));
1711
1712        let sort_keys = [PhysicalSortExpr {
1713            expr: col("c", &schema).unwrap(),
1714            options: SortOptions::default(),
1715        }];
1716
1717        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1718            sort_keys.into(),
1719            true,
1720            &DataType::Int64,
1721            &[DataType::Int64],
1722            true,
1723        )?;
1724
1725        let val_with_orderings = {
1726            let mut val_with_orderings = Vec::<ArrayRef>::new();
1727
1728            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1729            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1730
1731            val_with_orderings.push(vals);
1732            val_with_orderings.push(orderings);
1733
1734            val_with_orderings
1735        };
1736
1737        for _ in 0..10 {
1738            group_acc.update_batch(
1739                &val_with_orderings,
1740                &[0, 1, 2, 1],
1741                Some(&BooleanArray::from(vec![true, true, false, true])),
1742                100,
1743            )?;
1744            assert_eq!(
1745                group_acc.size_of_orderings,
1746                group_acc.compute_size_of_orderings()
1747            );
1748
1749            group_acc.state(EmitTo::First(2))?;
1750            assert_eq!(
1751                group_acc.size_of_orderings,
1752                group_acc.compute_size_of_orderings()
1753            );
1754
1755            let s = group_acc.state(EmitTo::All)?;
1756            assert_eq!(
1757                group_acc.size_of_orderings,
1758                group_acc.compute_size_of_orderings()
1759            );
1760
1761            group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1762            assert_eq!(
1763                group_acc.size_of_orderings,
1764                group_acc.compute_size_of_orderings()
1765            );
1766
1767            group_acc.evaluate(EmitTo::First(2))?;
1768            assert_eq!(
1769                group_acc.size_of_orderings,
1770                group_acc.compute_size_of_orderings()
1771            );
1772
1773            group_acc.evaluate(EmitTo::All)?;
1774            assert_eq!(
1775                group_acc.size_of_orderings,
1776                group_acc.compute_size_of_orderings()
1777            );
1778        }
1779
1780        Ok(())
1781    }
1782
1783    #[test]
1784    fn test_last_group_acc() -> Result<()> {
1785        let schema = Arc::new(Schema::new(vec![
1786            Field::new("a", DataType::Int64, true),
1787            Field::new("b", DataType::Int64, true),
1788            Field::new("c", DataType::Int64, true),
1789            Field::new("d", DataType::Int32, true),
1790            Field::new("e", DataType::Boolean, true),
1791        ]));
1792
1793        let sort_keys = [PhysicalSortExpr {
1794            expr: col("c", &schema).unwrap(),
1795            options: SortOptions::default(),
1796        }];
1797
1798        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1799            sort_keys.into(),
1800            true,
1801            &DataType::Int64,
1802            &[DataType::Int64],
1803            false,
1804        )?;
1805
1806        let mut val_with_orderings = {
1807            let mut val_with_orderings = Vec::<ArrayRef>::new();
1808
1809            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1810            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1811
1812            val_with_orderings.push(vals);
1813            val_with_orderings.push(orderings);
1814
1815            val_with_orderings
1816        };
1817
1818        group_acc.update_batch(
1819            &val_with_orderings,
1820            &[0, 1, 2, 1],
1821            Some(&BooleanArray::from(vec![true, true, false, true])),
1822            3,
1823        )?;
1824
1825        let state = group_acc.state(EmitTo::All)?;
1826
1827        let expected_state: Vec<Arc<dyn Array>> = vec![
1828            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1829            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1830            Arc::new(BooleanArray::from(vec![true, true, false])),
1831        ];
1832        assert_eq!(state, expected_state);
1833
1834        group_acc.merge_batch(
1835            &state,
1836            &[0, 1, 2],
1837            Some(&BooleanArray::from(vec![true, false, false])),
1838            3,
1839        )?;
1840
1841        val_with_orderings.clear();
1842        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1843        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1844
1845        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1846
1847        let binding = group_acc.evaluate(EmitTo::All)?;
1848        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1849
1850        let expect: PrimitiveArray<Int64Type> =
1851            Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1852
1853        assert_eq!(eval_result, &expect);
1854
1855        Ok(())
1856    }
1857
1858    #[test]
1859    fn test_first_list_acc_size() -> Result<()> {
1860        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1861            let mut first_accumulator = TrivialFirstValueAccumulator::try_new(
1862                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1863                false,
1864            )?;
1865
1866            first_accumulator.update_batch(values)?;
1867
1868            Ok(first_accumulator.size())
1869        }
1870
1871        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1872            repeat_with(|| Some(vec![Some(1)])).take(10000),
1873        );
1874        let batch2 =
1875            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1876
1877        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1878        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1879        assert_eq!(size1, size2);
1880
1881        Ok(())
1882    }
1883
1884    #[test]
1885    fn test_last_list_acc_size() -> Result<()> {
1886        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1887            let mut last_accumulator = TrivialLastValueAccumulator::try_new(
1888                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1889                false,
1890            )?;
1891
1892            last_accumulator.update_batch(values)?;
1893
1894            Ok(last_accumulator.size())
1895        }
1896
1897        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1898            repeat_with(|| Some(vec![Some(1)])).take(10000),
1899        );
1900        let batch2 =
1901            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1902
1903        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1904        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1905        assert_eq!(size1, size2);
1906
1907        Ok(())
1908    }
1909}