datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use arrow::{
20    array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21    buffer::BooleanBuffer,
22    compute,
23    datatypes::{
24        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25        FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
26        Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30    },
31};
32use datafusion_common::{
33    downcast_value, internal_err, not_impl_err, stats::Precision,
34    utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue,
35};
36use datafusion_expr::{
37    expr::WindowFunction,
38    function::{AccumulatorArgs, StateFieldsArgs},
39    utils::format_state_name,
40    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
41    ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
42    WindowFunctionDefinition,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45    count_distinct::BytesDistinctCountAccumulator,
46    count_distinct::BytesViewDistinctCountAccumulator,
47    count_distinct::DictionaryCountAccumulator,
48    count_distinct::FloatDistinctCountAccumulator,
49    count_distinct::PrimitiveDistinctCountAccumulator,
50    groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56    collections::HashSet,
57    fmt::Debug,
58    mem::{size_of, size_of_val},
59    ops::BitAnd,
60    sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64    Count,
65    count,
66    expr,
67    "Count the number of non-null values in the column",
68    count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73        count_udaf(),
74        vec![expr],
75        true,
76        None,
77        vec![],
78        None,
79    ))
80}
81
82/// Creates aggregation to count all rows.
83///
84/// In SQL this is `SELECT COUNT(*) ... `
85///
86/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
87/// aliased to a column named `"count(*)"` for backward compatibility.
88///
89/// Example
90/// ```
91/// # use datafusion_functions_aggregate::count::count_all;
92/// # use datafusion_expr::col;
93/// // create `count(*)` expression
94/// let expr = count_all();
95/// assert_eq!(expr.schema_name().to_string(), "count(*)");
96/// // if you need to refer to this column, use the `schema_name` function
97/// let expr = col(expr.schema_name().to_string());
98/// ```
99pub fn count_all() -> Expr {
100    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103/// Creates window aggregation to count all rows.
104///
105/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
106///
107/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
108///
109/// Example
110/// ```
111/// # use datafusion_functions_aggregate::count::count_all_window;
112/// # use datafusion_expr::col;
113/// // create `count(*)` OVER ... window function expression
114/// let expr = count_all_window();
115/// assert_eq!(
116///   expr.schema_name().to_string(),
117///   "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
118/// );
119/// // if you need to refer to this column, use the `schema_name` function
120/// let expr = col(expr.schema_name().to_string());
121/// ```
122pub fn count_all_window() -> Expr {
123    Expr::from(WindowFunction::new(
124        WindowFunctionDefinition::AggregateUDF(count_udaf()),
125        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126    ))
127}
128
129#[user_doc(
130    doc_section(label = "General Functions"),
131    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132    syntax_example = "count(expression)",
133    sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name)     |
137+-----------------------+
138| 100                   |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*)         |
144+------------------+
145| 120              |
146+------------------+
147```"#,
148    standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash)]
151pub struct Count {
152    signature: Signature,
153}
154
155impl Debug for Count {
156    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157        f.debug_struct("Count")
158            .field("name", &self.name())
159            .field("signature", &self.signature)
160            .finish()
161    }
162}
163
164impl Default for Count {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl Count {
171    pub fn new() -> Self {
172        Self {
173            signature: Signature::one_of(
174                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
175                Volatility::Immutable,
176            ),
177        }
178    }
179}
180fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
181    match data_type {
182        // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
183        DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
184            data_type,
185        )),
186        DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
187            data_type,
188        )),
189        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
190            data_type,
191        )),
192        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
193            data_type,
194        )),
195        DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
196            data_type,
197        )),
198        DataType::UInt16 => Box::new(
199            PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
200        ),
201        DataType::UInt32 => Box::new(
202            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
203        ),
204        DataType::UInt64 => Box::new(
205            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
206        ),
207        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
208            Decimal128Type,
209        >::new(data_type)),
210        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211            Decimal256Type,
212        >::new(data_type)),
213
214        DataType::Date32 => Box::new(
215            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
216        ),
217        DataType::Date64 => Box::new(
218            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
219        ),
220        DataType::Time32(TimeUnit::Millisecond) => Box::new(
221            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
222        ),
223        DataType::Time32(TimeUnit::Second) => Box::new(
224            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
225        ),
226        DataType::Time64(TimeUnit::Microsecond) => Box::new(
227            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
228        ),
229        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
230            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
231        ),
232        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
233            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
234        ),
235        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
236            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
237        ),
238        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
239            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
240        ),
241        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
242            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
243        ),
244
245        DataType::Float16 => {
246            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
247        }
248        DataType::Float32 => {
249            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
250        }
251        DataType::Float64 => {
252            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
253        }
254
255        DataType::Utf8 => {
256            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
257        }
258        DataType::Utf8View => {
259            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
260        }
261        DataType::LargeUtf8 => {
262            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
263        }
264        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
265            OutputType::Binary,
266        )),
267        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
268            OutputType::BinaryView,
269        )),
270        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
271            OutputType::Binary,
272        )),
273
274        // Use the generic accumulator based on `ScalarValue` for all other types
275        _ => Box::new(DistinctCountAccumulator {
276            values: HashSet::default(),
277            state_data_type: data_type.clone(),
278        }),
279    }
280}
281
282impl AggregateUDFImpl for Count {
283    fn as_any(&self) -> &dyn std::any::Any {
284        self
285    }
286
287    fn name(&self) -> &str {
288        "count"
289    }
290
291    fn signature(&self) -> &Signature {
292        &self.signature
293    }
294
295    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296        Ok(DataType::Int64)
297    }
298
299    fn is_nullable(&self) -> bool {
300        false
301    }
302
303    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
304        if args.is_distinct {
305            let dtype: DataType = match &args.input_fields[0].data_type() {
306                DataType::Dictionary(_, values_type) => (**values_type).clone(),
307                &dtype => dtype.clone(),
308            };
309
310            Ok(vec![Field::new_list(
311                format_state_name(args.name, "count distinct"),
312                // See COMMENTS.md to understand why nullable is set to true
313                Field::new_list_field(dtype, true),
314                false,
315            )
316            .into()])
317        } else {
318            Ok(vec![Field::new(
319                format_state_name(args.name, "count"),
320                DataType::Int64,
321                false,
322            )
323            .into()])
324        }
325    }
326
327    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
328        if !acc_args.is_distinct {
329            return Ok(Box::new(CountAccumulator::new()));
330        }
331
332        if acc_args.exprs.len() > 1 {
333            return not_impl_err!("COUNT DISTINCT with multiple arguments");
334        }
335
336        let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
337
338        Ok(match data_type {
339            DataType::Dictionary(_, values_type) => {
340                let inner = get_count_accumulator(values_type);
341                Box::new(DictionaryCountAccumulator::new(inner))
342            }
343            _ => get_count_accumulator(data_type),
344        })
345    }
346
347    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348        // groups accumulator only supports `COUNT(c1)`, not
349        // `COUNT(c1, c2)`, etc
350        if args.is_distinct {
351            return false;
352        }
353        args.exprs.len() == 1
354    }
355
356    fn create_groups_accumulator(
357        &self,
358        _args: AccumulatorArgs,
359    ) -> Result<Box<dyn GroupsAccumulator>> {
360        // instantiate specialized accumulator
361        Ok(Box::new(CountGroupsAccumulator::new()))
362    }
363
364    fn reverse_expr(&self) -> ReversedUDAF {
365        ReversedUDAF::Identical
366    }
367
368    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
369        Ok(ScalarValue::Int64(Some(0)))
370    }
371
372    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
373        if statistics_args.is_distinct {
374            return None;
375        }
376        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
377            if statistics_args.exprs.len() == 1 {
378                // TODO optimize with exprs other than Column
379                if let Some(col_expr) = statistics_args.exprs[0]
380                    .as_any()
381                    .downcast_ref::<expressions::Column>()
382                {
383                    let current_val = &statistics_args.statistics.column_statistics
384                        [col_expr.index()]
385                    .null_count;
386                    if let &Precision::Exact(val) = current_val {
387                        return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
388                    }
389                } else if let Some(lit_expr) = statistics_args.exprs[0]
390                    .as_any()
391                    .downcast_ref::<expressions::Literal>()
392                {
393                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
394                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
395                    }
396                }
397            }
398        }
399        None
400    }
401
402    fn documentation(&self) -> Option<&Documentation> {
403        self.doc()
404    }
405
406    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
407        // `COUNT` is monotonically increasing as it always increases or stays
408        // the same as new values are seen.
409        SetMonotonicity::Increasing
410    }
411
412    fn create_sliding_accumulator(
413        &self,
414        args: AccumulatorArgs,
415    ) -> Result<Box<dyn Accumulator>> {
416        if args.is_distinct {
417            let acc =
418                SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
419            Ok(Box::new(acc))
420        } else {
421            let acc = CountAccumulator::new();
422            Ok(Box::new(acc))
423        }
424    }
425}
426
427// DistinctCountAccumulator does not support retract_batch and sliding window
428// this is a specialized accumulator for distinct count that supports retract_batch
429// and sliding window.
430#[derive(Debug)]
431pub struct SlidingDistinctCountAccumulator {
432    counts: HashMap<ScalarValue, usize, RandomState>,
433    data_type: DataType,
434}
435
436impl SlidingDistinctCountAccumulator {
437    pub fn try_new(data_type: &DataType) -> Result<Self> {
438        Ok(Self {
439            counts: HashMap::default(),
440            data_type: data_type.clone(),
441        })
442    }
443}
444
445impl Accumulator for SlidingDistinctCountAccumulator {
446    fn state(&mut self) -> Result<Vec<ScalarValue>> {
447        let keys = self.counts.keys().cloned().collect::<Vec<_>>();
448        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
449            keys.as_slice(),
450            &self.data_type,
451        ))])
452    }
453
454    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
455        let arr = &values[0];
456        for i in 0..arr.len() {
457            let v = ScalarValue::try_from_array(arr, i)?;
458            if !v.is_null() {
459                *self.counts.entry(v).or_default() += 1;
460            }
461        }
462        Ok(())
463    }
464
465    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
466        let arr = &values[0];
467        for i in 0..arr.len() {
468            let v = ScalarValue::try_from_array(arr, i)?;
469            if !v.is_null() {
470                if let Some(cnt) = self.counts.get_mut(&v) {
471                    *cnt -= 1;
472                    if *cnt == 0 {
473                        self.counts.remove(&v);
474                    }
475                }
476            }
477        }
478        Ok(())
479    }
480
481    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
482        let list_arr = states[0].as_list::<i32>();
483        for inner in list_arr.iter().flatten() {
484            for j in 0..inner.len() {
485                let v = ScalarValue::try_from_array(&*inner, j)?;
486                *self.counts.entry(v).or_default() += 1;
487            }
488        }
489        Ok(())
490    }
491
492    fn evaluate(&mut self) -> Result<ScalarValue> {
493        Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
494    }
495
496    fn supports_retract_batch(&self) -> bool {
497        true
498    }
499
500    fn size(&self) -> usize {
501        size_of_val(self)
502    }
503}
504
505#[derive(Debug)]
506struct CountAccumulator {
507    count: i64,
508}
509
510impl CountAccumulator {
511    /// new count accumulator
512    pub fn new() -> Self {
513        Self { count: 0 }
514    }
515}
516
517impl Accumulator for CountAccumulator {
518    fn state(&mut self) -> Result<Vec<ScalarValue>> {
519        Ok(vec![ScalarValue::Int64(Some(self.count))])
520    }
521
522    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523        let array = &values[0];
524        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
525        Ok(())
526    }
527
528    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
529        let array = &values[0];
530        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
531        Ok(())
532    }
533
534    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
535        let counts = downcast_value!(states[0], Int64Array);
536        let delta = &compute::sum(counts);
537        if let Some(d) = delta {
538            self.count += *d;
539        }
540        Ok(())
541    }
542
543    fn evaluate(&mut self) -> Result<ScalarValue> {
544        Ok(ScalarValue::Int64(Some(self.count)))
545    }
546
547    fn supports_retract_batch(&self) -> bool {
548        true
549    }
550
551    fn size(&self) -> usize {
552        size_of_val(self)
553    }
554}
555
556/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
557/// Stores values as native types, and does overflow checking
558///
559/// Unlike most other accumulators, COUNT never produces NULLs. If no
560/// non-null values are seen in any group the output is 0. Thus, this
561/// accumulator has no additional null or seen filter tracking.
562#[derive(Debug)]
563struct CountGroupsAccumulator {
564    /// Count per group.
565    ///
566    /// Note this is an i64 and not a u64 (or usize) because the
567    /// output type of count is `DataType::Int64`. Thus by using `i64`
568    /// for the counts, the output [`Int64Array`] can be created
569    /// without copy.
570    counts: Vec<i64>,
571}
572
573impl CountGroupsAccumulator {
574    pub fn new() -> Self {
575        Self { counts: vec![] }
576    }
577}
578
579impl GroupsAccumulator for CountGroupsAccumulator {
580    fn update_batch(
581        &mut self,
582        values: &[ArrayRef],
583        group_indices: &[usize],
584        opt_filter: Option<&BooleanArray>,
585        total_num_groups: usize,
586    ) -> Result<()> {
587        assert_eq!(values.len(), 1, "single argument to update_batch");
588        let values = &values[0];
589
590        // Add one to each group's counter for each non null, non
591        // filtered value
592        self.counts.resize(total_num_groups, 0);
593        accumulate_indices(
594            group_indices,
595            values.logical_nulls().as_ref(),
596            opt_filter,
597            |group_index| {
598                self.counts[group_index] += 1;
599            },
600        );
601
602        Ok(())
603    }
604
605    fn merge_batch(
606        &mut self,
607        values: &[ArrayRef],
608        group_indices: &[usize],
609        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
610        _opt_filter: Option<&BooleanArray>,
611        total_num_groups: usize,
612    ) -> Result<()> {
613        assert_eq!(values.len(), 1, "one argument to merge_batch");
614        // first batch is counts, second is partial sums
615        let partial_counts = values[0].as_primitive::<Int64Type>();
616
617        // intermediate counts are always created as non null
618        assert_eq!(partial_counts.null_count(), 0);
619        let partial_counts = partial_counts.values();
620
621        // Adds the counts with the partial counts
622        self.counts.resize(total_num_groups, 0);
623        group_indices.iter().zip(partial_counts.iter()).for_each(
624            |(&group_index, partial_count)| {
625                self.counts[group_index] += partial_count;
626            },
627        );
628
629        Ok(())
630    }
631
632    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
633        let counts = emit_to.take_needed(&mut self.counts);
634
635        // Count is always non null (null inputs just don't contribute to the overall values)
636        let nulls = None;
637        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
638
639        Ok(Arc::new(array))
640    }
641
642    // return arrays for counts
643    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
644        let counts = emit_to.take_needed(&mut self.counts);
645        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
646        Ok(vec![Arc::new(counts) as ArrayRef])
647    }
648
649    /// Converts an input batch directly to a state batch
650    ///
651    /// The state of `COUNT` is always a single Int64Array:
652    /// * `1` (for non-null, non filtered values)
653    /// * `0` (for null values)
654    fn convert_to_state(
655        &self,
656        values: &[ArrayRef],
657        opt_filter: Option<&BooleanArray>,
658    ) -> Result<Vec<ArrayRef>> {
659        let values = &values[0];
660
661        let state_array = match (values.logical_nulls(), opt_filter) {
662            (None, None) => {
663                // In case there is no nulls in input and no filter, returning array of 1
664                Arc::new(Int64Array::from_value(1, values.len()))
665            }
666            (Some(nulls), None) => {
667                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
668                // of input array to Int64
669                let nulls = BooleanArray::new(nulls.into_inner(), None);
670                compute::cast(&nulls, &DataType::Int64)?
671            }
672            (None, Some(filter)) => {
673                // If there is only filter
674                // - applying filter null mask to filter values by bitand filter values and nulls buffers
675                //   (using buffers guarantees absence of nulls in result)
676                // - casting result of bitand to Int64 array
677                let (filter_values, filter_nulls) = filter.clone().into_parts();
678
679                let state_buf = match filter_nulls {
680                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
681                    None => filter_values,
682                };
683
684                let boolean_state = BooleanArray::new(state_buf, None);
685                compute::cast(&boolean_state, &DataType::Int64)?
686            }
687            (Some(nulls), Some(filter)) => {
688                // For both input nulls and filter
689                // - applying filter null mask to filter values by bitand filter values and nulls buffers
690                //   (using buffers guarantees absence of nulls in result)
691                // - applying values null mask to filter buffer by another bitand on filter result and
692                //   nulls from input values
693                // - casting result to Int64 array
694                let (filter_values, filter_nulls) = filter.clone().into_parts();
695
696                let filter_buf = match filter_nulls {
697                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
698                    None => filter_values,
699                };
700                let state_buf = &filter_buf & nulls.inner();
701
702                let boolean_state = BooleanArray::new(state_buf, None);
703                compute::cast(&boolean_state, &DataType::Int64)?
704            }
705        };
706
707        Ok(vec![state_array])
708    }
709
710    fn supports_convert_to_state(&self) -> bool {
711        true
712    }
713
714    fn size(&self) -> usize {
715        self.counts.capacity() * size_of::<usize>()
716    }
717}
718
719/// count null values for multiple columns
720/// for each row if one column value is null, then null_count + 1
721fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
722    if values.len() > 1 {
723        let result_bool_buf: Option<BooleanBuffer> = values
724            .iter()
725            .map(|a| a.logical_nulls())
726            .fold(None, |acc, b| match (acc, b) {
727                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
728                (Some(acc), None) => Some(acc),
729                (None, Some(b)) => Some(b.into_inner()),
730                _ => None,
731            });
732        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
733    } else {
734        values[0]
735            .logical_nulls()
736            .map_or(0, |nulls| nulls.null_count())
737    }
738}
739
740/// General purpose distinct accumulator that works for any DataType by using
741/// [`ScalarValue`].
742///
743/// It stores intermediate results as a `ListArray`
744///
745/// Note that many types have specialized accumulators that are (much)
746/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
747/// [`BytesDistinctCountAccumulator`]
748#[derive(Debug)]
749struct DistinctCountAccumulator {
750    values: HashSet<ScalarValue, RandomState>,
751    state_data_type: DataType,
752}
753
754impl DistinctCountAccumulator {
755    // calculating the size for fixed length values, taking first batch size *
756    // number of batches This method is faster than .full_size(), however it is
757    // not suitable for variable length values like strings or complex types
758    fn fixed_size(&self) -> usize {
759        size_of_val(self)
760            + (size_of::<ScalarValue>() * self.values.capacity())
761            + self
762                .values
763                .iter()
764                .next()
765                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
766                .unwrap_or(0)
767            + size_of::<DataType>()
768    }
769
770    // calculates the size as accurately as possible. Note that calling this
771    // method is expensive
772    fn full_size(&self) -> usize {
773        size_of_val(self)
774            + (size_of::<ScalarValue>() * self.values.capacity())
775            + self
776                .values
777                .iter()
778                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
779                .sum::<usize>()
780            + size_of::<DataType>()
781    }
782}
783
784impl Accumulator for DistinctCountAccumulator {
785    /// Returns the distinct values seen so far as (one element) ListArray.
786    fn state(&mut self) -> Result<Vec<ScalarValue>> {
787        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
788        let arr =
789            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
790        Ok(vec![ScalarValue::List(arr)])
791    }
792
793    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
794        if values.is_empty() {
795            return Ok(());
796        }
797
798        let arr = &values[0];
799        if arr.data_type() == &DataType::Null {
800            return Ok(());
801        }
802
803        (0..arr.len()).try_for_each(|index| {
804            let scalar = ScalarValue::try_from_array(arr, index)?;
805            if !scalar.is_null() {
806                self.values.insert(scalar);
807            }
808            Ok(())
809        })
810    }
811
812    /// Merges multiple sets of distinct values into the current set.
813    ///
814    /// The input to this function is a `ListArray` with **multiple** rows,
815    /// where each row contains the values from a partial aggregate's phase (e.g.
816    /// the result of calling `Self::state` on multiple accumulators).
817    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
818        if states.is_empty() {
819            return Ok(());
820        }
821        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
822        let array = &states[0];
823        let list_array = array.as_list::<i32>();
824        for inner_array in list_array.iter() {
825            let Some(inner_array) = inner_array else {
826                return internal_err!(
827                    "Intermediate results of COUNT DISTINCT should always be non null"
828                );
829            };
830            self.update_batch(&[inner_array])?;
831        }
832        Ok(())
833    }
834
835    fn evaluate(&mut self) -> Result<ScalarValue> {
836        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
837    }
838
839    fn size(&self) -> usize {
840        match &self.state_data_type {
841            DataType::Boolean | DataType::Null => self.fixed_size(),
842            d if d.is_primitive() => self.fixed_size(),
843            _ => self.full_size(),
844        }
845    }
846}
847
848#[cfg(test)]
849mod tests {
850
851    use super::*;
852    use arrow::{
853        array::{DictionaryArray, Int32Array, NullArray, StringArray},
854        datatypes::{DataType, Field, Int32Type, Schema},
855    };
856    use datafusion_expr::function::AccumulatorArgs;
857    use datafusion_physical_expr::expressions::Column;
858    use std::sync::Arc;
859    /// Helper function to create a dictionary array with non-null keys but some null values
860    /// Returns a dictionary array where:
861    /// - keys are [0, 1, 2, 0, 1] (all non-null)
862    /// - values are ["a", null, "c"]
863    /// - so the keys reference: "a", null, "c", "a", null
864    fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
865        let values = StringArray::from(vec![Some("a"), None, Some("c")]);
866        let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
867        Ok(DictionaryArray::<Int32Type>::try_new(
868            keys,
869            Arc::new(values),
870        )?)
871    }
872
873    #[test]
874    fn count_accumulator_nulls() -> Result<()> {
875        let mut accumulator = CountAccumulator::new();
876        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
877        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
878        Ok(())
879    }
880
881    #[test]
882    fn test_nested_dictionary() -> Result<()> {
883        let schema = Arc::new(Schema::new(vec![Field::new(
884            "dict_col",
885            DataType::Dictionary(
886                Box::new(DataType::Int32),
887                Box::new(DataType::Dictionary(
888                    Box::new(DataType::Int32),
889                    Box::new(DataType::Utf8),
890                )),
891            ),
892            true,
893        )]));
894
895        // Using Count UDAF's accumulator
896        let count = Count::new();
897        let expr = Arc::new(Column::new("dict_col", 0));
898        let args = AccumulatorArgs {
899            schema: &schema,
900            exprs: &[expr],
901            is_distinct: true,
902            name: "count",
903            ignore_nulls: false,
904            is_reversed: false,
905            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
906            order_bys: &[],
907        };
908
909        let inner_dict =
910            DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
911
912        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
913        let dict_of_dict =
914            DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
915
916        let mut acc = count.accumulator(args)?;
917        acc.update_batch(&[Arc::new(dict_of_dict)])?;
918        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
919
920        Ok(())
921    }
922
923    #[test]
924    fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
925        let dict_array = create_dictionary_with_null_values()?;
926
927        // The expected behavior is that count_distinct should count only non-null values
928        // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
929        let mut accumulator = DistinctCountAccumulator {
930            values: HashSet::default(),
931            state_data_type: dict_array.data_type().clone(),
932        };
933
934        accumulator.update_batch(&[Arc::new(dict_array)])?;
935
936        // Should have 2 distinct non-null values ("a" and "c")
937        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
938        Ok(())
939    }
940
941    #[test]
942    fn count_accumulator_dictionary_with_null_values() -> Result<()> {
943        let dict_array = create_dictionary_with_null_values()?;
944
945        // The expected behavior is that count should only count non-null values
946        let mut accumulator = CountAccumulator::new();
947
948        accumulator.update_batch(&[Arc::new(dict_array)])?;
949
950        // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
951        // So we should count 3 non-null values
952        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
953        Ok(())
954    }
955
956    #[test]
957    fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
958        // Create a dictionary array that only contains null values
959        let dict_values = StringArray::from(vec![None, Some("abc")]);
960        let dict_indices = Int32Array::from(vec![0; 5]);
961        let dict_array =
962            DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
963
964        let mut accumulator = DistinctCountAccumulator {
965            values: HashSet::default(),
966            state_data_type: dict_array.data_type().clone(),
967        };
968
969        accumulator.update_batch(&[Arc::new(dict_array)])?;
970
971        // All referenced values are null so count(distinct) should be 0
972        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
973        Ok(())
974    }
975
976    #[test]
977    fn sliding_distinct_count_accumulator_basic() -> Result<()> {
978        // Basic update_batch + evaluate functionality
979        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
980        // Create an Int32Array: [1, 2, 2, 3, null]
981        let values: ArrayRef = Arc::new(Int32Array::from(vec![
982            Some(1),
983            Some(2),
984            Some(2),
985            Some(3),
986            None,
987        ]));
988        acc.update_batch(&[values])?;
989        // Expect distinct values {1,2,3} → count = 3
990        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
991        Ok(())
992    }
993
994    #[test]
995    fn sliding_distinct_count_accumulator_retract() -> Result<()> {
996        // Test that retract_batch properly decrements counts
997        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
998        // Initial batch: ["a", "b", "a"]
999        let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1000            as ArrayRef;
1001        acc.update_batch(&[arr1])?;
1002        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"}
1003
1004        // Retract batch: ["a", null, "b"]
1005        let arr2 =
1006            Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1007        acc.retract_batch(&[arr2])?;
1008        // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"}
1009        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1010        Ok(())
1011    }
1012
1013    #[test]
1014    fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1015        // Test merging multiple accumulator states with merge_batch
1016        let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1017        let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1018        // acc1 sees [1, 2]
1019        acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1020        // acc2 sees [2, 3]
1021        acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1022        // Extract their states as Vec<ScalarValue>
1023        let state_sv1 = acc1.state()?;
1024        let state_sv2 = acc2.state()?;
1025        // Convert ScalarValue states into Vec<ArrayRef>, propagating errors
1026        // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray
1027        let state_arr1: Vec<ArrayRef> = state_sv1
1028            .into_iter()
1029            .map(|sv| sv.to_array())
1030            .collect::<Result<_>>()?;
1031        let state_arr2: Vec<ArrayRef> = state_sv2
1032            .into_iter()
1033            .map(|sv| sv.to_array())
1034            .collect::<Result<_>>()?;
1035        // Merge both states into a fresh accumulator
1036        let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1037        merged.merge_batch(&state_arr1)?;
1038        merged.merge_batch(&state_arr2)?;
1039        // Expect distinct {1,2,3} → count = 3
1040        assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1041        Ok(())
1042    }
1043}