datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
21use datafusion_macros::user_doc;
22use datafusion_physical_expr::expressions;
23use std::collections::HashSet;
24use std::fmt::Debug;
25use std::mem::{size_of, size_of_val};
26use std::ops::BitAnd;
27use std::sync::Arc;
28
29use arrow::{
30    array::{ArrayRef, AsArray},
31    compute,
32    datatypes::{
33        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
34        Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35        Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
36        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
38        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
39    },
40};
41
42use arrow::{
43    array::{Array, BooleanArray, Int64Array, PrimitiveArray},
44    buffer::BooleanBuffer,
45};
46use datafusion_common::{
47    downcast_value, internal_err, not_impl_err, Result, ScalarValue,
48};
49use datafusion_expr::function::StateFieldsArgs;
50use datafusion_expr::{
51    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
52    Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
53};
54use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
55use datafusion_functions_aggregate_common::aggregate::count_distinct::{
56    BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
57    PrimitiveDistinctCountAccumulator,
58};
59use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
60use datafusion_physical_expr_common::binary_map::OutputType;
61
62use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
63make_udaf_expr_and_func!(
64    Count,
65    count,
66    expr,
67    "Count the number of non-null values in the column",
68    count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73        count_udaf(),
74        vec![expr],
75        true,
76        None,
77        None,
78        None,
79    ))
80}
81
82/// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
83pub fn count_all() -> Expr {
84    count(Expr::Literal(COUNT_STAR_EXPANSION))
85}
86
87#[user_doc(
88    doc_section(label = "General Functions"),
89    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
90    syntax_example = "count(expression)",
91    sql_example = r#"```sql
92> SELECT count(column_name) FROM table_name;
93+-----------------------+
94| count(column_name)     |
95+-----------------------+
96| 100                   |
97+-----------------------+
98
99> SELECT count(*) FROM table_name;
100+------------------+
101| count(*)         |
102+------------------+
103| 120              |
104+------------------+
105```"#,
106    standard_argument(name = "expression",)
107)]
108pub struct Count {
109    signature: Signature,
110}
111
112impl Debug for Count {
113    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114        f.debug_struct("Count")
115            .field("name", &self.name())
116            .field("signature", &self.signature)
117            .finish()
118    }
119}
120
121impl Default for Count {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl Count {
128    pub fn new() -> Self {
129        Self {
130            signature: Signature::one_of(
131                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
132                Volatility::Immutable,
133            ),
134        }
135    }
136}
137
138impl AggregateUDFImpl for Count {
139    fn as_any(&self) -> &dyn std::any::Any {
140        self
141    }
142
143    fn name(&self) -> &str {
144        "count"
145    }
146
147    fn signature(&self) -> &Signature {
148        &self.signature
149    }
150
151    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
152        Ok(DataType::Int64)
153    }
154
155    fn is_nullable(&self) -> bool {
156        false
157    }
158
159    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
160        if args.is_distinct {
161            Ok(vec![Field::new_list(
162                format_state_name(args.name, "count distinct"),
163                // See COMMENTS.md to understand why nullable is set to true
164                Field::new_list_field(args.input_types[0].clone(), true),
165                false,
166            )])
167        } else {
168            Ok(vec![Field::new(
169                format_state_name(args.name, "count"),
170                DataType::Int64,
171                false,
172            )])
173        }
174    }
175
176    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
177        if !acc_args.is_distinct {
178            return Ok(Box::new(CountAccumulator::new()));
179        }
180
181        if acc_args.exprs.len() > 1 {
182            return not_impl_err!("COUNT DISTINCT with multiple arguments");
183        }
184
185        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
186        Ok(match data_type {
187            // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
188            DataType::Int8 => Box::new(
189                PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
190            ),
191            DataType::Int16 => Box::new(
192                PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
193            ),
194            DataType::Int32 => Box::new(
195                PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
196            ),
197            DataType::Int64 => Box::new(
198                PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
199            ),
200            DataType::UInt8 => Box::new(
201                PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
202            ),
203            DataType::UInt16 => Box::new(
204                PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
205            ),
206            DataType::UInt32 => Box::new(
207                PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
208            ),
209            DataType::UInt64 => Box::new(
210                PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
211            ),
212            DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
213                Decimal128Type,
214            >::new(data_type)),
215            DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
216                Decimal256Type,
217            >::new(data_type)),
218
219            DataType::Date32 => Box::new(
220                PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
221            ),
222            DataType::Date64 => Box::new(
223                PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
224            ),
225            DataType::Time32(TimeUnit::Millisecond) => Box::new(
226                PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
227                    data_type,
228                ),
229            ),
230            DataType::Time32(TimeUnit::Second) => Box::new(
231                PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
232            ),
233            DataType::Time64(TimeUnit::Microsecond) => Box::new(
234                PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
235                    data_type,
236                ),
237            ),
238            DataType::Time64(TimeUnit::Nanosecond) => Box::new(
239                PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
240            ),
241            DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
242                PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
243                    data_type,
244                ),
245            ),
246            DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
247                PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
248                    data_type,
249                ),
250            ),
251            DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
252                PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
253                    data_type,
254                ),
255            ),
256            DataType::Timestamp(TimeUnit::Second, _) => Box::new(
257                PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
258            ),
259
260            DataType::Float16 => {
261                Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
262            }
263            DataType::Float32 => {
264                Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
265            }
266            DataType::Float64 => {
267                Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
268            }
269
270            DataType::Utf8 => {
271                Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
272            }
273            DataType::Utf8View => {
274                Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
275            }
276            DataType::LargeUtf8 => {
277                Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
278            }
279            DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
280                OutputType::Binary,
281            )),
282            DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
283                OutputType::BinaryView,
284            )),
285            DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
286                OutputType::Binary,
287            )),
288
289            // Use the generic accumulator based on `ScalarValue` for all other types
290            _ => Box::new(DistinctCountAccumulator {
291                values: HashSet::default(),
292                state_data_type: data_type.clone(),
293            }),
294        })
295    }
296
297    fn aliases(&self) -> &[String] {
298        &[]
299    }
300
301    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
302        // groups accumulator only supports `COUNT(c1)`, not
303        // `COUNT(c1, c2)`, etc
304        if args.is_distinct {
305            return false;
306        }
307        args.exprs.len() == 1
308    }
309
310    fn create_groups_accumulator(
311        &self,
312        _args: AccumulatorArgs,
313    ) -> Result<Box<dyn GroupsAccumulator>> {
314        // instantiate specialized accumulator
315        Ok(Box::new(CountGroupsAccumulator::new()))
316    }
317
318    fn reverse_expr(&self) -> ReversedUDAF {
319        ReversedUDAF::Identical
320    }
321
322    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
323        Ok(ScalarValue::Int64(Some(0)))
324    }
325
326    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
327        if statistics_args.is_distinct {
328            return None;
329        }
330        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
331            if statistics_args.exprs.len() == 1 {
332                // TODO optimize with exprs other than Column
333                if let Some(col_expr) = statistics_args.exprs[0]
334                    .as_any()
335                    .downcast_ref::<expressions::Column>()
336                {
337                    let current_val = &statistics_args.statistics.column_statistics
338                        [col_expr.index()]
339                    .null_count;
340                    if let &Precision::Exact(val) = current_val {
341                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
342                    }
343                } else if let Some(lit_expr) = statistics_args.exprs[0]
344                    .as_any()
345                    .downcast_ref::<expressions::Literal>()
346                {
347                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
348                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
349                    }
350                }
351            }
352        }
353        None
354    }
355
356    fn documentation(&self) -> Option<&Documentation> {
357        self.doc()
358    }
359
360    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
361        // `COUNT` is monotonically increasing as it always increases or stays
362        // the same as new values are seen.
363        SetMonotonicity::Increasing
364    }
365}
366
367#[derive(Debug)]
368struct CountAccumulator {
369    count: i64,
370}
371
372impl CountAccumulator {
373    /// new count accumulator
374    pub fn new() -> Self {
375        Self { count: 0 }
376    }
377}
378
379impl Accumulator for CountAccumulator {
380    fn state(&mut self) -> Result<Vec<ScalarValue>> {
381        Ok(vec![ScalarValue::Int64(Some(self.count))])
382    }
383
384    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
385        let array = &values[0];
386        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
387        Ok(())
388    }
389
390    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
391        let array = &values[0];
392        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
393        Ok(())
394    }
395
396    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
397        let counts = downcast_value!(states[0], Int64Array);
398        let delta = &compute::sum(counts);
399        if let Some(d) = delta {
400            self.count += *d;
401        }
402        Ok(())
403    }
404
405    fn evaluate(&mut self) -> Result<ScalarValue> {
406        Ok(ScalarValue::Int64(Some(self.count)))
407    }
408
409    fn supports_retract_batch(&self) -> bool {
410        true
411    }
412
413    fn size(&self) -> usize {
414        size_of_val(self)
415    }
416}
417
418/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
419/// Stores values as native types, and does overflow checking
420///
421/// Unlike most other accumulators, COUNT never produces NULLs. If no
422/// non-null values are seen in any group the output is 0. Thus, this
423/// accumulator has no additional null or seen filter tracking.
424#[derive(Debug)]
425struct CountGroupsAccumulator {
426    /// Count per group.
427    ///
428    /// Note this is an i64 and not a u64 (or usize) because the
429    /// output type of count is `DataType::Int64`. Thus by using `i64`
430    /// for the counts, the output [`Int64Array`] can be created
431    /// without copy.
432    counts: Vec<i64>,
433}
434
435impl CountGroupsAccumulator {
436    pub fn new() -> Self {
437        Self { counts: vec![] }
438    }
439}
440
441impl GroupsAccumulator for CountGroupsAccumulator {
442    fn update_batch(
443        &mut self,
444        values: &[ArrayRef],
445        group_indices: &[usize],
446        opt_filter: Option<&BooleanArray>,
447        total_num_groups: usize,
448    ) -> Result<()> {
449        assert_eq!(values.len(), 1, "single argument to update_batch");
450        let values = &values[0];
451
452        // Add one to each group's counter for each non null, non
453        // filtered value
454        self.counts.resize(total_num_groups, 0);
455        accumulate_indices(
456            group_indices,
457            values.logical_nulls().as_ref(),
458            opt_filter,
459            |group_index| {
460                self.counts[group_index] += 1;
461            },
462        );
463
464        Ok(())
465    }
466
467    fn merge_batch(
468        &mut self,
469        values: &[ArrayRef],
470        group_indices: &[usize],
471        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
472        _opt_filter: Option<&BooleanArray>,
473        total_num_groups: usize,
474    ) -> Result<()> {
475        assert_eq!(values.len(), 1, "one argument to merge_batch");
476        // first batch is counts, second is partial sums
477        let partial_counts = values[0].as_primitive::<Int64Type>();
478
479        // intermediate counts are always created as non null
480        assert_eq!(partial_counts.null_count(), 0);
481        let partial_counts = partial_counts.values();
482
483        // Adds the counts with the partial counts
484        self.counts.resize(total_num_groups, 0);
485        group_indices.iter().zip(partial_counts.iter()).for_each(
486            |(&group_index, partial_count)| {
487                self.counts[group_index] += partial_count;
488            },
489        );
490
491        Ok(())
492    }
493
494    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
495        let counts = emit_to.take_needed(&mut self.counts);
496
497        // Count is always non null (null inputs just don't contribute to the overall values)
498        let nulls = None;
499        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
500
501        Ok(Arc::new(array))
502    }
503
504    // return arrays for counts
505    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
506        let counts = emit_to.take_needed(&mut self.counts);
507        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
508        Ok(vec![Arc::new(counts) as ArrayRef])
509    }
510
511    /// Converts an input batch directly to a state batch
512    ///
513    /// The state of `COUNT` is always a single Int64Array:
514    /// * `1` (for non-null, non filtered values)
515    /// * `0` (for null values)
516    fn convert_to_state(
517        &self,
518        values: &[ArrayRef],
519        opt_filter: Option<&BooleanArray>,
520    ) -> Result<Vec<ArrayRef>> {
521        let values = &values[0];
522
523        let state_array = match (values.logical_nulls(), opt_filter) {
524            (None, None) => {
525                // In case there is no nulls in input and no filter, returning array of 1
526                Arc::new(Int64Array::from_value(1, values.len()))
527            }
528            (Some(nulls), None) => {
529                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
530                // of input array to Int64
531                let nulls = BooleanArray::new(nulls.into_inner(), None);
532                compute::cast(&nulls, &DataType::Int64)?
533            }
534            (None, Some(filter)) => {
535                // If there is only filter
536                // - applying filter null mask to filter values by bitand filter values and nulls buffers
537                //   (using buffers guarantees absence of nulls in result)
538                // - casting result of bitand to Int64 array
539                let (filter_values, filter_nulls) = filter.clone().into_parts();
540
541                let state_buf = match filter_nulls {
542                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
543                    None => filter_values,
544                };
545
546                let boolean_state = BooleanArray::new(state_buf, None);
547                compute::cast(&boolean_state, &DataType::Int64)?
548            }
549            (Some(nulls), Some(filter)) => {
550                // For both input nulls and filter
551                // - applying filter null mask to filter values by bitand filter values and nulls buffers
552                //   (using buffers guarantees absence of nulls in result)
553                // - applying values null mask to filter buffer by another bitand on filter result and
554                //   nulls from input values
555                // - casting result to Int64 array
556                let (filter_values, filter_nulls) = filter.clone().into_parts();
557
558                let filter_buf = match filter_nulls {
559                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
560                    None => filter_values,
561                };
562                let state_buf = &filter_buf & nulls.inner();
563
564                let boolean_state = BooleanArray::new(state_buf, None);
565                compute::cast(&boolean_state, &DataType::Int64)?
566            }
567        };
568
569        Ok(vec![state_array])
570    }
571
572    fn supports_convert_to_state(&self) -> bool {
573        true
574    }
575
576    fn size(&self) -> usize {
577        self.counts.capacity() * size_of::<usize>()
578    }
579}
580
581/// count null values for multiple columns
582/// for each row if one column value is null, then null_count + 1
583fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
584    if values.len() > 1 {
585        let result_bool_buf: Option<BooleanBuffer> = values
586            .iter()
587            .map(|a| a.logical_nulls())
588            .fold(None, |acc, b| match (acc, b) {
589                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
590                (Some(acc), None) => Some(acc),
591                (None, Some(b)) => Some(b.into_inner()),
592                _ => None,
593            });
594        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
595    } else {
596        values[0]
597            .logical_nulls()
598            .map_or(0, |nulls| nulls.null_count())
599    }
600}
601
602/// General purpose distinct accumulator that works for any DataType by using
603/// [`ScalarValue`].
604///
605/// It stores intermediate results as a `ListArray`
606///
607/// Note that many types have specialized accumulators that are (much)
608/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
609/// [`BytesDistinctCountAccumulator`]
610#[derive(Debug)]
611struct DistinctCountAccumulator {
612    values: HashSet<ScalarValue, RandomState>,
613    state_data_type: DataType,
614}
615
616impl DistinctCountAccumulator {
617    // calculating the size for fixed length values, taking first batch size *
618    // number of batches This method is faster than .full_size(), however it is
619    // not suitable for variable length values like strings or complex types
620    fn fixed_size(&self) -> usize {
621        size_of_val(self)
622            + (size_of::<ScalarValue>() * self.values.capacity())
623            + self
624                .values
625                .iter()
626                .next()
627                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
628                .unwrap_or(0)
629            + size_of::<DataType>()
630    }
631
632    // calculates the size as accurately as possible. Note that calling this
633    // method is expensive
634    fn full_size(&self) -> usize {
635        size_of_val(self)
636            + (size_of::<ScalarValue>() * self.values.capacity())
637            + self
638                .values
639                .iter()
640                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
641                .sum::<usize>()
642            + size_of::<DataType>()
643    }
644}
645
646impl Accumulator for DistinctCountAccumulator {
647    /// Returns the distinct values seen so far as (one element) ListArray.
648    fn state(&mut self) -> Result<Vec<ScalarValue>> {
649        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
650        let arr =
651            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
652        Ok(vec![ScalarValue::List(arr)])
653    }
654
655    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
656        if values.is_empty() {
657            return Ok(());
658        }
659
660        let arr = &values[0];
661        if arr.data_type() == &DataType::Null {
662            return Ok(());
663        }
664
665        (0..arr.len()).try_for_each(|index| {
666            if !arr.is_null(index) {
667                let scalar = ScalarValue::try_from_array(arr, index)?;
668                self.values.insert(scalar);
669            }
670            Ok(())
671        })
672    }
673
674    /// Merges multiple sets of distinct values into the current set.
675    ///
676    /// The input to this function is a `ListArray` with **multiple** rows,
677    /// where each row contains the values from a partial aggregate's phase (e.g.
678    /// the result of calling `Self::state` on multiple accumulators).
679    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
680        if states.is_empty() {
681            return Ok(());
682        }
683        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
684        let array = &states[0];
685        let list_array = array.as_list::<i32>();
686        for inner_array in list_array.iter() {
687            let Some(inner_array) = inner_array else {
688                return internal_err!(
689                    "Intermediate results of COUNT DISTINCT should always be non null"
690                );
691            };
692            self.update_batch(&[inner_array])?;
693        }
694        Ok(())
695    }
696
697    fn evaluate(&mut self) -> Result<ScalarValue> {
698        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
699    }
700
701    fn size(&self) -> usize {
702        match &self.state_data_type {
703            DataType::Boolean | DataType::Null => self.fixed_size(),
704            d if d.is_primitive() => self.fixed_size(),
705            _ => self.full_size(),
706        }
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use arrow::array::NullArray;
714
715    #[test]
716    fn count_accumulator_nulls() -> Result<()> {
717        let mut accumulator = CountAccumulator::new();
718        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
719        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
720        Ok(())
721    }
722}