Skip to main content

datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
20    buffer::BooleanBuffer,
21    compute,
22    datatypes::{
23        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
24        FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
25        Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
26        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
27        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28        UInt8Type, UInt16Type, UInt32Type, UInt64Type,
29    },
30};
31use datafusion_common::hash_utils::RandomState;
32use datafusion_common::{
33    HashMap, Result, ScalarValue, downcast_value, exec_err, internal_err, not_impl_err,
34    stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38    ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39    WindowFunctionDefinition,
40    expr::WindowFunction,
41    function::{AccumulatorArgs, StateFieldsArgs},
42    utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
45use datafusion_functions_aggregate_common::aggregate::{
46    count_distinct::Bitmap65536DistinctCountAccumulator,
47    count_distinct::Bitmap65536DistinctCountAccumulatorI16,
48    count_distinct::BoolArray256DistinctCountAccumulator,
49    count_distinct::BoolArray256DistinctCountAccumulatorI8,
50    count_distinct::BytesDistinctCountAccumulator,
51    count_distinct::BytesViewDistinctCountAccumulator,
52    count_distinct::DictionaryCountAccumulator,
53    count_distinct::FloatDistinctCountAccumulator,
54    count_distinct::PrimitiveDistinctCountAccumulator,
55    groups_accumulator::accumulate::accumulate_indices,
56};
57use datafusion_macros::user_doc;
58use datafusion_physical_expr::expressions;
59use datafusion_physical_expr_common::binary_map::OutputType;
60use std::{
61    collections::HashSet,
62    fmt::Debug,
63    mem::{size_of, size_of_val},
64    ops::BitAnd,
65    sync::Arc,
66};
67
68make_udaf_expr_and_func!(
69    Count,
70    count,
71    expr,
72    "Count the number of non-null values in the column",
73    count_udaf
74);
75
76pub fn count_distinct(expr: Expr) -> Expr {
77    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
78        count_udaf(),
79        vec![expr],
80        true,
81        None,
82        vec![],
83        None,
84    ))
85}
86
87/// Creates aggregation to count all rows.
88///
89/// In SQL this is `SELECT COUNT(*) ... `
90///
91/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
92/// aliased to a column named `"count(*)"` for backward compatibility.
93///
94/// Example
95/// ```
96/// # use datafusion_functions_aggregate::count::count_all;
97/// # use datafusion_expr::col;
98/// // create `count(*)` expression
99/// let expr = count_all();
100/// assert_eq!(expr.schema_name().to_string(), "count(*)");
101/// // if you need to refer to this column, use the `schema_name` function
102/// let expr = col(expr.schema_name().to_string());
103/// ```
104pub fn count_all() -> Expr {
105    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
106}
107
108/// Creates window aggregation to count all rows.
109///
110/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
111///
112/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
113///
114/// Example
115/// ```
116/// # use datafusion_functions_aggregate::count::count_all_window;
117/// # use datafusion_expr::col;
118/// // create `count(*)` OVER ... window function expression
119/// let expr = count_all_window();
120/// assert_eq!(
121///     expr.schema_name().to_string(),
122///     "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
123/// );
124/// // if you need to refer to this column, use the `schema_name` function
125/// let expr = col(expr.schema_name().to_string());
126/// ```
127pub fn count_all_window() -> Expr {
128    Expr::from(WindowFunction::new(
129        WindowFunctionDefinition::AggregateUDF(count_udaf()),
130        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
131    ))
132}
133
134#[user_doc(
135    doc_section(label = "General Functions"),
136    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
137    syntax_example = "count(expression)",
138    sql_example = r#"```sql
139> SELECT count(column_name) FROM table_name;
140+-----------------------+
141| count(column_name)     |
142+-----------------------+
143| 100                   |
144+-----------------------+
145
146> SELECT count(*) FROM table_name;
147+------------------+
148| count(*)         |
149+------------------+
150| 120              |
151+------------------+
152```"#,
153    standard_argument(name = "expression",)
154)]
155#[derive(PartialEq, Eq, Hash, Debug)]
156pub struct Count {
157    signature: Signature,
158}
159
160impl Default for Count {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166impl Count {
167    pub fn new() -> Self {
168        Self {
169            signature: Signature::one_of(
170                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
171                Volatility::Immutable,
172            ),
173        }
174    }
175}
176fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
177    match data_type {
178        // HashSet-based accumulator for larger integer types
179        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
180            data_type,
181        )),
182        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
183            data_type,
184        )),
185        DataType::UInt32 => Box::new(
186            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
187        ),
188        DataType::UInt64 => Box::new(
189            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
190        ),
191        // Small int types - cold path
192        DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => {
193            get_small_int_accumulator(data_type).unwrap()
194        }
195        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
196            Decimal128Type,
197        >::new(data_type)),
198        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
199            Decimal256Type,
200        >::new(data_type)),
201
202        DataType::Date32 => Box::new(
203            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
204        ),
205        DataType::Date64 => Box::new(
206            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
207        ),
208        DataType::Time32(TimeUnit::Millisecond) => Box::new(
209            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
210        ),
211        DataType::Time32(TimeUnit::Second) => Box::new(
212            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
213        ),
214        DataType::Time64(TimeUnit::Microsecond) => Box::new(
215            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
216        ),
217        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
218            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
219        ),
220        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
221            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
222        ),
223        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
224            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
225        ),
226        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
227            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
228        ),
229        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
230            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
231        ),
232
233        DataType::Float16 => {
234            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
235        }
236        DataType::Float32 => {
237            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
238        }
239        DataType::Float64 => {
240            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
241        }
242
243        DataType::Utf8 => {
244            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
245        }
246        DataType::Utf8View => {
247            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
248        }
249        DataType::LargeUtf8 => {
250            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
251        }
252        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
253            OutputType::Binary,
254        )),
255        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
256            OutputType::BinaryView,
257        )),
258        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
259            OutputType::Binary,
260        )),
261
262        // Use the generic accumulator based on `ScalarValue` for all other types
263        _ => Box::new(DistinctCountAccumulator {
264            values: HashSet::default(),
265            state_data_type: data_type.clone(),
266        }),
267    }
268}
269
270/// Uses optimized bitmap accumulators but separated to keep hot path small
271#[cold]
272fn get_small_int_accumulator(data_type: &DataType) -> Result<Box<dyn Accumulator>> {
273    match data_type {
274        DataType::UInt8 => Ok(Box::new(BoolArray256DistinctCountAccumulator::new())),
275        DataType::Int8 => Ok(Box::new(BoolArray256DistinctCountAccumulatorI8::new())),
276        DataType::UInt16 => Ok(Box::new(Bitmap65536DistinctCountAccumulator::new())),
277        DataType::Int16 => Ok(Box::new(Bitmap65536DistinctCountAccumulatorI16::new())),
278        _ => exec_err!("unsupported accumulator for datatype: {}", data_type),
279    }
280}
281
282impl AggregateUDFImpl for Count {
283    fn name(&self) -> &str {
284        "count"
285    }
286
287    fn signature(&self) -> &Signature {
288        &self.signature
289    }
290
291    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
292        Ok(DataType::Int64)
293    }
294
295    fn is_nullable(&self) -> bool {
296        false
297    }
298
299    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
300        if args.is_distinct {
301            let dtype: DataType = match &args.input_fields[0].data_type() {
302                DataType::Dictionary(_, values_type) => (**values_type).clone(),
303                &dtype => dtype.clone(),
304            };
305
306            Ok(vec![
307                Field::new_list(
308                    format_state_name(args.name, "count distinct"),
309                    // See COMMENTS.md to understand why nullable is set to true
310                    Field::new_list_field(dtype, true),
311                    false,
312                )
313                .into(),
314            ])
315        } else {
316            Ok(vec![
317                Field::new(
318                    format_state_name(args.name, "count"),
319                    DataType::Int64,
320                    false,
321                )
322                .into(),
323            ])
324        }
325    }
326
327    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
328        if !acc_args.is_distinct {
329            return Ok(Box::new(CountAccumulator::new()));
330        }
331
332        if acc_args.exprs.len() > 1 {
333            return not_impl_err!("COUNT DISTINCT with multiple arguments");
334        }
335
336        let data_type = acc_args.expr_fields[0].data_type();
337
338        Ok(match data_type {
339            DataType::Dictionary(_, values_type) => {
340                let inner = get_count_accumulator(values_type);
341                Box::new(DictionaryCountAccumulator::new(inner))
342            }
343            _ => get_count_accumulator(data_type),
344        })
345    }
346
347    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348        if args.exprs.len() != 1 {
349            return false;
350        }
351        if !args.is_distinct {
352            return true;
353        }
354        matches!(
355            args.expr_fields[0].data_type(),
356            DataType::Int8
357                | DataType::Int16
358                | DataType::Int32
359                | DataType::Int64
360                | DataType::UInt8
361                | DataType::UInt16
362                | DataType::UInt32
363                | DataType::UInt64
364        )
365    }
366
367    fn create_groups_accumulator(
368        &self,
369        args: AccumulatorArgs,
370    ) -> Result<Box<dyn GroupsAccumulator>> {
371        if !args.is_distinct {
372            return Ok(Box::new(CountGroupsAccumulator::new()));
373        }
374        create_distinct_count_groups_accumulator(&args)
375    }
376
377    fn reverse_expr(&self) -> ReversedUDAF {
378        ReversedUDAF::Identical
379    }
380
381    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
382        Ok(ScalarValue::Int64(Some(0)))
383    }
384
385    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
386        let [expr] = statistics_args.exprs else {
387            return None;
388        };
389        let col_stats = &statistics_args.statistics.column_statistics;
390
391        if statistics_args.is_distinct {
392            // Only column references can be resolved from statistics;
393            // expressions like casts or literals are not supported.
394            let col_expr = expr.downcast_ref::<expressions::Column>()?;
395            if let Precision::Exact(dc) = col_stats[col_expr.index()].distinct_count {
396                let dc = i64::try_from(dc).ok()?;
397                return Some(ScalarValue::Int64(Some(dc)));
398            }
399            return None;
400        }
401
402        let Precision::Exact(num_rows) = statistics_args.statistics.num_rows else {
403            return None;
404        };
405
406        // TODO optimize with exprs other than Column
407        if let Some(col_expr) = expr.downcast_ref::<expressions::Column>() {
408            if let Precision::Exact(val) = col_stats[col_expr.index()].null_count {
409                let count = i64::try_from(num_rows - val).ok()?;
410                return Some(ScalarValue::Int64(Some(count)));
411            }
412        } else if let Some(lit_expr) = expr.downcast_ref::<expressions::Literal>()
413            && lit_expr.value() == &COUNT_STAR_EXPANSION
414        {
415            let num_rows = i64::try_from(num_rows).ok()?;
416            return Some(ScalarValue::Int64(Some(num_rows)));
417        }
418
419        None
420    }
421
422    fn documentation(&self) -> Option<&Documentation> {
423        self.doc()
424    }
425
426    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
427        // `COUNT` is monotonically increasing as it always increases or stays
428        // the same as new values are seen.
429        SetMonotonicity::Increasing
430    }
431
432    fn create_sliding_accumulator(
433        &self,
434        args: AccumulatorArgs,
435    ) -> Result<Box<dyn Accumulator>> {
436        if args.is_distinct {
437            let acc =
438                SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
439            Ok(Box::new(acc))
440        } else {
441            let acc = CountAccumulator::new();
442            Ok(Box::new(acc))
443        }
444    }
445}
446
447#[cold]
448fn create_distinct_count_groups_accumulator(
449    args: &AccumulatorArgs,
450) -> Result<Box<dyn GroupsAccumulator>> {
451    let data_type = args.expr_fields[0].data_type();
452    match data_type {
453        DataType::Int8 => Ok(Box::new(
454            PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
455        )),
456        DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
457            Int16Type,
458        >::new())),
459        DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
460            Int32Type,
461        >::new())),
462        DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
463            Int64Type,
464        >::new())),
465        DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
466            UInt8Type,
467        >::new())),
468        DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
469            UInt16Type,
470        >::new())),
471        DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
472            UInt32Type,
473        >::new())),
474        DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
475            UInt64Type,
476        >::new())),
477        _ => not_impl_err!(
478            "GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
479            data_type
480        ),
481    }
482}
483
484// DistinctCountAccumulator does not support retract_batch and sliding window
485// this is a specialized accumulator for distinct count that supports retract_batch
486// and sliding window.
487#[derive(Debug)]
488pub struct SlidingDistinctCountAccumulator {
489    counts: HashMap<ScalarValue, usize, RandomState>,
490    data_type: DataType,
491}
492
493impl SlidingDistinctCountAccumulator {
494    pub fn try_new(data_type: &DataType) -> Result<Self> {
495        Ok(Self {
496            counts: HashMap::default(),
497            data_type: data_type.clone(),
498        })
499    }
500}
501
502impl Accumulator for SlidingDistinctCountAccumulator {
503    fn state(&mut self) -> Result<Vec<ScalarValue>> {
504        let keys = self.counts.keys().cloned().collect::<Vec<_>>();
505        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
506            keys.as_slice(),
507            &self.data_type,
508        ))])
509    }
510
511    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
512        let arr = &values[0];
513        for i in 0..arr.len() {
514            let v = ScalarValue::try_from_array(arr, i)?;
515            if !v.is_null() {
516                *self.counts.entry(v).or_default() += 1;
517            }
518        }
519        Ok(())
520    }
521
522    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523        let arr = &values[0];
524        for i in 0..arr.len() {
525            let v = ScalarValue::try_from_array(arr, i)?;
526            if !v.is_null()
527                && let Some(cnt) = self.counts.get_mut(&v)
528            {
529                *cnt -= 1;
530                if *cnt == 0 {
531                    self.counts.remove(&v);
532                }
533            }
534        }
535        Ok(())
536    }
537
538    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
539        let list_arr = states[0].as_list::<i32>();
540        for inner in list_arr.iter().flatten() {
541            for j in 0..inner.len() {
542                let v = ScalarValue::try_from_array(&*inner, j)?;
543                *self.counts.entry(v).or_default() += 1;
544            }
545        }
546        Ok(())
547    }
548
549    fn evaluate(&mut self) -> Result<ScalarValue> {
550        Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
551    }
552
553    fn supports_retract_batch(&self) -> bool {
554        true
555    }
556
557    fn size(&self) -> usize {
558        size_of_val(self)
559    }
560}
561
562#[derive(Debug)]
563struct CountAccumulator {
564    count: i64,
565}
566
567impl CountAccumulator {
568    /// new count accumulator
569    pub fn new() -> Self {
570        Self { count: 0 }
571    }
572}
573
574impl Accumulator for CountAccumulator {
575    fn state(&mut self) -> Result<Vec<ScalarValue>> {
576        Ok(vec![ScalarValue::Int64(Some(self.count))])
577    }
578
579    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
580        let array = &values[0];
581        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
582        Ok(())
583    }
584
585    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
586        let array = &values[0];
587        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
588        Ok(())
589    }
590
591    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
592        let counts = downcast_value!(states[0], Int64Array);
593        let delta = &compute::sum(counts);
594        if let Some(d) = delta {
595            self.count += *d;
596        }
597        Ok(())
598    }
599
600    fn evaluate(&mut self) -> Result<ScalarValue> {
601        Ok(ScalarValue::Int64(Some(self.count)))
602    }
603
604    fn supports_retract_batch(&self) -> bool {
605        true
606    }
607
608    fn size(&self) -> usize {
609        size_of_val(self)
610    }
611}
612
613/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
614/// Stores values as native types, and does overflow checking
615///
616/// Unlike most other accumulators, COUNT never produces NULLs. If no
617/// non-null values are seen in any group the output is 0. Thus, this
618/// accumulator has no additional null or seen filter tracking.
619#[derive(Debug)]
620struct CountGroupsAccumulator {
621    /// Count per group.
622    ///
623    /// Note this is an i64 and not a u64 (or usize) because the
624    /// output type of count is `DataType::Int64`. Thus by using `i64`
625    /// for the counts, the output [`Int64Array`] can be created
626    /// without copy.
627    counts: Vec<i64>,
628}
629
630impl CountGroupsAccumulator {
631    pub fn new() -> Self {
632        Self { counts: vec![] }
633    }
634}
635
636impl GroupsAccumulator for CountGroupsAccumulator {
637    fn update_batch(
638        &mut self,
639        values: &[ArrayRef],
640        group_indices: &[usize],
641        opt_filter: Option<&BooleanArray>,
642        total_num_groups: usize,
643    ) -> Result<()> {
644        assert_eq!(values.len(), 1, "single argument to update_batch");
645        let values = &values[0];
646
647        // Add one to each group's counter for each non null, non
648        // filtered value
649        self.counts.resize(total_num_groups, 0);
650        accumulate_indices(
651            group_indices,
652            values.logical_nulls().as_ref(),
653            opt_filter,
654            |group_index| {
655                // SAFETY: group_index is guaranteed to be in bounds
656                let count = unsafe { self.counts.get_unchecked_mut(group_index) };
657                *count += 1;
658            },
659        );
660
661        Ok(())
662    }
663
664    fn merge_batch(
665        &mut self,
666        values: &[ArrayRef],
667        group_indices: &[usize],
668        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
669        _opt_filter: Option<&BooleanArray>,
670        total_num_groups: usize,
671    ) -> Result<()> {
672        assert_eq!(values.len(), 1, "one argument to merge_batch");
673        // first batch is counts, second is partial sums
674        let partial_counts = values[0].as_primitive::<Int64Type>();
675
676        // intermediate counts are always created as non null
677        assert_eq!(partial_counts.null_count(), 0);
678        let partial_counts = partial_counts.values();
679
680        // Adds the counts with the partial counts
681        self.counts.resize(total_num_groups, 0);
682        group_indices.iter().zip(partial_counts.iter()).for_each(
683            |(&group_index, partial_count)| {
684                self.counts[group_index] += partial_count;
685            },
686        );
687
688        Ok(())
689    }
690
691    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
692        let counts = emit_to.take_needed(&mut self.counts);
693
694        // Count is always non null (null inputs just don't contribute to the overall values)
695        let nulls = None;
696        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
697
698        Ok(Arc::new(array))
699    }
700
701    // return arrays for counts
702    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
703        let counts = emit_to.take_needed(&mut self.counts);
704        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
705        Ok(vec![Arc::new(counts) as ArrayRef])
706    }
707
708    /// Converts an input batch directly to a state batch
709    ///
710    /// The state of `COUNT` is always a single Int64Array:
711    /// * `1` (for non-null, non filtered values)
712    /// * `0` (for null values)
713    fn convert_to_state(
714        &self,
715        values: &[ArrayRef],
716        opt_filter: Option<&BooleanArray>,
717    ) -> Result<Vec<ArrayRef>> {
718        let values = &values[0];
719
720        let state_array = match (values.logical_nulls(), opt_filter) {
721            (None, None) => {
722                // In case there is no nulls in input and no filter, returning array of 1
723                Arc::new(Int64Array::from_value(1, values.len()))
724            }
725            (Some(nulls), None) => {
726                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
727                // of input array to Int64
728                let nulls = BooleanArray::new(nulls.into_inner(), None);
729                compute::cast(&nulls, &DataType::Int64)?
730            }
731            (None, Some(filter)) => {
732                // If there is only filter
733                // - applying filter null mask to filter values by bitand filter values and nulls buffers
734                //   (using buffers guarantees absence of nulls in result)
735                // - casting result of bitand to Int64 array
736                let (filter_values, filter_nulls) = filter.clone().into_parts();
737
738                let state_buf = match filter_nulls {
739                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
740                    None => filter_values,
741                };
742
743                let boolean_state = BooleanArray::new(state_buf, None);
744                compute::cast(&boolean_state, &DataType::Int64)?
745            }
746            (Some(nulls), Some(filter)) => {
747                // For both input nulls and filter
748                // - applying filter null mask to filter values by bitand filter values and nulls buffers
749                //   (using buffers guarantees absence of nulls in result)
750                // - applying values null mask to filter buffer by another bitand on filter result and
751                //   nulls from input values
752                // - casting result to Int64 array
753                let (filter_values, filter_nulls) = filter.clone().into_parts();
754
755                let filter_buf = match filter_nulls {
756                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
757                    None => filter_values,
758                };
759                let state_buf = &filter_buf & nulls.inner();
760
761                let boolean_state = BooleanArray::new(state_buf, None);
762                compute::cast(&boolean_state, &DataType::Int64)?
763            }
764        };
765
766        Ok(vec![state_array])
767    }
768
769    fn supports_convert_to_state(&self) -> bool {
770        true
771    }
772
773    fn size(&self) -> usize {
774        self.counts.capacity() * size_of::<usize>()
775    }
776}
777
778/// count null values for multiple columns
779/// for each row if one column value is null, then null_count + 1
780fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
781    if values.len() > 1 {
782        let result_bool_buf: Option<BooleanBuffer> = values
783            .iter()
784            .map(|a| a.logical_nulls())
785            .fold(None, |acc, b| match (acc, b) {
786                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
787                (Some(acc), None) => Some(acc),
788                (None, Some(b)) => Some(b.into_inner()),
789                _ => None,
790            });
791        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
792    } else {
793        values[0]
794            .logical_nulls()
795            .map_or(0, |nulls| nulls.null_count())
796    }
797}
798
799/// General purpose distinct accumulator that works for any DataType by using
800/// [`ScalarValue`].
801///
802/// It stores intermediate results as a `ListArray`
803///
804/// Note that many types have specialized accumulators that are (much)
805/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
806/// [`BytesDistinctCountAccumulator`]
807#[derive(Debug)]
808struct DistinctCountAccumulator {
809    values: HashSet<ScalarValue, RandomState>,
810    state_data_type: DataType,
811}
812
813impl DistinctCountAccumulator {
814    // calculating the size for fixed length values, taking first batch size *
815    // number of batches This method is faster than .full_size(), however it is
816    // not suitable for variable length values like strings or complex types
817    fn fixed_size(&self) -> usize {
818        size_of_val(self)
819            + (size_of::<ScalarValue>() * self.values.capacity())
820            + self
821                .values
822                .iter()
823                .next()
824                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
825                .unwrap_or(0)
826            + size_of::<DataType>()
827    }
828
829    // calculates the size as accurately as possible. Note that calling this
830    // method is expensive
831    fn full_size(&self) -> usize {
832        size_of_val(self)
833            + (size_of::<ScalarValue>() * self.values.capacity())
834            + self
835                .values
836                .iter()
837                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
838                .sum::<usize>()
839            + size_of::<DataType>()
840    }
841}
842
843impl Accumulator for DistinctCountAccumulator {
844    /// Returns the distinct values seen so far as (one element) ListArray.
845    fn state(&mut self) -> Result<Vec<ScalarValue>> {
846        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
847        let arr =
848            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
849        Ok(vec![ScalarValue::List(arr)])
850    }
851
852    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
853        if values.is_empty() {
854            return Ok(());
855        }
856
857        let arr = &values[0];
858        if arr.data_type() == &DataType::Null {
859            return Ok(());
860        }
861
862        (0..arr.len()).try_for_each(|index| {
863            let scalar = ScalarValue::try_from_array(arr, index)?;
864            if !scalar.is_null() {
865                self.values.insert(scalar);
866            }
867            Ok(())
868        })
869    }
870
871    /// Merges multiple sets of distinct values into the current set.
872    ///
873    /// The input to this function is a `ListArray` with **multiple** rows,
874    /// where each row contains the values from a partial aggregate's phase (e.g.
875    /// the result of calling `Self::state` on multiple accumulators).
876    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
877        if states.is_empty() {
878            return Ok(());
879        }
880        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
881        let array = &states[0];
882        let list_array = array.as_list::<i32>();
883        for inner_array in list_array.iter() {
884            let Some(inner_array) = inner_array else {
885                return internal_err!(
886                    "Intermediate results of COUNT DISTINCT should always be non null"
887                );
888            };
889            self.update_batch(&[inner_array])?;
890        }
891        Ok(())
892    }
893
894    fn evaluate(&mut self) -> Result<ScalarValue> {
895        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
896    }
897
898    fn size(&self) -> usize {
899        match &self.state_data_type {
900            DataType::Boolean | DataType::Null => self.fixed_size(),
901            d if d.is_primitive() => self.fixed_size(),
902            _ => self.full_size(),
903        }
904    }
905}
906
907#[cfg(test)]
908mod tests {
909
910    use super::*;
911    use arrow::{
912        array::{DictionaryArray, Int32Array, NullArray, StringArray},
913        datatypes::{DataType, Field, Int32Type, Schema},
914    };
915    use datafusion_expr::function::AccumulatorArgs;
916    use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
917    use std::sync::Arc;
918    /// Helper function to create a dictionary array with non-null keys but some null values
919    /// Returns a dictionary array where:
920    /// - keys are [0, 1, 2, 0, 1] (all non-null)
921    /// - values are ["a", null, "c"]
922    /// - so the keys reference: "a", null, "c", "a", null
923    fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
924        let values = StringArray::from(vec![Some("a"), None, Some("c")]);
925        let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
926        Ok(DictionaryArray::<Int32Type>::try_new(
927            keys,
928            Arc::new(values),
929        )?)
930    }
931
932    #[test]
933    fn count_accumulator_nulls() -> Result<()> {
934        let mut accumulator = CountAccumulator::new();
935        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
936        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
937        Ok(())
938    }
939
940    #[test]
941    fn test_nested_dictionary() -> Result<()> {
942        let schema = Arc::new(Schema::new(vec![Field::new(
943            "dict_col",
944            DataType::Dictionary(
945                Box::new(DataType::Int32),
946                Box::new(DataType::Dictionary(
947                    Box::new(DataType::Int32),
948                    Box::new(DataType::Utf8),
949                )),
950            ),
951            true,
952        )]));
953
954        // Using Count UDAF's accumulator
955        let count = Count::new();
956        let expr = Arc::new(Column::new("dict_col", 0));
957        let expr_field = expr.return_field(&schema)?;
958        let args = AccumulatorArgs {
959            schema: &schema,
960            expr_fields: &[expr_field],
961            exprs: &[expr],
962            is_distinct: true,
963            name: "count",
964            ignore_nulls: false,
965            is_reversed: false,
966            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
967            order_bys: &[],
968        };
969
970        let inner_dict =
971            DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
972
973        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
974        let dict_of_dict =
975            DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
976
977        let mut acc = count.accumulator(args)?;
978        acc.update_batch(&[Arc::new(dict_of_dict)])?;
979        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
980
981        Ok(())
982    }
983
984    #[test]
985    fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
986        let dict_array = create_dictionary_with_null_values()?;
987
988        // The expected behavior is that count_distinct should count only non-null values
989        // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
990        let mut accumulator = DistinctCountAccumulator {
991            values: HashSet::default(),
992            state_data_type: dict_array.data_type().clone(),
993        };
994
995        accumulator.update_batch(&[Arc::new(dict_array)])?;
996
997        // Should have 2 distinct non-null values ("a" and "c")
998        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
999        Ok(())
1000    }
1001
1002    #[test]
1003    fn count_accumulator_dictionary_with_null_values() -> Result<()> {
1004        let dict_array = create_dictionary_with_null_values()?;
1005
1006        // The expected behavior is that count should only count non-null values
1007        let mut accumulator = CountAccumulator::new();
1008
1009        accumulator.update_batch(&[Arc::new(dict_array)])?;
1010
1011        // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
1012        // So we should count 3 non-null values
1013        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
1014        Ok(())
1015    }
1016
1017    #[test]
1018    fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
1019        // Create a dictionary array that only contains null values
1020        let dict_values = StringArray::from(vec![None, Some("abc")]);
1021        let dict_indices = Int32Array::from(vec![0; 5]);
1022        let dict_array =
1023            DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
1024
1025        let mut accumulator = DistinctCountAccumulator {
1026            values: HashSet::default(),
1027            state_data_type: dict_array.data_type().clone(),
1028        };
1029
1030        accumulator.update_batch(&[Arc::new(dict_array)])?;
1031
1032        // All referenced values are null so count(distinct) should be 0
1033        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1034        Ok(())
1035    }
1036
1037    #[test]
1038    fn sliding_distinct_count_accumulator_basic() -> Result<()> {
1039        // Basic update_batch + evaluate functionality
1040        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1041        // Create an Int32Array: [1, 2, 2, 3, null]
1042        let values: ArrayRef = Arc::new(Int32Array::from(vec![
1043            Some(1),
1044            Some(2),
1045            Some(2),
1046            Some(3),
1047            None,
1048        ]));
1049        acc.update_batch(&[values])?;
1050        // Expect distinct values {1,2,3} → count = 3
1051        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
1052        Ok(())
1053    }
1054
1055    #[test]
1056    fn sliding_distinct_count_accumulator_retract() -> Result<()> {
1057        // Test that retract_batch properly decrements counts
1058        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
1059        // Initial batch: ["a", "b", "a"]
1060        let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1061            as ArrayRef;
1062        acc.update_batch(&[arr1])?;
1063        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"}
1064
1065        // Retract batch: ["a", null, "b"]
1066        let arr2 =
1067            Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1068        acc.retract_batch(&[arr2])?;
1069        // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"}
1070        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1071        Ok(())
1072    }
1073
1074    #[test]
1075    fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1076        // Test merging multiple accumulator states with merge_batch
1077        let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1078        let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1079        // acc1 sees [1, 2]
1080        acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1081        // acc2 sees [2, 3]
1082        acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1083        // Extract their states as Vec<ScalarValue>
1084        let state_sv1 = acc1.state()?;
1085        let state_sv2 = acc2.state()?;
1086        // Convert ScalarValue states into Vec<ArrayRef>, propagating errors
1087        // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray
1088        let state_arr1: Vec<ArrayRef> = state_sv1
1089            .into_iter()
1090            .map(|sv| sv.to_array())
1091            .collect::<Result<_>>()?;
1092        let state_arr2: Vec<ArrayRef> = state_sv2
1093            .into_iter()
1094            .map(|sv| sv.to_array())
1095            .collect::<Result<_>>()?;
1096        // Merge both states into a fresh accumulator
1097        let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1098        merged.merge_batch(&state_arr1)?;
1099        merged.merge_batch(&state_arr2)?;
1100        // Expect distinct {1,2,3} → count = 3
1101        assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1102        Ok(())
1103    }
1104}