Skip to main content

datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use arrow::{
20    array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21    buffer::BooleanBuffer,
22    compute,
23    datatypes::{
24        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25        FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
26        Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29        UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30    },
31};
32use datafusion_common::{
33    HashMap, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
34    stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38    ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39    WindowFunctionDefinition,
40    expr::WindowFunction,
41    function::{AccumulatorArgs, StateFieldsArgs},
42    utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45    count_distinct::BytesDistinctCountAccumulator,
46    count_distinct::BytesViewDistinctCountAccumulator,
47    count_distinct::DictionaryCountAccumulator,
48    count_distinct::FloatDistinctCountAccumulator,
49    count_distinct::PrimitiveDistinctCountAccumulator,
50    groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56    collections::HashSet,
57    fmt::Debug,
58    mem::{size_of, size_of_val},
59    ops::BitAnd,
60    sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64    Count,
65    count,
66    expr,
67    "Count the number of non-null values in the column",
68    count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73        count_udaf(),
74        vec![expr],
75        true,
76        None,
77        vec![],
78        None,
79    ))
80}
81
82/// Creates aggregation to count all rows.
83///
84/// In SQL this is `SELECT COUNT(*) ... `
85///
86/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
87/// aliased to a column named `"count(*)"` for backward compatibility.
88///
89/// Example
90/// ```
91/// # use datafusion_functions_aggregate::count::count_all;
92/// # use datafusion_expr::col;
93/// // create `count(*)` expression
94/// let expr = count_all();
95/// assert_eq!(expr.schema_name().to_string(), "count(*)");
96/// // if you need to refer to this column, use the `schema_name` function
97/// let expr = col(expr.schema_name().to_string());
98/// ```
99pub fn count_all() -> Expr {
100    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103/// Creates window aggregation to count all rows.
104///
105/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
106///
107/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
108///
109/// Example
110/// ```
111/// # use datafusion_functions_aggregate::count::count_all_window;
112/// # use datafusion_expr::col;
113/// // create `count(*)` OVER ... window function expression
114/// let expr = count_all_window();
115/// assert_eq!(
116///     expr.schema_name().to_string(),
117///     "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
118/// );
119/// // if you need to refer to this column, use the `schema_name` function
120/// let expr = col(expr.schema_name().to_string());
121/// ```
122pub fn count_all_window() -> Expr {
123    Expr::from(WindowFunction::new(
124        WindowFunctionDefinition::AggregateUDF(count_udaf()),
125        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126    ))
127}
128
129#[user_doc(
130    doc_section(label = "General Functions"),
131    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132    syntax_example = "count(expression)",
133    sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name)     |
137+-----------------------+
138| 100                   |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*)         |
144+------------------+
145| 120              |
146+------------------+
147```"#,
148    standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash, Debug)]
151pub struct Count {
152    signature: Signature,
153}
154
155impl Default for Count {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl Count {
162    pub fn new() -> Self {
163        Self {
164            signature: Signature::one_of(
165                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
166                Volatility::Immutable,
167            ),
168        }
169    }
170}
171fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
172    match data_type {
173        // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
174        DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
175            data_type,
176        )),
177        DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
178            data_type,
179        )),
180        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
181            data_type,
182        )),
183        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
184            data_type,
185        )),
186        DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
187            data_type,
188        )),
189        DataType::UInt16 => Box::new(
190            PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
191        ),
192        DataType::UInt32 => Box::new(
193            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
194        ),
195        DataType::UInt64 => Box::new(
196            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
197        ),
198        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
199            Decimal128Type,
200        >::new(data_type)),
201        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
202            Decimal256Type,
203        >::new(data_type)),
204
205        DataType::Date32 => Box::new(
206            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
207        ),
208        DataType::Date64 => Box::new(
209            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
210        ),
211        DataType::Time32(TimeUnit::Millisecond) => Box::new(
212            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
213        ),
214        DataType::Time32(TimeUnit::Second) => Box::new(
215            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
216        ),
217        DataType::Time64(TimeUnit::Microsecond) => Box::new(
218            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
219        ),
220        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
221            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
222        ),
223        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
224            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
225        ),
226        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
227            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
228        ),
229        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
230            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
231        ),
232        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
233            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
234        ),
235
236        DataType::Float16 => {
237            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
238        }
239        DataType::Float32 => {
240            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
241        }
242        DataType::Float64 => {
243            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
244        }
245
246        DataType::Utf8 => {
247            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
248        }
249        DataType::Utf8View => {
250            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
251        }
252        DataType::LargeUtf8 => {
253            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
254        }
255        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
256            OutputType::Binary,
257        )),
258        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
259            OutputType::BinaryView,
260        )),
261        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
262            OutputType::Binary,
263        )),
264
265        // Use the generic accumulator based on `ScalarValue` for all other types
266        _ => Box::new(DistinctCountAccumulator {
267            values: HashSet::default(),
268            state_data_type: data_type.clone(),
269        }),
270    }
271}
272
273impl AggregateUDFImpl for Count {
274    fn as_any(&self) -> &dyn std::any::Any {
275        self
276    }
277
278    fn name(&self) -> &str {
279        "count"
280    }
281
282    fn signature(&self) -> &Signature {
283        &self.signature
284    }
285
286    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
287        Ok(DataType::Int64)
288    }
289
290    fn is_nullable(&self) -> bool {
291        false
292    }
293
294    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
295        if args.is_distinct {
296            let dtype: DataType = match &args.input_fields[0].data_type() {
297                DataType::Dictionary(_, values_type) => (**values_type).clone(),
298                &dtype => dtype.clone(),
299            };
300
301            Ok(vec![
302                Field::new_list(
303                    format_state_name(args.name, "count distinct"),
304                    // See COMMENTS.md to understand why nullable is set to true
305                    Field::new_list_field(dtype, true),
306                    false,
307                )
308                .into(),
309            ])
310        } else {
311            Ok(vec![
312                Field::new(
313                    format_state_name(args.name, "count"),
314                    DataType::Int64,
315                    false,
316                )
317                .into(),
318            ])
319        }
320    }
321
322    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
323        if !acc_args.is_distinct {
324            return Ok(Box::new(CountAccumulator::new()));
325        }
326
327        if acc_args.exprs.len() > 1 {
328            return not_impl_err!("COUNT DISTINCT with multiple arguments");
329        }
330
331        let data_type = acc_args.expr_fields[0].data_type();
332
333        Ok(match data_type {
334            DataType::Dictionary(_, values_type) => {
335                let inner = get_count_accumulator(values_type);
336                Box::new(DictionaryCountAccumulator::new(inner))
337            }
338            _ => get_count_accumulator(data_type),
339        })
340    }
341
342    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
343        // groups accumulator only supports `COUNT(c1)`, not
344        // `COUNT(c1, c2)`, etc
345        if args.is_distinct {
346            return false;
347        }
348        args.exprs.len() == 1
349    }
350
351    fn create_groups_accumulator(
352        &self,
353        _args: AccumulatorArgs,
354    ) -> Result<Box<dyn GroupsAccumulator>> {
355        // instantiate specialized accumulator
356        Ok(Box::new(CountGroupsAccumulator::new()))
357    }
358
359    fn reverse_expr(&self) -> ReversedUDAF {
360        ReversedUDAF::Identical
361    }
362
363    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
364        Ok(ScalarValue::Int64(Some(0)))
365    }
366
367    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
368        if statistics_args.is_distinct {
369            return None;
370        }
371        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows
372            && statistics_args.exprs.len() == 1
373        {
374            // TODO optimize with exprs other than Column
375            if let Some(col_expr) = statistics_args.exprs[0]
376                .as_any()
377                .downcast_ref::<expressions::Column>()
378            {
379                let current_val = &statistics_args.statistics.column_statistics
380                    [col_expr.index()]
381                .null_count;
382                if let &Precision::Exact(val) = current_val {
383                    return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
384                }
385            } else if let Some(lit_expr) = statistics_args.exprs[0]
386                .as_any()
387                .downcast_ref::<expressions::Literal>()
388                && lit_expr.value() == &COUNT_STAR_EXPANSION
389            {
390                return Some(ScalarValue::Int64(Some(num_rows as i64)));
391            }
392        }
393        None
394    }
395
396    fn documentation(&self) -> Option<&Documentation> {
397        self.doc()
398    }
399
400    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
401        // `COUNT` is monotonically increasing as it always increases or stays
402        // the same as new values are seen.
403        SetMonotonicity::Increasing
404    }
405
406    fn create_sliding_accumulator(
407        &self,
408        args: AccumulatorArgs,
409    ) -> Result<Box<dyn Accumulator>> {
410        if args.is_distinct {
411            let acc =
412                SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
413            Ok(Box::new(acc))
414        } else {
415            let acc = CountAccumulator::new();
416            Ok(Box::new(acc))
417        }
418    }
419}
420
421// DistinctCountAccumulator does not support retract_batch and sliding window
422// this is a specialized accumulator for distinct count that supports retract_batch
423// and sliding window.
424#[derive(Debug)]
425pub struct SlidingDistinctCountAccumulator {
426    counts: HashMap<ScalarValue, usize, RandomState>,
427    data_type: DataType,
428}
429
430impl SlidingDistinctCountAccumulator {
431    pub fn try_new(data_type: &DataType) -> Result<Self> {
432        Ok(Self {
433            counts: HashMap::default(),
434            data_type: data_type.clone(),
435        })
436    }
437}
438
439impl Accumulator for SlidingDistinctCountAccumulator {
440    fn state(&mut self) -> Result<Vec<ScalarValue>> {
441        let keys = self.counts.keys().cloned().collect::<Vec<_>>();
442        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
443            keys.as_slice(),
444            &self.data_type,
445        ))])
446    }
447
448    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
449        let arr = &values[0];
450        for i in 0..arr.len() {
451            let v = ScalarValue::try_from_array(arr, i)?;
452            if !v.is_null() {
453                *self.counts.entry(v).or_default() += 1;
454            }
455        }
456        Ok(())
457    }
458
459    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
460        let arr = &values[0];
461        for i in 0..arr.len() {
462            let v = ScalarValue::try_from_array(arr, i)?;
463            if !v.is_null()
464                && let Some(cnt) = self.counts.get_mut(&v)
465            {
466                *cnt -= 1;
467                if *cnt == 0 {
468                    self.counts.remove(&v);
469                }
470            }
471        }
472        Ok(())
473    }
474
475    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
476        let list_arr = states[0].as_list::<i32>();
477        for inner in list_arr.iter().flatten() {
478            for j in 0..inner.len() {
479                let v = ScalarValue::try_from_array(&*inner, j)?;
480                *self.counts.entry(v).or_default() += 1;
481            }
482        }
483        Ok(())
484    }
485
486    fn evaluate(&mut self) -> Result<ScalarValue> {
487        Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
488    }
489
490    fn supports_retract_batch(&self) -> bool {
491        true
492    }
493
494    fn size(&self) -> usize {
495        size_of_val(self)
496    }
497}
498
499#[derive(Debug)]
500struct CountAccumulator {
501    count: i64,
502}
503
504impl CountAccumulator {
505    /// new count accumulator
506    pub fn new() -> Self {
507        Self { count: 0 }
508    }
509}
510
511impl Accumulator for CountAccumulator {
512    fn state(&mut self) -> Result<Vec<ScalarValue>> {
513        Ok(vec![ScalarValue::Int64(Some(self.count))])
514    }
515
516    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517        let array = &values[0];
518        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
519        Ok(())
520    }
521
522    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523        let array = &values[0];
524        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
525        Ok(())
526    }
527
528    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
529        let counts = downcast_value!(states[0], Int64Array);
530        let delta = &compute::sum(counts);
531        if let Some(d) = delta {
532            self.count += *d;
533        }
534        Ok(())
535    }
536
537    fn evaluate(&mut self) -> Result<ScalarValue> {
538        Ok(ScalarValue::Int64(Some(self.count)))
539    }
540
541    fn supports_retract_batch(&self) -> bool {
542        true
543    }
544
545    fn size(&self) -> usize {
546        size_of_val(self)
547    }
548}
549
550/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
551/// Stores values as native types, and does overflow checking
552///
553/// Unlike most other accumulators, COUNT never produces NULLs. If no
554/// non-null values are seen in any group the output is 0. Thus, this
555/// accumulator has no additional null or seen filter tracking.
556#[derive(Debug)]
557struct CountGroupsAccumulator {
558    /// Count per group.
559    ///
560    /// Note this is an i64 and not a u64 (or usize) because the
561    /// output type of count is `DataType::Int64`. Thus by using `i64`
562    /// for the counts, the output [`Int64Array`] can be created
563    /// without copy.
564    counts: Vec<i64>,
565}
566
567impl CountGroupsAccumulator {
568    pub fn new() -> Self {
569        Self { counts: vec![] }
570    }
571}
572
573impl GroupsAccumulator for CountGroupsAccumulator {
574    fn update_batch(
575        &mut self,
576        values: &[ArrayRef],
577        group_indices: &[usize],
578        opt_filter: Option<&BooleanArray>,
579        total_num_groups: usize,
580    ) -> Result<()> {
581        assert_eq!(values.len(), 1, "single argument to update_batch");
582        let values = &values[0];
583
584        // Add one to each group's counter for each non null, non
585        // filtered value
586        self.counts.resize(total_num_groups, 0);
587        accumulate_indices(
588            group_indices,
589            values.logical_nulls().as_ref(),
590            opt_filter,
591            |group_index| {
592                // SAFETY: group_index is guaranteed to be in bounds
593                let count = unsafe { self.counts.get_unchecked_mut(group_index) };
594                *count += 1;
595            },
596        );
597
598        Ok(())
599    }
600
601    fn merge_batch(
602        &mut self,
603        values: &[ArrayRef],
604        group_indices: &[usize],
605        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
606        _opt_filter: Option<&BooleanArray>,
607        total_num_groups: usize,
608    ) -> Result<()> {
609        assert_eq!(values.len(), 1, "one argument to merge_batch");
610        // first batch is counts, second is partial sums
611        let partial_counts = values[0].as_primitive::<Int64Type>();
612
613        // intermediate counts are always created as non null
614        assert_eq!(partial_counts.null_count(), 0);
615        let partial_counts = partial_counts.values();
616
617        // Adds the counts with the partial counts
618        self.counts.resize(total_num_groups, 0);
619        group_indices.iter().zip(partial_counts.iter()).for_each(
620            |(&group_index, partial_count)| {
621                self.counts[group_index] += partial_count;
622            },
623        );
624
625        Ok(())
626    }
627
628    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
629        let counts = emit_to.take_needed(&mut self.counts);
630
631        // Count is always non null (null inputs just don't contribute to the overall values)
632        let nulls = None;
633        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
634
635        Ok(Arc::new(array))
636    }
637
638    // return arrays for counts
639    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
640        let counts = emit_to.take_needed(&mut self.counts);
641        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
642        Ok(vec![Arc::new(counts) as ArrayRef])
643    }
644
645    /// Converts an input batch directly to a state batch
646    ///
647    /// The state of `COUNT` is always a single Int64Array:
648    /// * `1` (for non-null, non filtered values)
649    /// * `0` (for null values)
650    fn convert_to_state(
651        &self,
652        values: &[ArrayRef],
653        opt_filter: Option<&BooleanArray>,
654    ) -> Result<Vec<ArrayRef>> {
655        let values = &values[0];
656
657        let state_array = match (values.logical_nulls(), opt_filter) {
658            (None, None) => {
659                // In case there is no nulls in input and no filter, returning array of 1
660                Arc::new(Int64Array::from_value(1, values.len()))
661            }
662            (Some(nulls), None) => {
663                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
664                // of input array to Int64
665                let nulls = BooleanArray::new(nulls.into_inner(), None);
666                compute::cast(&nulls, &DataType::Int64)?
667            }
668            (None, Some(filter)) => {
669                // If there is only filter
670                // - applying filter null mask to filter values by bitand filter values and nulls buffers
671                //   (using buffers guarantees absence of nulls in result)
672                // - casting result of bitand to Int64 array
673                let (filter_values, filter_nulls) = filter.clone().into_parts();
674
675                let state_buf = match filter_nulls {
676                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
677                    None => filter_values,
678                };
679
680                let boolean_state = BooleanArray::new(state_buf, None);
681                compute::cast(&boolean_state, &DataType::Int64)?
682            }
683            (Some(nulls), Some(filter)) => {
684                // For both input nulls and filter
685                // - applying filter null mask to filter values by bitand filter values and nulls buffers
686                //   (using buffers guarantees absence of nulls in result)
687                // - applying values null mask to filter buffer by another bitand on filter result and
688                //   nulls from input values
689                // - casting result to Int64 array
690                let (filter_values, filter_nulls) = filter.clone().into_parts();
691
692                let filter_buf = match filter_nulls {
693                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
694                    None => filter_values,
695                };
696                let state_buf = &filter_buf & nulls.inner();
697
698                let boolean_state = BooleanArray::new(state_buf, None);
699                compute::cast(&boolean_state, &DataType::Int64)?
700            }
701        };
702
703        Ok(vec![state_array])
704    }
705
706    fn supports_convert_to_state(&self) -> bool {
707        true
708    }
709
710    fn size(&self) -> usize {
711        self.counts.capacity() * size_of::<usize>()
712    }
713}
714
715/// count null values for multiple columns
716/// for each row if one column value is null, then null_count + 1
717fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
718    if values.len() > 1 {
719        let result_bool_buf: Option<BooleanBuffer> = values
720            .iter()
721            .map(|a| a.logical_nulls())
722            .fold(None, |acc, b| match (acc, b) {
723                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
724                (Some(acc), None) => Some(acc),
725                (None, Some(b)) => Some(b.into_inner()),
726                _ => None,
727            });
728        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
729    } else {
730        values[0]
731            .logical_nulls()
732            .map_or(0, |nulls| nulls.null_count())
733    }
734}
735
736/// General purpose distinct accumulator that works for any DataType by using
737/// [`ScalarValue`].
738///
739/// It stores intermediate results as a `ListArray`
740///
741/// Note that many types have specialized accumulators that are (much)
742/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
743/// [`BytesDistinctCountAccumulator`]
744#[derive(Debug)]
745struct DistinctCountAccumulator {
746    values: HashSet<ScalarValue, RandomState>,
747    state_data_type: DataType,
748}
749
750impl DistinctCountAccumulator {
751    // calculating the size for fixed length values, taking first batch size *
752    // number of batches This method is faster than .full_size(), however it is
753    // not suitable for variable length values like strings or complex types
754    fn fixed_size(&self) -> usize {
755        size_of_val(self)
756            + (size_of::<ScalarValue>() * self.values.capacity())
757            + self
758                .values
759                .iter()
760                .next()
761                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
762                .unwrap_or(0)
763            + size_of::<DataType>()
764    }
765
766    // calculates the size as accurately as possible. Note that calling this
767    // method is expensive
768    fn full_size(&self) -> usize {
769        size_of_val(self)
770            + (size_of::<ScalarValue>() * self.values.capacity())
771            + self
772                .values
773                .iter()
774                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
775                .sum::<usize>()
776            + size_of::<DataType>()
777    }
778}
779
780impl Accumulator for DistinctCountAccumulator {
781    /// Returns the distinct values seen so far as (one element) ListArray.
782    fn state(&mut self) -> Result<Vec<ScalarValue>> {
783        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
784        let arr =
785            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
786        Ok(vec![ScalarValue::List(arr)])
787    }
788
789    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
790        if values.is_empty() {
791            return Ok(());
792        }
793
794        let arr = &values[0];
795        if arr.data_type() == &DataType::Null {
796            return Ok(());
797        }
798
799        (0..arr.len()).try_for_each(|index| {
800            let scalar = ScalarValue::try_from_array(arr, index)?;
801            if !scalar.is_null() {
802                self.values.insert(scalar);
803            }
804            Ok(())
805        })
806    }
807
808    /// Merges multiple sets of distinct values into the current set.
809    ///
810    /// The input to this function is a `ListArray` with **multiple** rows,
811    /// where each row contains the values from a partial aggregate's phase (e.g.
812    /// the result of calling `Self::state` on multiple accumulators).
813    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
814        if states.is_empty() {
815            return Ok(());
816        }
817        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
818        let array = &states[0];
819        let list_array = array.as_list::<i32>();
820        for inner_array in list_array.iter() {
821            let Some(inner_array) = inner_array else {
822                return internal_err!(
823                    "Intermediate results of COUNT DISTINCT should always be non null"
824                );
825            };
826            self.update_batch(&[inner_array])?;
827        }
828        Ok(())
829    }
830
831    fn evaluate(&mut self) -> Result<ScalarValue> {
832        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
833    }
834
835    fn size(&self) -> usize {
836        match &self.state_data_type {
837            DataType::Boolean | DataType::Null => self.fixed_size(),
838            d if d.is_primitive() => self.fixed_size(),
839            _ => self.full_size(),
840        }
841    }
842}
843
844#[cfg(test)]
845mod tests {
846
847    use super::*;
848    use arrow::{
849        array::{DictionaryArray, Int32Array, NullArray, StringArray},
850        datatypes::{DataType, Field, Int32Type, Schema},
851    };
852    use datafusion_expr::function::AccumulatorArgs;
853    use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
854    use std::sync::Arc;
855    /// Helper function to create a dictionary array with non-null keys but some null values
856    /// Returns a dictionary array where:
857    /// - keys are [0, 1, 2, 0, 1] (all non-null)
858    /// - values are ["a", null, "c"]
859    /// - so the keys reference: "a", null, "c", "a", null
860    fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
861        let values = StringArray::from(vec![Some("a"), None, Some("c")]);
862        let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
863        Ok(DictionaryArray::<Int32Type>::try_new(
864            keys,
865            Arc::new(values),
866        )?)
867    }
868
869    #[test]
870    fn count_accumulator_nulls() -> Result<()> {
871        let mut accumulator = CountAccumulator::new();
872        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
873        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
874        Ok(())
875    }
876
877    #[test]
878    fn test_nested_dictionary() -> Result<()> {
879        let schema = Arc::new(Schema::new(vec![Field::new(
880            "dict_col",
881            DataType::Dictionary(
882                Box::new(DataType::Int32),
883                Box::new(DataType::Dictionary(
884                    Box::new(DataType::Int32),
885                    Box::new(DataType::Utf8),
886                )),
887            ),
888            true,
889        )]));
890
891        // Using Count UDAF's accumulator
892        let count = Count::new();
893        let expr = Arc::new(Column::new("dict_col", 0));
894        let expr_field = expr.return_field(&schema)?;
895        let args = AccumulatorArgs {
896            schema: &schema,
897            expr_fields: &[expr_field],
898            exprs: &[expr],
899            is_distinct: true,
900            name: "count",
901            ignore_nulls: false,
902            is_reversed: false,
903            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
904            order_bys: &[],
905        };
906
907        let inner_dict =
908            DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
909
910        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
911        let dict_of_dict =
912            DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
913
914        let mut acc = count.accumulator(args)?;
915        acc.update_batch(&[Arc::new(dict_of_dict)])?;
916        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
917
918        Ok(())
919    }
920
921    #[test]
922    fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
923        let dict_array = create_dictionary_with_null_values()?;
924
925        // The expected behavior is that count_distinct should count only non-null values
926        // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
927        let mut accumulator = DistinctCountAccumulator {
928            values: HashSet::default(),
929            state_data_type: dict_array.data_type().clone(),
930        };
931
932        accumulator.update_batch(&[Arc::new(dict_array)])?;
933
934        // Should have 2 distinct non-null values ("a" and "c")
935        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
936        Ok(())
937    }
938
939    #[test]
940    fn count_accumulator_dictionary_with_null_values() -> Result<()> {
941        let dict_array = create_dictionary_with_null_values()?;
942
943        // The expected behavior is that count should only count non-null values
944        let mut accumulator = CountAccumulator::new();
945
946        accumulator.update_batch(&[Arc::new(dict_array)])?;
947
948        // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
949        // So we should count 3 non-null values
950        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
951        Ok(())
952    }
953
954    #[test]
955    fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
956        // Create a dictionary array that only contains null values
957        let dict_values = StringArray::from(vec![None, Some("abc")]);
958        let dict_indices = Int32Array::from(vec![0; 5]);
959        let dict_array =
960            DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
961
962        let mut accumulator = DistinctCountAccumulator {
963            values: HashSet::default(),
964            state_data_type: dict_array.data_type().clone(),
965        };
966
967        accumulator.update_batch(&[Arc::new(dict_array)])?;
968
969        // All referenced values are null so count(distinct) should be 0
970        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
971        Ok(())
972    }
973
974    #[test]
975    fn sliding_distinct_count_accumulator_basic() -> Result<()> {
976        // Basic update_batch + evaluate functionality
977        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
978        // Create an Int32Array: [1, 2, 2, 3, null]
979        let values: ArrayRef = Arc::new(Int32Array::from(vec![
980            Some(1),
981            Some(2),
982            Some(2),
983            Some(3),
984            None,
985        ]));
986        acc.update_batch(&[values])?;
987        // Expect distinct values {1,2,3} → count = 3
988        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
989        Ok(())
990    }
991
992    #[test]
993    fn sliding_distinct_count_accumulator_retract() -> Result<()> {
994        // Test that retract_batch properly decrements counts
995        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
996        // Initial batch: ["a", "b", "a"]
997        let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
998            as ArrayRef;
999        acc.update_batch(&[arr1])?;
1000        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"}
1001
1002        // Retract batch: ["a", null, "b"]
1003        let arr2 =
1004            Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1005        acc.retract_batch(&[arr2])?;
1006        // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"}
1007        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1008        Ok(())
1009    }
1010
1011    #[test]
1012    fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1013        // Test merging multiple accumulator states with merge_batch
1014        let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1015        let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1016        // acc1 sees [1, 2]
1017        acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1018        // acc2 sees [2, 3]
1019        acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1020        // Extract their states as Vec<ScalarValue>
1021        let state_sv1 = acc1.state()?;
1022        let state_sv2 = acc2.state()?;
1023        // Convert ScalarValue states into Vec<ArrayRef>, propagating errors
1024        // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray
1025        let state_arr1: Vec<ArrayRef> = state_sv1
1026            .into_iter()
1027            .map(|sv| sv.to_array())
1028            .collect::<Result<_>>()?;
1029        let state_arr2: Vec<ArrayRef> = state_sv2
1030            .into_iter()
1031            .map(|sv| sv.to_array())
1032            .collect::<Result<_>>()?;
1033        // Merge both states into a fresh accumulator
1034        let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1035        merged.merge_batch(&state_arr1)?;
1036        merged.merge_batch(&state_arr2)?;
1037        // Expect distinct {1,2,3} → count = 3
1038        assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1039        Ok(())
1040    }
1041}