datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use arrow::{
20    array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21    buffer::BooleanBuffer,
22    compute,
23    datatypes::{
24        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25        FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
26        Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30    },
31};
32use datafusion_common::{
33    downcast_value, internal_err, not_impl_err, stats::Precision,
34    utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue,
35};
36use datafusion_expr::{
37    expr::WindowFunction,
38    function::{AccumulatorArgs, StateFieldsArgs},
39    utils::format_state_name,
40    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
41    ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
42    WindowFunctionDefinition,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45    count_distinct::BytesDistinctCountAccumulator,
46    count_distinct::BytesViewDistinctCountAccumulator,
47    count_distinct::DictionaryCountAccumulator,
48    count_distinct::FloatDistinctCountAccumulator,
49    count_distinct::PrimitiveDistinctCountAccumulator,
50    groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56    collections::HashSet,
57    fmt::Debug,
58    mem::{size_of, size_of_val},
59    ops::BitAnd,
60    sync::Arc,
61};
62make_udaf_expr_and_func!(
63    Count,
64    count,
65    expr,
66    "Count the number of non-null values in the column",
67    count_udaf
68);
69
70pub fn count_distinct(expr: Expr) -> Expr {
71    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
72        count_udaf(),
73        vec![expr],
74        true,
75        None,
76        vec![],
77        None,
78    ))
79}
80
81/// Creates aggregation to count all rows.
82///
83/// In SQL this is `SELECT COUNT(*) ... `
84///
85/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
86/// aliased to a column named `"count(*)"` for backward compatibility.
87///
88/// Example
89/// ```
90/// # use datafusion_functions_aggregate::count::count_all;
91/// # use datafusion_expr::col;
92/// // create `count(*)` expression
93/// let expr = count_all();
94/// assert_eq!(expr.schema_name().to_string(), "count(*)");
95/// // if you need to refer to this column, use the `schema_name` function
96/// let expr = col(expr.schema_name().to_string());
97/// ```
98pub fn count_all() -> Expr {
99    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
100}
101
102/// Creates window aggregation to count all rows.
103///
104/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
105///
106/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
107///
108/// Example
109/// ```
110/// # use datafusion_functions_aggregate::count::count_all_window;
111/// # use datafusion_expr::col;
112/// // create `count(*)` OVER ... window function expression
113/// let expr = count_all_window();
114/// assert_eq!(
115///   expr.schema_name().to_string(),
116///   "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
117/// );
118/// // if you need to refer to this column, use the `schema_name` function
119/// let expr = col(expr.schema_name().to_string());
120/// ```
121pub fn count_all_window() -> Expr {
122    Expr::from(WindowFunction::new(
123        WindowFunctionDefinition::AggregateUDF(count_udaf()),
124        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
125    ))
126}
127
128#[user_doc(
129    doc_section(label = "General Functions"),
130    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
131    syntax_example = "count(expression)",
132    sql_example = r#"```sql
133> SELECT count(column_name) FROM table_name;
134+-----------------------+
135| count(column_name)     |
136+-----------------------+
137| 100                   |
138+-----------------------+
139
140> SELECT count(*) FROM table_name;
141+------------------+
142| count(*)         |
143+------------------+
144| 120              |
145+------------------+
146```"#,
147    standard_argument(name = "expression",)
148)]
149pub struct Count {
150    signature: Signature,
151}
152
153impl Debug for Count {
154    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
155        f.debug_struct("Count")
156            .field("name", &self.name())
157            .field("signature", &self.signature)
158            .finish()
159    }
160}
161
162impl Default for Count {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168impl Count {
169    pub fn new() -> Self {
170        Self {
171            signature: Signature::one_of(
172                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
173                Volatility::Immutable,
174            ),
175        }
176    }
177}
178fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
179    match data_type {
180        // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
181        DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
182            data_type,
183        )),
184        DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
185            data_type,
186        )),
187        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
188            data_type,
189        )),
190        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
191            data_type,
192        )),
193        DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
194            data_type,
195        )),
196        DataType::UInt16 => Box::new(
197            PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
198        ),
199        DataType::UInt32 => Box::new(
200            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
201        ),
202        DataType::UInt64 => Box::new(
203            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
204        ),
205        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
206            Decimal128Type,
207        >::new(data_type)),
208        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
209            Decimal256Type,
210        >::new(data_type)),
211
212        DataType::Date32 => Box::new(
213            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
214        ),
215        DataType::Date64 => Box::new(
216            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
217        ),
218        DataType::Time32(TimeUnit::Millisecond) => Box::new(
219            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
220        ),
221        DataType::Time32(TimeUnit::Second) => Box::new(
222            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
223        ),
224        DataType::Time64(TimeUnit::Microsecond) => Box::new(
225            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
226        ),
227        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
228            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
229        ),
230        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
231            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
232        ),
233        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
234            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
235        ),
236        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
237            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
238        ),
239        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
240            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
241        ),
242
243        DataType::Float16 => {
244            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
245        }
246        DataType::Float32 => {
247            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
248        }
249        DataType::Float64 => {
250            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
251        }
252
253        DataType::Utf8 => {
254            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
255        }
256        DataType::Utf8View => {
257            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
258        }
259        DataType::LargeUtf8 => {
260            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
261        }
262        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
263            OutputType::Binary,
264        )),
265        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
266            OutputType::BinaryView,
267        )),
268        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
269            OutputType::Binary,
270        )),
271
272        // Use the generic accumulator based on `ScalarValue` for all other types
273        _ => Box::new(DistinctCountAccumulator {
274            values: HashSet::default(),
275            state_data_type: data_type.clone(),
276        }),
277    }
278}
279
280impl AggregateUDFImpl for Count {
281    fn as_any(&self) -> &dyn std::any::Any {
282        self
283    }
284
285    fn name(&self) -> &str {
286        "count"
287    }
288
289    fn signature(&self) -> &Signature {
290        &self.signature
291    }
292
293    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
294        Ok(DataType::Int64)
295    }
296
297    fn is_nullable(&self) -> bool {
298        false
299    }
300
301    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
302        if args.is_distinct {
303            let dtype: DataType = match &args.input_fields[0].data_type() {
304                DataType::Dictionary(_, values_type) => (**values_type).clone(),
305                &dtype => dtype.clone(),
306            };
307
308            Ok(vec![Field::new_list(
309                format_state_name(args.name, "count distinct"),
310                // See COMMENTS.md to understand why nullable is set to true
311                Field::new_list_field(dtype, true),
312                false,
313            )
314            .into()])
315        } else {
316            Ok(vec![Field::new(
317                format_state_name(args.name, "count"),
318                DataType::Int64,
319                false,
320            )
321            .into()])
322        }
323    }
324
325    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
326        if !acc_args.is_distinct {
327            return Ok(Box::new(CountAccumulator::new()));
328        }
329
330        if acc_args.exprs.len() > 1 {
331            return not_impl_err!("COUNT DISTINCT with multiple arguments");
332        }
333
334        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
335
336        Ok(match data_type {
337            DataType::Dictionary(_, values_type) => {
338                let inner = get_count_accumulator(values_type);
339                Box::new(DictionaryCountAccumulator::new(inner))
340            }
341            _ => get_count_accumulator(data_type),
342        })
343    }
344
345    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
346        // groups accumulator only supports `COUNT(c1)`, not
347        // `COUNT(c1, c2)`, etc
348        if args.is_distinct {
349            return false;
350        }
351        args.exprs.len() == 1
352    }
353
354    fn create_groups_accumulator(
355        &self,
356        _args: AccumulatorArgs,
357    ) -> Result<Box<dyn GroupsAccumulator>> {
358        // instantiate specialized accumulator
359        Ok(Box::new(CountGroupsAccumulator::new()))
360    }
361
362    fn reverse_expr(&self) -> ReversedUDAF {
363        ReversedUDAF::Identical
364    }
365
366    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
367        Ok(ScalarValue::Int64(Some(0)))
368    }
369
370    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
371        if statistics_args.is_distinct {
372            return None;
373        }
374        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
375            if statistics_args.exprs.len() == 1 {
376                // TODO optimize with exprs other than Column
377                if let Some(col_expr) = statistics_args.exprs[0]
378                    .as_any()
379                    .downcast_ref::<expressions::Column>()
380                {
381                    let current_val = &statistics_args.statistics.column_statistics
382                        [col_expr.index()]
383                    .null_count;
384                    if let &Precision::Exact(val) = current_val {
385                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
386                    }
387                } else if let Some(lit_expr) = statistics_args.exprs[0]
388                    .as_any()
389                    .downcast_ref::<expressions::Literal>()
390                {
391                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
392                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
393                    }
394                }
395            }
396        }
397        None
398    }
399
400    fn documentation(&self) -> Option<&Documentation> {
401        self.doc()
402    }
403
404    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
405        // `COUNT` is monotonically increasing as it always increases or stays
406        // the same as new values are seen.
407        SetMonotonicity::Increasing
408    }
409}
410
411#[derive(Debug)]
412struct CountAccumulator {
413    count: i64,
414}
415
416impl CountAccumulator {
417    /// new count accumulator
418    pub fn new() -> Self {
419        Self { count: 0 }
420    }
421}
422
423impl Accumulator for CountAccumulator {
424    fn state(&mut self) -> Result<Vec<ScalarValue>> {
425        Ok(vec![ScalarValue::Int64(Some(self.count))])
426    }
427
428    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
429        let array = &values[0];
430        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
431        Ok(())
432    }
433
434    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
435        let array = &values[0];
436        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
437        Ok(())
438    }
439
440    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
441        let counts = downcast_value!(states[0], Int64Array);
442        let delta = &compute::sum(counts);
443        if let Some(d) = delta {
444            self.count += *d;
445        }
446        Ok(())
447    }
448
449    fn evaluate(&mut self) -> Result<ScalarValue> {
450        Ok(ScalarValue::Int64(Some(self.count)))
451    }
452
453    fn supports_retract_batch(&self) -> bool {
454        true
455    }
456
457    fn size(&self) -> usize {
458        size_of_val(self)
459    }
460}
461
462/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
463/// Stores values as native types, and does overflow checking
464///
465/// Unlike most other accumulators, COUNT never produces NULLs. If no
466/// non-null values are seen in any group the output is 0. Thus, this
467/// accumulator has no additional null or seen filter tracking.
468#[derive(Debug)]
469struct CountGroupsAccumulator {
470    /// Count per group.
471    ///
472    /// Note this is an i64 and not a u64 (or usize) because the
473    /// output type of count is `DataType::Int64`. Thus by using `i64`
474    /// for the counts, the output [`Int64Array`] can be created
475    /// without copy.
476    counts: Vec<i64>,
477}
478
479impl CountGroupsAccumulator {
480    pub fn new() -> Self {
481        Self { counts: vec![] }
482    }
483}
484
485impl GroupsAccumulator for CountGroupsAccumulator {
486    fn update_batch(
487        &mut self,
488        values: &[ArrayRef],
489        group_indices: &[usize],
490        opt_filter: Option<&BooleanArray>,
491        total_num_groups: usize,
492    ) -> Result<()> {
493        assert_eq!(values.len(), 1, "single argument to update_batch");
494        let values = &values[0];
495
496        // Add one to each group's counter for each non null, non
497        // filtered value
498        self.counts.resize(total_num_groups, 0);
499        accumulate_indices(
500            group_indices,
501            values.logical_nulls().as_ref(),
502            opt_filter,
503            |group_index| {
504                self.counts[group_index] += 1;
505            },
506        );
507
508        Ok(())
509    }
510
511    fn merge_batch(
512        &mut self,
513        values: &[ArrayRef],
514        group_indices: &[usize],
515        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
516        _opt_filter: Option<&BooleanArray>,
517        total_num_groups: usize,
518    ) -> Result<()> {
519        assert_eq!(values.len(), 1, "one argument to merge_batch");
520        // first batch is counts, second is partial sums
521        let partial_counts = values[0].as_primitive::<Int64Type>();
522
523        // intermediate counts are always created as non null
524        assert_eq!(partial_counts.null_count(), 0);
525        let partial_counts = partial_counts.values();
526
527        // Adds the counts with the partial counts
528        self.counts.resize(total_num_groups, 0);
529        group_indices.iter().zip(partial_counts.iter()).for_each(
530            |(&group_index, partial_count)| {
531                self.counts[group_index] += partial_count;
532            },
533        );
534
535        Ok(())
536    }
537
538    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
539        let counts = emit_to.take_needed(&mut self.counts);
540
541        // Count is always non null (null inputs just don't contribute to the overall values)
542        let nulls = None;
543        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
544
545        Ok(Arc::new(array))
546    }
547
548    // return arrays for counts
549    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
550        let counts = emit_to.take_needed(&mut self.counts);
551        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
552        Ok(vec![Arc::new(counts) as ArrayRef])
553    }
554
555    /// Converts an input batch directly to a state batch
556    ///
557    /// The state of `COUNT` is always a single Int64Array:
558    /// * `1` (for non-null, non filtered values)
559    /// * `0` (for null values)
560    fn convert_to_state(
561        &self,
562        values: &[ArrayRef],
563        opt_filter: Option<&BooleanArray>,
564    ) -> Result<Vec<ArrayRef>> {
565        let values = &values[0];
566
567        let state_array = match (values.logical_nulls(), opt_filter) {
568            (None, None) => {
569                // In case there is no nulls in input and no filter, returning array of 1
570                Arc::new(Int64Array::from_value(1, values.len()))
571            }
572            (Some(nulls), None) => {
573                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
574                // of input array to Int64
575                let nulls = BooleanArray::new(nulls.into_inner(), None);
576                compute::cast(&nulls, &DataType::Int64)?
577            }
578            (None, Some(filter)) => {
579                // If there is only filter
580                // - applying filter null mask to filter values by bitand filter values and nulls buffers
581                //   (using buffers guarantees absence of nulls in result)
582                // - casting result of bitand to Int64 array
583                let (filter_values, filter_nulls) = filter.clone().into_parts();
584
585                let state_buf = match filter_nulls {
586                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
587                    None => filter_values,
588                };
589
590                let boolean_state = BooleanArray::new(state_buf, None);
591                compute::cast(&boolean_state, &DataType::Int64)?
592            }
593            (Some(nulls), Some(filter)) => {
594                // For both input nulls and filter
595                // - applying filter null mask to filter values by bitand filter values and nulls buffers
596                //   (using buffers guarantees absence of nulls in result)
597                // - applying values null mask to filter buffer by another bitand on filter result and
598                //   nulls from input values
599                // - casting result to Int64 array
600                let (filter_values, filter_nulls) = filter.clone().into_parts();
601
602                let filter_buf = match filter_nulls {
603                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
604                    None => filter_values,
605                };
606                let state_buf = &filter_buf & nulls.inner();
607
608                let boolean_state = BooleanArray::new(state_buf, None);
609                compute::cast(&boolean_state, &DataType::Int64)?
610            }
611        };
612
613        Ok(vec![state_array])
614    }
615
616    fn supports_convert_to_state(&self) -> bool {
617        true
618    }
619
620    fn size(&self) -> usize {
621        self.counts.capacity() * size_of::<usize>()
622    }
623}
624
625/// count null values for multiple columns
626/// for each row if one column value is null, then null_count + 1
627fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
628    if values.len() > 1 {
629        let result_bool_buf: Option<BooleanBuffer> = values
630            .iter()
631            .map(|a| a.logical_nulls())
632            .fold(None, |acc, b| match (acc, b) {
633                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
634                (Some(acc), None) => Some(acc),
635                (None, Some(b)) => Some(b.into_inner()),
636                _ => None,
637            });
638        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
639    } else {
640        values[0]
641            .logical_nulls()
642            .map_or(0, |nulls| nulls.null_count())
643    }
644}
645
646/// General purpose distinct accumulator that works for any DataType by using
647/// [`ScalarValue`].
648///
649/// It stores intermediate results as a `ListArray`
650///
651/// Note that many types have specialized accumulators that are (much)
652/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
653/// [`BytesDistinctCountAccumulator`]
654#[derive(Debug)]
655struct DistinctCountAccumulator {
656    values: HashSet<ScalarValue, RandomState>,
657    state_data_type: DataType,
658}
659
660impl DistinctCountAccumulator {
661    // calculating the size for fixed length values, taking first batch size *
662    // number of batches This method is faster than .full_size(), however it is
663    // not suitable for variable length values like strings or complex types
664    fn fixed_size(&self) -> usize {
665        size_of_val(self)
666            + (size_of::<ScalarValue>() * self.values.capacity())
667            + self
668                .values
669                .iter()
670                .next()
671                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
672                .unwrap_or(0)
673            + size_of::<DataType>()
674    }
675
676    // calculates the size as accurately as possible. Note that calling this
677    // method is expensive
678    fn full_size(&self) -> usize {
679        size_of_val(self)
680            + (size_of::<ScalarValue>() * self.values.capacity())
681            + self
682                .values
683                .iter()
684                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
685                .sum::<usize>()
686            + size_of::<DataType>()
687    }
688}
689
690impl Accumulator for DistinctCountAccumulator {
691    /// Returns the distinct values seen so far as (one element) ListArray.
692    fn state(&mut self) -> Result<Vec<ScalarValue>> {
693        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
694        let arr =
695            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
696        Ok(vec![ScalarValue::List(arr)])
697    }
698
699    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
700        if values.is_empty() {
701            return Ok(());
702        }
703
704        let arr = &values[0];
705        if arr.data_type() == &DataType::Null {
706            return Ok(());
707        }
708
709        (0..arr.len()).try_for_each(|index| {
710            let scalar = ScalarValue::try_from_array(arr, index)?;
711            if !scalar.is_null() {
712                self.values.insert(scalar);
713            }
714            Ok(())
715        })
716    }
717
718    /// Merges multiple sets of distinct values into the current set.
719    ///
720    /// The input to this function is a `ListArray` with **multiple** rows,
721    /// where each row contains the values from a partial aggregate's phase (e.g.
722    /// the result of calling `Self::state` on multiple accumulators).
723    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
724        if states.is_empty() {
725            return Ok(());
726        }
727        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
728        let array = &states[0];
729        let list_array = array.as_list::<i32>();
730        for inner_array in list_array.iter() {
731            let Some(inner_array) = inner_array else {
732                return internal_err!(
733                    "Intermediate results of COUNT DISTINCT should always be non null"
734                );
735            };
736            self.update_batch(&[inner_array])?;
737        }
738        Ok(())
739    }
740
741    fn evaluate(&mut self) -> Result<ScalarValue> {
742        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
743    }
744
745    fn size(&self) -> usize {
746        match &self.state_data_type {
747            DataType::Boolean | DataType::Null => self.fixed_size(),
748            d if d.is_primitive() => self.fixed_size(),
749            _ => self.full_size(),
750        }
751    }
752}
753
754#[cfg(test)]
755mod tests {
756
757    use super::*;
758    use arrow::{
759        array::{DictionaryArray, Int32Array, NullArray, StringArray},
760        datatypes::{DataType, Field, Int32Type, Schema},
761    };
762    use datafusion_expr::function::AccumulatorArgs;
763    use datafusion_physical_expr::expressions::Column;
764    use std::sync::Arc;
765    /// Helper function to create a dictionary array with non-null keys but some null values
766    /// Returns a dictionary array where:
767    /// - keys are [0, 1, 2, 0, 1] (all non-null)
768    /// - values are ["a", null, "c"]
769    /// - so the keys reference: "a", null, "c", "a", null
770    fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
771        let values = StringArray::from(vec![Some("a"), None, Some("c")]);
772        let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
773        Ok(DictionaryArray::<Int32Type>::try_new(
774            keys,
775            Arc::new(values),
776        )?)
777    }
778
779    #[test]
780    fn count_accumulator_nulls() -> Result<()> {
781        let mut accumulator = CountAccumulator::new();
782        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
783        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
784        Ok(())
785    }
786
787    #[test]
788    fn test_nested_dictionary() -> Result<()> {
789        let schema = Arc::new(Schema::new(vec![Field::new(
790            "dict_col",
791            DataType::Dictionary(
792                Box::new(DataType::Int32),
793                Box::new(DataType::Dictionary(
794                    Box::new(DataType::Int32),
795                    Box::new(DataType::Utf8),
796                )),
797            ),
798            true,
799        )]));
800
801        // Using Count UDAF's accumulator
802        let count = Count::new();
803        let expr = Arc::new(Column::new("dict_col", 0));
804        let args = AccumulatorArgs {
805            schema: &schema,
806            exprs: &[expr],
807            is_distinct: true,
808            name: "count",
809            ignore_nulls: false,
810            is_reversed: false,
811            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
812            order_bys: &[],
813        };
814
815        let inner_dict =
816            DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
817
818        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
819        let dict_of_dict =
820            DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
821
822        let mut acc = count.accumulator(args)?;
823        acc.update_batch(&[Arc::new(dict_of_dict)])?;
824        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
825
826        Ok(())
827    }
828
829    #[test]
830    fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
831        let dict_array = create_dictionary_with_null_values()?;
832
833        // The expected behavior is that count_distinct should count only non-null values
834        // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
835        let mut accumulator = DistinctCountAccumulator {
836            values: HashSet::default(),
837            state_data_type: dict_array.data_type().clone(),
838        };
839
840        accumulator.update_batch(&[Arc::new(dict_array)])?;
841
842        // Should have 2 distinct non-null values ("a" and "c")
843        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
844        Ok(())
845    }
846
847    #[test]
848    fn count_accumulator_dictionary_with_null_values() -> Result<()> {
849        let dict_array = create_dictionary_with_null_values()?;
850
851        // The expected behavior is that count should only count non-null values
852        let mut accumulator = CountAccumulator::new();
853
854        accumulator.update_batch(&[Arc::new(dict_array)])?;
855
856        // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
857        // So we should count 3 non-null values
858        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
859        Ok(())
860    }
861
862    #[test]
863    fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
864        // Create a dictionary array that only contains null values
865        let dict_values = StringArray::from(vec![None, Some("abc")]);
866        let dict_indices = Int32Array::from(vec![0; 5]);
867        let dict_array =
868            DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
869
870        let mut accumulator = DistinctCountAccumulator {
871            values: HashSet::default(),
872            state_data_type: dict_array.data_type().clone(),
873        };
874
875        accumulator.update_batch(&[Arc::new(dict_array)])?;
876
877        // All referenced values are null so count(distinct) should be 0
878        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
879        Ok(())
880    }
881}