datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_expr::expr::WindowFunction;
21use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
22use datafusion_macros::user_doc;
23use datafusion_physical_expr::expressions;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::mem::{size_of, size_of_val};
27use std::ops::BitAnd;
28use std::sync::Arc;
29
30use arrow::{
31    array::{ArrayRef, AsArray},
32    compute,
33    datatypes::{
34        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
35        Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
36        Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
37        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
38        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
39        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40    },
41};
42
43use arrow::datatypes::FieldRef;
44use arrow::{
45    array::{Array, BooleanArray, Int64Array, PrimitiveArray},
46    buffer::BooleanBuffer,
47};
48use datafusion_common::{
49    downcast_value, internal_err, not_impl_err, Result, ScalarValue,
50};
51use datafusion_expr::function::StateFieldsArgs;
52use datafusion_expr::{
53    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
54    Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
55};
56use datafusion_expr::{
57    Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
58};
59use datafusion_functions_aggregate_common::aggregate::count_distinct::{
60    BytesDistinctCountAccumulator, DictionaryCountAccumulator,
61    FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
62};
63use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
64use datafusion_physical_expr_common::binary_map::OutputType;
65
66use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
67make_udaf_expr_and_func!(
68    Count,
69    count,
70    expr,
71    "Count the number of non-null values in the column",
72    count_udaf
73);
74
75pub fn count_distinct(expr: Expr) -> Expr {
76    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
77        count_udaf(),
78        vec![expr],
79        true,
80        None,
81        None,
82        None,
83    ))
84}
85
86/// Creates aggregation to count all rows.
87///
88/// In SQL this is `SELECT COUNT(*) ... `
89///
90/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
91/// aliased to a column named `"count(*)"` for backward compatibility.
92///
93/// Example
94/// ```
95/// # use datafusion_functions_aggregate::count::count_all;
96/// # use datafusion_expr::col;
97/// // create `count(*)` expression
98/// let expr = count_all();
99/// assert_eq!(expr.schema_name().to_string(), "count(*)");
100/// // if you need to refer to this column, use the `schema_name` function
101/// let expr = col(expr.schema_name().to_string());
102/// ```
103pub fn count_all() -> Expr {
104    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
105}
106
107/// Creates window aggregation to count all rows.
108///
109/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
110///
111/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
112///
113/// Example
114/// ```
115/// # use datafusion_functions_aggregate::count::count_all_window;
116/// # use datafusion_expr::col;
117/// // create `count(*)` OVER ... window function expression
118/// let expr = count_all_window();
119/// assert_eq!(
120///   expr.schema_name().to_string(),
121///   "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
122/// );
123/// // if you need to refer to this column, use the `schema_name` function
124/// let expr = col(expr.schema_name().to_string());
125/// ```
126pub fn count_all_window() -> Expr {
127    Expr::from(WindowFunction::new(
128        WindowFunctionDefinition::AggregateUDF(count_udaf()),
129        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
130    ))
131}
132
133#[user_doc(
134    doc_section(label = "General Functions"),
135    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
136    syntax_example = "count(expression)",
137    sql_example = r#"```sql
138> SELECT count(column_name) FROM table_name;
139+-----------------------+
140| count(column_name)     |
141+-----------------------+
142| 100                   |
143+-----------------------+
144
145> SELECT count(*) FROM table_name;
146+------------------+
147| count(*)         |
148+------------------+
149| 120              |
150+------------------+
151```"#,
152    standard_argument(name = "expression",)
153)]
154pub struct Count {
155    signature: Signature,
156}
157
158impl Debug for Count {
159    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160        f.debug_struct("Count")
161            .field("name", &self.name())
162            .field("signature", &self.signature)
163            .finish()
164    }
165}
166
167impl Default for Count {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl Count {
174    pub fn new() -> Self {
175        Self {
176            signature: Signature::one_of(
177                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
178                Volatility::Immutable,
179            ),
180        }
181    }
182}
183fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
184    match data_type {
185        // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
186        DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
187            data_type,
188        )),
189        DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
190            data_type,
191        )),
192        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
193            data_type,
194        )),
195        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
196            data_type,
197        )),
198        DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
199            data_type,
200        )),
201        DataType::UInt16 => Box::new(
202            PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
203        ),
204        DataType::UInt32 => Box::new(
205            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
206        ),
207        DataType::UInt64 => Box::new(
208            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
209        ),
210        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211            Decimal128Type,
212        >::new(data_type)),
213        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
214            Decimal256Type,
215        >::new(data_type)),
216
217        DataType::Date32 => Box::new(
218            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
219        ),
220        DataType::Date64 => Box::new(
221            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
222        ),
223        DataType::Time32(TimeUnit::Millisecond) => Box::new(
224            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
225        ),
226        DataType::Time32(TimeUnit::Second) => Box::new(
227            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
228        ),
229        DataType::Time64(TimeUnit::Microsecond) => Box::new(
230            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
231        ),
232        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
233            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
234        ),
235        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
236            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
237        ),
238        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
239            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
240        ),
241        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
242            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
243        ),
244        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
245            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
246        ),
247
248        DataType::Float16 => {
249            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
250        }
251        DataType::Float32 => {
252            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
253        }
254        DataType::Float64 => {
255            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
256        }
257
258        DataType::Utf8 => {
259            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
260        }
261        DataType::Utf8View => {
262            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
263        }
264        DataType::LargeUtf8 => {
265            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
266        }
267        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
268            OutputType::Binary,
269        )),
270        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
271            OutputType::BinaryView,
272        )),
273        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
274            OutputType::Binary,
275        )),
276
277        // Use the generic accumulator based on `ScalarValue` for all other types
278        _ => Box::new(DistinctCountAccumulator {
279            values: HashSet::default(),
280            state_data_type: data_type.clone(),
281        }),
282    }
283}
284
285impl AggregateUDFImpl for Count {
286    fn as_any(&self) -> &dyn std::any::Any {
287        self
288    }
289
290    fn name(&self) -> &str {
291        "count"
292    }
293
294    fn signature(&self) -> &Signature {
295        &self.signature
296    }
297
298    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
299        Ok(DataType::Int64)
300    }
301
302    fn is_nullable(&self) -> bool {
303        false
304    }
305
306    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
307        if args.is_distinct {
308            let dtype: DataType = match &args.input_fields[0].data_type() {
309                DataType::Dictionary(_, values_type) => (**values_type).clone(),
310                &dtype => dtype.clone(),
311            };
312
313            Ok(vec![Field::new_list(
314                format_state_name(args.name, "count distinct"),
315                // See COMMENTS.md to understand why nullable is set to true
316                Field::new_list_field(dtype, true),
317                false,
318            )
319            .into()])
320        } else {
321            Ok(vec![Field::new(
322                format_state_name(args.name, "count"),
323                DataType::Int64,
324                false,
325            )
326            .into()])
327        }
328    }
329
330    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
331        if !acc_args.is_distinct {
332            return Ok(Box::new(CountAccumulator::new()));
333        }
334
335        if acc_args.exprs.len() > 1 {
336            return not_impl_err!("COUNT DISTINCT with multiple arguments");
337        }
338
339        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
340
341        Ok(match data_type {
342            DataType::Dictionary(_, values_type) => {
343                let inner = get_count_accumulator(values_type);
344                Box::new(DictionaryCountAccumulator::new(inner))
345            }
346            _ => get_count_accumulator(data_type),
347        })
348    }
349
350    fn aliases(&self) -> &[String] {
351        &[]
352    }
353
354    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
355        // groups accumulator only supports `COUNT(c1)`, not
356        // `COUNT(c1, c2)`, etc
357        if args.is_distinct {
358            return false;
359        }
360        args.exprs.len() == 1
361    }
362
363    fn create_groups_accumulator(
364        &self,
365        _args: AccumulatorArgs,
366    ) -> Result<Box<dyn GroupsAccumulator>> {
367        // instantiate specialized accumulator
368        Ok(Box::new(CountGroupsAccumulator::new()))
369    }
370
371    fn reverse_expr(&self) -> ReversedUDAF {
372        ReversedUDAF::Identical
373    }
374
375    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
376        Ok(ScalarValue::Int64(Some(0)))
377    }
378
379    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
380        if statistics_args.is_distinct {
381            return None;
382        }
383        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
384            if statistics_args.exprs.len() == 1 {
385                // TODO optimize with exprs other than Column
386                if let Some(col_expr) = statistics_args.exprs[0]
387                    .as_any()
388                    .downcast_ref::<expressions::Column>()
389                {
390                    let current_val = &statistics_args.statistics.column_statistics
391                        [col_expr.index()]
392                    .null_count;
393                    if let &Precision::Exact(val) = current_val {
394                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
395                    }
396                } else if let Some(lit_expr) = statistics_args.exprs[0]
397                    .as_any()
398                    .downcast_ref::<expressions::Literal>()
399                {
400                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
401                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
402                    }
403                }
404            }
405        }
406        None
407    }
408
409    fn documentation(&self) -> Option<&Documentation> {
410        self.doc()
411    }
412
413    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
414        // `COUNT` is monotonically increasing as it always increases or stays
415        // the same as new values are seen.
416        SetMonotonicity::Increasing
417    }
418}
419
420#[derive(Debug)]
421struct CountAccumulator {
422    count: i64,
423}
424
425impl CountAccumulator {
426    /// new count accumulator
427    pub fn new() -> Self {
428        Self { count: 0 }
429    }
430}
431
432impl Accumulator for CountAccumulator {
433    fn state(&mut self) -> Result<Vec<ScalarValue>> {
434        Ok(vec![ScalarValue::Int64(Some(self.count))])
435    }
436
437    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
438        let array = &values[0];
439        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
440        Ok(())
441    }
442
443    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
444        let array = &values[0];
445        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
446        Ok(())
447    }
448
449    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
450        let counts = downcast_value!(states[0], Int64Array);
451        let delta = &compute::sum(counts);
452        if let Some(d) = delta {
453            self.count += *d;
454        }
455        Ok(())
456    }
457
458    fn evaluate(&mut self) -> Result<ScalarValue> {
459        Ok(ScalarValue::Int64(Some(self.count)))
460    }
461
462    fn supports_retract_batch(&self) -> bool {
463        true
464    }
465
466    fn size(&self) -> usize {
467        size_of_val(self)
468    }
469}
470
471/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
472/// Stores values as native types, and does overflow checking
473///
474/// Unlike most other accumulators, COUNT never produces NULLs. If no
475/// non-null values are seen in any group the output is 0. Thus, this
476/// accumulator has no additional null or seen filter tracking.
477#[derive(Debug)]
478struct CountGroupsAccumulator {
479    /// Count per group.
480    ///
481    /// Note this is an i64 and not a u64 (or usize) because the
482    /// output type of count is `DataType::Int64`. Thus by using `i64`
483    /// for the counts, the output [`Int64Array`] can be created
484    /// without copy.
485    counts: Vec<i64>,
486}
487
488impl CountGroupsAccumulator {
489    pub fn new() -> Self {
490        Self { counts: vec![] }
491    }
492}
493
494impl GroupsAccumulator for CountGroupsAccumulator {
495    fn update_batch(
496        &mut self,
497        values: &[ArrayRef],
498        group_indices: &[usize],
499        opt_filter: Option<&BooleanArray>,
500        total_num_groups: usize,
501    ) -> Result<()> {
502        assert_eq!(values.len(), 1, "single argument to update_batch");
503        let values = &values[0];
504
505        // Add one to each group's counter for each non null, non
506        // filtered value
507        self.counts.resize(total_num_groups, 0);
508        accumulate_indices(
509            group_indices,
510            values.logical_nulls().as_ref(),
511            opt_filter,
512            |group_index| {
513                self.counts[group_index] += 1;
514            },
515        );
516
517        Ok(())
518    }
519
520    fn merge_batch(
521        &mut self,
522        values: &[ArrayRef],
523        group_indices: &[usize],
524        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
525        _opt_filter: Option<&BooleanArray>,
526        total_num_groups: usize,
527    ) -> Result<()> {
528        assert_eq!(values.len(), 1, "one argument to merge_batch");
529        // first batch is counts, second is partial sums
530        let partial_counts = values[0].as_primitive::<Int64Type>();
531
532        // intermediate counts are always created as non null
533        assert_eq!(partial_counts.null_count(), 0);
534        let partial_counts = partial_counts.values();
535
536        // Adds the counts with the partial counts
537        self.counts.resize(total_num_groups, 0);
538        group_indices.iter().zip(partial_counts.iter()).for_each(
539            |(&group_index, partial_count)| {
540                self.counts[group_index] += partial_count;
541            },
542        );
543
544        Ok(())
545    }
546
547    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
548        let counts = emit_to.take_needed(&mut self.counts);
549
550        // Count is always non null (null inputs just don't contribute to the overall values)
551        let nulls = None;
552        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
553
554        Ok(Arc::new(array))
555    }
556
557    // return arrays for counts
558    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
559        let counts = emit_to.take_needed(&mut self.counts);
560        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
561        Ok(vec![Arc::new(counts) as ArrayRef])
562    }
563
564    /// Converts an input batch directly to a state batch
565    ///
566    /// The state of `COUNT` is always a single Int64Array:
567    /// * `1` (for non-null, non filtered values)
568    /// * `0` (for null values)
569    fn convert_to_state(
570        &self,
571        values: &[ArrayRef],
572        opt_filter: Option<&BooleanArray>,
573    ) -> Result<Vec<ArrayRef>> {
574        let values = &values[0];
575
576        let state_array = match (values.logical_nulls(), opt_filter) {
577            (None, None) => {
578                // In case there is no nulls in input and no filter, returning array of 1
579                Arc::new(Int64Array::from_value(1, values.len()))
580            }
581            (Some(nulls), None) => {
582                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
583                // of input array to Int64
584                let nulls = BooleanArray::new(nulls.into_inner(), None);
585                compute::cast(&nulls, &DataType::Int64)?
586            }
587            (None, Some(filter)) => {
588                // If there is only filter
589                // - applying filter null mask to filter values by bitand filter values and nulls buffers
590                //   (using buffers guarantees absence of nulls in result)
591                // - casting result of bitand to Int64 array
592                let (filter_values, filter_nulls) = filter.clone().into_parts();
593
594                let state_buf = match filter_nulls {
595                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
596                    None => filter_values,
597                };
598
599                let boolean_state = BooleanArray::new(state_buf, None);
600                compute::cast(&boolean_state, &DataType::Int64)?
601            }
602            (Some(nulls), Some(filter)) => {
603                // For both input nulls and filter
604                // - applying filter null mask to filter values by bitand filter values and nulls buffers
605                //   (using buffers guarantees absence of nulls in result)
606                // - applying values null mask to filter buffer by another bitand on filter result and
607                //   nulls from input values
608                // - casting result to Int64 array
609                let (filter_values, filter_nulls) = filter.clone().into_parts();
610
611                let filter_buf = match filter_nulls {
612                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
613                    None => filter_values,
614                };
615                let state_buf = &filter_buf & nulls.inner();
616
617                let boolean_state = BooleanArray::new(state_buf, None);
618                compute::cast(&boolean_state, &DataType::Int64)?
619            }
620        };
621
622        Ok(vec![state_array])
623    }
624
625    fn supports_convert_to_state(&self) -> bool {
626        true
627    }
628
629    fn size(&self) -> usize {
630        self.counts.capacity() * size_of::<usize>()
631    }
632}
633
634/// count null values for multiple columns
635/// for each row if one column value is null, then null_count + 1
636fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
637    if values.len() > 1 {
638        let result_bool_buf: Option<BooleanBuffer> = values
639            .iter()
640            .map(|a| a.logical_nulls())
641            .fold(None, |acc, b| match (acc, b) {
642                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
643                (Some(acc), None) => Some(acc),
644                (None, Some(b)) => Some(b.into_inner()),
645                _ => None,
646            });
647        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
648    } else {
649        values[0]
650            .logical_nulls()
651            .map_or(0, |nulls| nulls.null_count())
652    }
653}
654
655/// General purpose distinct accumulator that works for any DataType by using
656/// [`ScalarValue`].
657///
658/// It stores intermediate results as a `ListArray`
659///
660/// Note that many types have specialized accumulators that are (much)
661/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
662/// [`BytesDistinctCountAccumulator`]
663#[derive(Debug)]
664struct DistinctCountAccumulator {
665    values: HashSet<ScalarValue, RandomState>,
666    state_data_type: DataType,
667}
668
669impl DistinctCountAccumulator {
670    // calculating the size for fixed length values, taking first batch size *
671    // number of batches This method is faster than .full_size(), however it is
672    // not suitable for variable length values like strings or complex types
673    fn fixed_size(&self) -> usize {
674        size_of_val(self)
675            + (size_of::<ScalarValue>() * self.values.capacity())
676            + self
677                .values
678                .iter()
679                .next()
680                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
681                .unwrap_or(0)
682            + size_of::<DataType>()
683    }
684
685    // calculates the size as accurately as possible. Note that calling this
686    // method is expensive
687    fn full_size(&self) -> usize {
688        size_of_val(self)
689            + (size_of::<ScalarValue>() * self.values.capacity())
690            + self
691                .values
692                .iter()
693                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
694                .sum::<usize>()
695            + size_of::<DataType>()
696    }
697}
698
699impl Accumulator for DistinctCountAccumulator {
700    /// Returns the distinct values seen so far as (one element) ListArray.
701    fn state(&mut self) -> Result<Vec<ScalarValue>> {
702        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
703        let arr =
704            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
705        Ok(vec![ScalarValue::List(arr)])
706    }
707
708    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
709        if values.is_empty() {
710            return Ok(());
711        }
712
713        let arr = &values[0];
714        if arr.data_type() == &DataType::Null {
715            return Ok(());
716        }
717
718        (0..arr.len()).try_for_each(|index| {
719            if !arr.is_null(index) {
720                let scalar = ScalarValue::try_from_array(arr, index)?;
721                self.values.insert(scalar);
722            }
723            Ok(())
724        })
725    }
726
727    /// Merges multiple sets of distinct values into the current set.
728    ///
729    /// The input to this function is a `ListArray` with **multiple** rows,
730    /// where each row contains the values from a partial aggregate's phase (e.g.
731    /// the result of calling `Self::state` on multiple accumulators).
732    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
733        if states.is_empty() {
734            return Ok(());
735        }
736        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
737        let array = &states[0];
738        let list_array = array.as_list::<i32>();
739        for inner_array in list_array.iter() {
740            let Some(inner_array) = inner_array else {
741                return internal_err!(
742                    "Intermediate results of COUNT DISTINCT should always be non null"
743                );
744            };
745            self.update_batch(&[inner_array])?;
746        }
747        Ok(())
748    }
749
750    fn evaluate(&mut self) -> Result<ScalarValue> {
751        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
752    }
753
754    fn size(&self) -> usize {
755        match &self.state_data_type {
756            DataType::Boolean | DataType::Null => self.fixed_size(),
757            d if d.is_primitive() => self.fixed_size(),
758            _ => self.full_size(),
759        }
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use arrow::array::{Int32Array, NullArray};
767    use arrow::datatypes::{DataType, Field, Int32Type, Schema};
768    use datafusion_expr::function::AccumulatorArgs;
769    use datafusion_physical_expr::expressions::Column;
770    use datafusion_physical_expr::LexOrdering;
771    use std::sync::Arc;
772
773    #[test]
774    fn count_accumulator_nulls() -> Result<()> {
775        let mut accumulator = CountAccumulator::new();
776        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
777        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
778        Ok(())
779    }
780
781    #[test]
782    fn test_nested_dictionary() -> Result<()> {
783        let schema = Arc::new(Schema::new(vec![Field::new(
784            "dict_col",
785            DataType::Dictionary(
786                Box::new(DataType::Int32),
787                Box::new(DataType::Dictionary(
788                    Box::new(DataType::Int32),
789                    Box::new(DataType::Utf8),
790                )),
791            ),
792            true,
793        )]));
794
795        // Using Count UDAF's accumulator
796        let count = Count::new();
797        let expr = Arc::new(Column::new("dict_col", 0));
798        let args = AccumulatorArgs {
799            schema: &schema,
800            exprs: &[expr],
801            is_distinct: true,
802            name: "count",
803            ignore_nulls: false,
804            is_reversed: false,
805            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
806            ordering_req: &LexOrdering::default(),
807        };
808
809        let inner_dict = arrow::array::DictionaryArray::<Int32Type>::from_iter([
810            "a", "b", "c", "d", "a", "b",
811        ]);
812
813        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
814        let dict_of_dict = arrow::array::DictionaryArray::<Int32Type>::try_new(
815            keys,
816            Arc::new(inner_dict),
817        )?;
818
819        let mut acc = count.accumulator(args)?;
820        acc.update_batch(&[Arc::new(dict_of_dict)])?;
821        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
822
823        Ok(())
824    }
825}