datafusion_functions_aggregate/
count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use ahash::RandomState;
19use arrow::{
20    array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21    buffer::BooleanBuffer,
22    compute,
23    datatypes::{
24        DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25        FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
26        Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27        Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29        UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30    },
31};
32use datafusion_common::{
33    HashMap, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
34    stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38    ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39    WindowFunctionDefinition,
40    expr::WindowFunction,
41    function::{AccumulatorArgs, StateFieldsArgs},
42    utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45    count_distinct::BytesDistinctCountAccumulator,
46    count_distinct::BytesViewDistinctCountAccumulator,
47    count_distinct::DictionaryCountAccumulator,
48    count_distinct::FloatDistinctCountAccumulator,
49    count_distinct::PrimitiveDistinctCountAccumulator,
50    groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56    collections::HashSet,
57    fmt::Debug,
58    mem::{size_of, size_of_val},
59    ops::BitAnd,
60    sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64    Count,
65    count,
66    expr,
67    "Count the number of non-null values in the column",
68    count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73        count_udaf(),
74        vec![expr],
75        true,
76        None,
77        vec![],
78        None,
79    ))
80}
81
82/// Creates aggregation to count all rows.
83///
84/// In SQL this is `SELECT COUNT(*) ... `
85///
86/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
87/// aliased to a column named `"count(*)"` for backward compatibility.
88///
89/// Example
90/// ```
91/// # use datafusion_functions_aggregate::count::count_all;
92/// # use datafusion_expr::col;
93/// // create `count(*)` expression
94/// let expr = count_all();
95/// assert_eq!(expr.schema_name().to_string(), "count(*)");
96/// // if you need to refer to this column, use the `schema_name` function
97/// let expr = col(expr.schema_name().to_string());
98/// ```
99pub fn count_all() -> Expr {
100    count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103/// Creates window aggregation to count all rows.
104///
105/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
106///
107/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
108///
109/// Example
110/// ```
111/// # use datafusion_functions_aggregate::count::count_all_window;
112/// # use datafusion_expr::col;
113/// // create `count(*)` OVER ... window function expression
114/// let expr = count_all_window();
115/// assert_eq!(
116///     expr.schema_name().to_string(),
117///     "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING"
118/// );
119/// // if you need to refer to this column, use the `schema_name` function
120/// let expr = col(expr.schema_name().to_string());
121/// ```
122pub fn count_all_window() -> Expr {
123    Expr::from(WindowFunction::new(
124        WindowFunctionDefinition::AggregateUDF(count_udaf()),
125        vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126    ))
127}
128
129#[user_doc(
130    doc_section(label = "General Functions"),
131    description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132    syntax_example = "count(expression)",
133    sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name)     |
137+-----------------------+
138| 100                   |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*)         |
144+------------------+
145| 120              |
146+------------------+
147```"#,
148    standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash)]
151pub struct Count {
152    signature: Signature,
153}
154
155impl Debug for Count {
156    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157        f.debug_struct("Count")
158            .field("name", &self.name())
159            .field("signature", &self.signature)
160            .finish()
161    }
162}
163
164impl Default for Count {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl Count {
171    pub fn new() -> Self {
172        Self {
173            signature: Signature::one_of(
174                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
175                Volatility::Immutable,
176            ),
177        }
178    }
179}
180fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
181    match data_type {
182        // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
183        DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
184            data_type,
185        )),
186        DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
187            data_type,
188        )),
189        DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
190            data_type,
191        )),
192        DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
193            data_type,
194        )),
195        DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
196            data_type,
197        )),
198        DataType::UInt16 => Box::new(
199            PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
200        ),
201        DataType::UInt32 => Box::new(
202            PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
203        ),
204        DataType::UInt64 => Box::new(
205            PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
206        ),
207        DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
208            Decimal128Type,
209        >::new(data_type)),
210        DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211            Decimal256Type,
212        >::new(data_type)),
213
214        DataType::Date32 => Box::new(
215            PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
216        ),
217        DataType::Date64 => Box::new(
218            PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
219        ),
220        DataType::Time32(TimeUnit::Millisecond) => Box::new(
221            PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
222        ),
223        DataType::Time32(TimeUnit::Second) => Box::new(
224            PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
225        ),
226        DataType::Time64(TimeUnit::Microsecond) => Box::new(
227            PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
228        ),
229        DataType::Time64(TimeUnit::Nanosecond) => Box::new(
230            PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
231        ),
232        DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
233            PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
234        ),
235        DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
236            PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
237        ),
238        DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
239            PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
240        ),
241        DataType::Timestamp(TimeUnit::Second, _) => Box::new(
242            PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
243        ),
244
245        DataType::Float16 => {
246            Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
247        }
248        DataType::Float32 => {
249            Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
250        }
251        DataType::Float64 => {
252            Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
253        }
254
255        DataType::Utf8 => {
256            Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
257        }
258        DataType::Utf8View => {
259            Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
260        }
261        DataType::LargeUtf8 => {
262            Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
263        }
264        DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
265            OutputType::Binary,
266        )),
267        DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
268            OutputType::BinaryView,
269        )),
270        DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
271            OutputType::Binary,
272        )),
273
274        // Use the generic accumulator based on `ScalarValue` for all other types
275        _ => Box::new(DistinctCountAccumulator {
276            values: HashSet::default(),
277            state_data_type: data_type.clone(),
278        }),
279    }
280}
281
282impl AggregateUDFImpl for Count {
283    fn as_any(&self) -> &dyn std::any::Any {
284        self
285    }
286
287    fn name(&self) -> &str {
288        "count"
289    }
290
291    fn signature(&self) -> &Signature {
292        &self.signature
293    }
294
295    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296        Ok(DataType::Int64)
297    }
298
299    fn is_nullable(&self) -> bool {
300        false
301    }
302
303    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
304        if args.is_distinct {
305            let dtype: DataType = match &args.input_fields[0].data_type() {
306                DataType::Dictionary(_, values_type) => (**values_type).clone(),
307                &dtype => dtype.clone(),
308            };
309
310            Ok(vec![
311                Field::new_list(
312                    format_state_name(args.name, "count distinct"),
313                    // See COMMENTS.md to understand why nullable is set to true
314                    Field::new_list_field(dtype, true),
315                    false,
316                )
317                .into(),
318            ])
319        } else {
320            Ok(vec![
321                Field::new(
322                    format_state_name(args.name, "count"),
323                    DataType::Int64,
324                    false,
325                )
326                .into(),
327            ])
328        }
329    }
330
331    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
332        if !acc_args.is_distinct {
333            return Ok(Box::new(CountAccumulator::new()));
334        }
335
336        if acc_args.exprs.len() > 1 {
337            return not_impl_err!("COUNT DISTINCT with multiple arguments");
338        }
339
340        let data_type = acc_args.expr_fields[0].data_type();
341
342        Ok(match data_type {
343            DataType::Dictionary(_, values_type) => {
344                let inner = get_count_accumulator(values_type);
345                Box::new(DictionaryCountAccumulator::new(inner))
346            }
347            _ => get_count_accumulator(data_type),
348        })
349    }
350
351    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
352        // groups accumulator only supports `COUNT(c1)`, not
353        // `COUNT(c1, c2)`, etc
354        if args.is_distinct {
355            return false;
356        }
357        args.exprs.len() == 1
358    }
359
360    fn create_groups_accumulator(
361        &self,
362        _args: AccumulatorArgs,
363    ) -> Result<Box<dyn GroupsAccumulator>> {
364        // instantiate specialized accumulator
365        Ok(Box::new(CountGroupsAccumulator::new()))
366    }
367
368    fn reverse_expr(&self) -> ReversedUDAF {
369        ReversedUDAF::Identical
370    }
371
372    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
373        Ok(ScalarValue::Int64(Some(0)))
374    }
375
376    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
377        if statistics_args.is_distinct {
378            return None;
379        }
380        if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows
381            && statistics_args.exprs.len() == 1
382        {
383            // TODO optimize with exprs other than Column
384            if let Some(col_expr) = statistics_args.exprs[0]
385                .as_any()
386                .downcast_ref::<expressions::Column>()
387            {
388                let current_val = &statistics_args.statistics.column_statistics
389                    [col_expr.index()]
390                .null_count;
391                if let &Precision::Exact(val) = current_val {
392                    return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
393                }
394            } else if let Some(lit_expr) = statistics_args.exprs[0]
395                .as_any()
396                .downcast_ref::<expressions::Literal>()
397                && lit_expr.value() == &COUNT_STAR_EXPANSION
398            {
399                return Some(ScalarValue::Int64(Some(num_rows as i64)));
400            }
401        }
402        None
403    }
404
405    fn documentation(&self) -> Option<&Documentation> {
406        self.doc()
407    }
408
409    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
410        // `COUNT` is monotonically increasing as it always increases or stays
411        // the same as new values are seen.
412        SetMonotonicity::Increasing
413    }
414
415    fn create_sliding_accumulator(
416        &self,
417        args: AccumulatorArgs,
418    ) -> Result<Box<dyn Accumulator>> {
419        if args.is_distinct {
420            let acc =
421                SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
422            Ok(Box::new(acc))
423        } else {
424            let acc = CountAccumulator::new();
425            Ok(Box::new(acc))
426        }
427    }
428}
429
430// DistinctCountAccumulator does not support retract_batch and sliding window
431// this is a specialized accumulator for distinct count that supports retract_batch
432// and sliding window.
433#[derive(Debug)]
434pub struct SlidingDistinctCountAccumulator {
435    counts: HashMap<ScalarValue, usize, RandomState>,
436    data_type: DataType,
437}
438
439impl SlidingDistinctCountAccumulator {
440    pub fn try_new(data_type: &DataType) -> Result<Self> {
441        Ok(Self {
442            counts: HashMap::default(),
443            data_type: data_type.clone(),
444        })
445    }
446}
447
448impl Accumulator for SlidingDistinctCountAccumulator {
449    fn state(&mut self) -> Result<Vec<ScalarValue>> {
450        let keys = self.counts.keys().cloned().collect::<Vec<_>>();
451        Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
452            keys.as_slice(),
453            &self.data_type,
454        ))])
455    }
456
457    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
458        let arr = &values[0];
459        for i in 0..arr.len() {
460            let v = ScalarValue::try_from_array(arr, i)?;
461            if !v.is_null() {
462                *self.counts.entry(v).or_default() += 1;
463            }
464        }
465        Ok(())
466    }
467
468    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469        let arr = &values[0];
470        for i in 0..arr.len() {
471            let v = ScalarValue::try_from_array(arr, i)?;
472            if !v.is_null()
473                && let Some(cnt) = self.counts.get_mut(&v)
474            {
475                *cnt -= 1;
476                if *cnt == 0 {
477                    self.counts.remove(&v);
478                }
479            }
480        }
481        Ok(())
482    }
483
484    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
485        let list_arr = states[0].as_list::<i32>();
486        for inner in list_arr.iter().flatten() {
487            for j in 0..inner.len() {
488                let v = ScalarValue::try_from_array(&*inner, j)?;
489                *self.counts.entry(v).or_default() += 1;
490            }
491        }
492        Ok(())
493    }
494
495    fn evaluate(&mut self) -> Result<ScalarValue> {
496        Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
497    }
498
499    fn supports_retract_batch(&self) -> bool {
500        true
501    }
502
503    fn size(&self) -> usize {
504        size_of_val(self)
505    }
506}
507
508#[derive(Debug)]
509struct CountAccumulator {
510    count: i64,
511}
512
513impl CountAccumulator {
514    /// new count accumulator
515    pub fn new() -> Self {
516        Self { count: 0 }
517    }
518}
519
520impl Accumulator for CountAccumulator {
521    fn state(&mut self) -> Result<Vec<ScalarValue>> {
522        Ok(vec![ScalarValue::Int64(Some(self.count))])
523    }
524
525    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
526        let array = &values[0];
527        self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
528        Ok(())
529    }
530
531    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
532        let array = &values[0];
533        self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
534        Ok(())
535    }
536
537    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
538        let counts = downcast_value!(states[0], Int64Array);
539        let delta = &compute::sum(counts);
540        if let Some(d) = delta {
541            self.count += *d;
542        }
543        Ok(())
544    }
545
546    fn evaluate(&mut self) -> Result<ScalarValue> {
547        Ok(ScalarValue::Int64(Some(self.count)))
548    }
549
550    fn supports_retract_batch(&self) -> bool {
551        true
552    }
553
554    fn size(&self) -> usize {
555        size_of_val(self)
556    }
557}
558
559/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
560/// Stores values as native types, and does overflow checking
561///
562/// Unlike most other accumulators, COUNT never produces NULLs. If no
563/// non-null values are seen in any group the output is 0. Thus, this
564/// accumulator has no additional null or seen filter tracking.
565#[derive(Debug)]
566struct CountGroupsAccumulator {
567    /// Count per group.
568    ///
569    /// Note this is an i64 and not a u64 (or usize) because the
570    /// output type of count is `DataType::Int64`. Thus by using `i64`
571    /// for the counts, the output [`Int64Array`] can be created
572    /// without copy.
573    counts: Vec<i64>,
574}
575
576impl CountGroupsAccumulator {
577    pub fn new() -> Self {
578        Self { counts: vec![] }
579    }
580}
581
582impl GroupsAccumulator for CountGroupsAccumulator {
583    fn update_batch(
584        &mut self,
585        values: &[ArrayRef],
586        group_indices: &[usize],
587        opt_filter: Option<&BooleanArray>,
588        total_num_groups: usize,
589    ) -> Result<()> {
590        assert_eq!(values.len(), 1, "single argument to update_batch");
591        let values = &values[0];
592
593        // Add one to each group's counter for each non null, non
594        // filtered value
595        self.counts.resize(total_num_groups, 0);
596        accumulate_indices(
597            group_indices,
598            values.logical_nulls().as_ref(),
599            opt_filter,
600            |group_index| {
601                self.counts[group_index] += 1;
602            },
603        );
604
605        Ok(())
606    }
607
608    fn merge_batch(
609        &mut self,
610        values: &[ArrayRef],
611        group_indices: &[usize],
612        // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
613        _opt_filter: Option<&BooleanArray>,
614        total_num_groups: usize,
615    ) -> Result<()> {
616        assert_eq!(values.len(), 1, "one argument to merge_batch");
617        // first batch is counts, second is partial sums
618        let partial_counts = values[0].as_primitive::<Int64Type>();
619
620        // intermediate counts are always created as non null
621        assert_eq!(partial_counts.null_count(), 0);
622        let partial_counts = partial_counts.values();
623
624        // Adds the counts with the partial counts
625        self.counts.resize(total_num_groups, 0);
626        group_indices.iter().zip(partial_counts.iter()).for_each(
627            |(&group_index, partial_count)| {
628                self.counts[group_index] += partial_count;
629            },
630        );
631
632        Ok(())
633    }
634
635    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
636        let counts = emit_to.take_needed(&mut self.counts);
637
638        // Count is always non null (null inputs just don't contribute to the overall values)
639        let nulls = None;
640        let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
641
642        Ok(Arc::new(array))
643    }
644
645    // return arrays for counts
646    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
647        let counts = emit_to.take_needed(&mut self.counts);
648        let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); // zero copy, no nulls
649        Ok(vec![Arc::new(counts) as ArrayRef])
650    }
651
652    /// Converts an input batch directly to a state batch
653    ///
654    /// The state of `COUNT` is always a single Int64Array:
655    /// * `1` (for non-null, non filtered values)
656    /// * `0` (for null values)
657    fn convert_to_state(
658        &self,
659        values: &[ArrayRef],
660        opt_filter: Option<&BooleanArray>,
661    ) -> Result<Vec<ArrayRef>> {
662        let values = &values[0];
663
664        let state_array = match (values.logical_nulls(), opt_filter) {
665            (None, None) => {
666                // In case there is no nulls in input and no filter, returning array of 1
667                Arc::new(Int64Array::from_value(1, values.len()))
668            }
669            (Some(nulls), None) => {
670                // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
671                // of input array to Int64
672                let nulls = BooleanArray::new(nulls.into_inner(), None);
673                compute::cast(&nulls, &DataType::Int64)?
674            }
675            (None, Some(filter)) => {
676                // If there is only filter
677                // - applying filter null mask to filter values by bitand filter values and nulls buffers
678                //   (using buffers guarantees absence of nulls in result)
679                // - casting result of bitand to Int64 array
680                let (filter_values, filter_nulls) = filter.clone().into_parts();
681
682                let state_buf = match filter_nulls {
683                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
684                    None => filter_values,
685                };
686
687                let boolean_state = BooleanArray::new(state_buf, None);
688                compute::cast(&boolean_state, &DataType::Int64)?
689            }
690            (Some(nulls), Some(filter)) => {
691                // For both input nulls and filter
692                // - applying filter null mask to filter values by bitand filter values and nulls buffers
693                //   (using buffers guarantees absence of nulls in result)
694                // - applying values null mask to filter buffer by another bitand on filter result and
695                //   nulls from input values
696                // - casting result to Int64 array
697                let (filter_values, filter_nulls) = filter.clone().into_parts();
698
699                let filter_buf = match filter_nulls {
700                    Some(filter_nulls) => &filter_values & filter_nulls.inner(),
701                    None => filter_values,
702                };
703                let state_buf = &filter_buf & nulls.inner();
704
705                let boolean_state = BooleanArray::new(state_buf, None);
706                compute::cast(&boolean_state, &DataType::Int64)?
707            }
708        };
709
710        Ok(vec![state_array])
711    }
712
713    fn supports_convert_to_state(&self) -> bool {
714        true
715    }
716
717    fn size(&self) -> usize {
718        self.counts.capacity() * size_of::<usize>()
719    }
720}
721
722/// count null values for multiple columns
723/// for each row if one column value is null, then null_count + 1
724fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
725    if values.len() > 1 {
726        let result_bool_buf: Option<BooleanBuffer> = values
727            .iter()
728            .map(|a| a.logical_nulls())
729            .fold(None, |acc, b| match (acc, b) {
730                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
731                (Some(acc), None) => Some(acc),
732                (None, Some(b)) => Some(b.into_inner()),
733                _ => None,
734            });
735        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
736    } else {
737        values[0]
738            .logical_nulls()
739            .map_or(0, |nulls| nulls.null_count())
740    }
741}
742
743/// General purpose distinct accumulator that works for any DataType by using
744/// [`ScalarValue`].
745///
746/// It stores intermediate results as a `ListArray`
747///
748/// Note that many types have specialized accumulators that are (much)
749/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
750/// [`BytesDistinctCountAccumulator`]
751#[derive(Debug)]
752struct DistinctCountAccumulator {
753    values: HashSet<ScalarValue, RandomState>,
754    state_data_type: DataType,
755}
756
757impl DistinctCountAccumulator {
758    // calculating the size for fixed length values, taking first batch size *
759    // number of batches This method is faster than .full_size(), however it is
760    // not suitable for variable length values like strings or complex types
761    fn fixed_size(&self) -> usize {
762        size_of_val(self)
763            + (size_of::<ScalarValue>() * self.values.capacity())
764            + self
765                .values
766                .iter()
767                .next()
768                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
769                .unwrap_or(0)
770            + size_of::<DataType>()
771    }
772
773    // calculates the size as accurately as possible. Note that calling this
774    // method is expensive
775    fn full_size(&self) -> usize {
776        size_of_val(self)
777            + (size_of::<ScalarValue>() * self.values.capacity())
778            + self
779                .values
780                .iter()
781                .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
782                .sum::<usize>()
783            + size_of::<DataType>()
784    }
785}
786
787impl Accumulator for DistinctCountAccumulator {
788    /// Returns the distinct values seen so far as (one element) ListArray.
789    fn state(&mut self) -> Result<Vec<ScalarValue>> {
790        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
791        let arr =
792            ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
793        Ok(vec![ScalarValue::List(arr)])
794    }
795
796    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
797        if values.is_empty() {
798            return Ok(());
799        }
800
801        let arr = &values[0];
802        if arr.data_type() == &DataType::Null {
803            return Ok(());
804        }
805
806        (0..arr.len()).try_for_each(|index| {
807            let scalar = ScalarValue::try_from_array(arr, index)?;
808            if !scalar.is_null() {
809                self.values.insert(scalar);
810            }
811            Ok(())
812        })
813    }
814
815    /// Merges multiple sets of distinct values into the current set.
816    ///
817    /// The input to this function is a `ListArray` with **multiple** rows,
818    /// where each row contains the values from a partial aggregate's phase (e.g.
819    /// the result of calling `Self::state` on multiple accumulators).
820    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
821        if states.is_empty() {
822            return Ok(());
823        }
824        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
825        let array = &states[0];
826        let list_array = array.as_list::<i32>();
827        for inner_array in list_array.iter() {
828            let Some(inner_array) = inner_array else {
829                return internal_err!(
830                    "Intermediate results of COUNT DISTINCT should always be non null"
831                );
832            };
833            self.update_batch(&[inner_array])?;
834        }
835        Ok(())
836    }
837
838    fn evaluate(&mut self) -> Result<ScalarValue> {
839        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
840    }
841
842    fn size(&self) -> usize {
843        match &self.state_data_type {
844            DataType::Boolean | DataType::Null => self.fixed_size(),
845            d if d.is_primitive() => self.fixed_size(),
846            _ => self.full_size(),
847        }
848    }
849}
850
851#[cfg(test)]
852mod tests {
853
854    use super::*;
855    use arrow::{
856        array::{DictionaryArray, Int32Array, NullArray, StringArray},
857        datatypes::{DataType, Field, Int32Type, Schema},
858    };
859    use datafusion_expr::function::AccumulatorArgs;
860    use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
861    use std::sync::Arc;
862    /// Helper function to create a dictionary array with non-null keys but some null values
863    /// Returns a dictionary array where:
864    /// - keys are [0, 1, 2, 0, 1] (all non-null)
865    /// - values are ["a", null, "c"]
866    /// - so the keys reference: "a", null, "c", "a", null
867    fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
868        let values = StringArray::from(vec![Some("a"), None, Some("c")]);
869        let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
870        Ok(DictionaryArray::<Int32Type>::try_new(
871            keys,
872            Arc::new(values),
873        )?)
874    }
875
876    #[test]
877    fn count_accumulator_nulls() -> Result<()> {
878        let mut accumulator = CountAccumulator::new();
879        accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
880        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
881        Ok(())
882    }
883
884    #[test]
885    fn test_nested_dictionary() -> Result<()> {
886        let schema = Arc::new(Schema::new(vec![Field::new(
887            "dict_col",
888            DataType::Dictionary(
889                Box::new(DataType::Int32),
890                Box::new(DataType::Dictionary(
891                    Box::new(DataType::Int32),
892                    Box::new(DataType::Utf8),
893                )),
894            ),
895            true,
896        )]));
897
898        // Using Count UDAF's accumulator
899        let count = Count::new();
900        let expr = Arc::new(Column::new("dict_col", 0));
901        let expr_field = expr.return_field(&schema)?;
902        let args = AccumulatorArgs {
903            schema: &schema,
904            expr_fields: &[expr_field],
905            exprs: &[expr],
906            is_distinct: true,
907            name: "count",
908            ignore_nulls: false,
909            is_reversed: false,
910            return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
911            order_bys: &[],
912        };
913
914        let inner_dict =
915            DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
916
917        let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
918        let dict_of_dict =
919            DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
920
921        let mut acc = count.accumulator(args)?;
922        acc.update_batch(&[Arc::new(dict_of_dict)])?;
923        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
924
925        Ok(())
926    }
927
928    #[test]
929    fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
930        let dict_array = create_dictionary_with_null_values()?;
931
932        // The expected behavior is that count_distinct should count only non-null values
933        // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
934        let mut accumulator = DistinctCountAccumulator {
935            values: HashSet::default(),
936            state_data_type: dict_array.data_type().clone(),
937        };
938
939        accumulator.update_batch(&[Arc::new(dict_array)])?;
940
941        // Should have 2 distinct non-null values ("a" and "c")
942        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
943        Ok(())
944    }
945
946    #[test]
947    fn count_accumulator_dictionary_with_null_values() -> Result<()> {
948        let dict_array = create_dictionary_with_null_values()?;
949
950        // The expected behavior is that count should only count non-null values
951        let mut accumulator = CountAccumulator::new();
952
953        accumulator.update_batch(&[Arc::new(dict_array)])?;
954
955        // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
956        // So we should count 3 non-null values
957        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
958        Ok(())
959    }
960
961    #[test]
962    fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
963        // Create a dictionary array that only contains null values
964        let dict_values = StringArray::from(vec![None, Some("abc")]);
965        let dict_indices = Int32Array::from(vec![0; 5]);
966        let dict_array =
967            DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
968
969        let mut accumulator = DistinctCountAccumulator {
970            values: HashSet::default(),
971            state_data_type: dict_array.data_type().clone(),
972        };
973
974        accumulator.update_batch(&[Arc::new(dict_array)])?;
975
976        // All referenced values are null so count(distinct) should be 0
977        assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
978        Ok(())
979    }
980
981    #[test]
982    fn sliding_distinct_count_accumulator_basic() -> Result<()> {
983        // Basic update_batch + evaluate functionality
984        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
985        // Create an Int32Array: [1, 2, 2, 3, null]
986        let values: ArrayRef = Arc::new(Int32Array::from(vec![
987            Some(1),
988            Some(2),
989            Some(2),
990            Some(3),
991            None,
992        ]));
993        acc.update_batch(&[values])?;
994        // Expect distinct values {1,2,3} → count = 3
995        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
996        Ok(())
997    }
998
999    #[test]
1000    fn sliding_distinct_count_accumulator_retract() -> Result<()> {
1001        // Test that retract_batch properly decrements counts
1002        let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
1003        // Initial batch: ["a", "b", "a"]
1004        let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1005            as ArrayRef;
1006        acc.update_batch(&[arr1])?;
1007        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"}
1008
1009        // Retract batch: ["a", null, "b"]
1010        let arr2 =
1011            Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1012        acc.retract_batch(&[arr2])?;
1013        // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"}
1014        assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1015        Ok(())
1016    }
1017
1018    #[test]
1019    fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1020        // Test merging multiple accumulator states with merge_batch
1021        let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1022        let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1023        // acc1 sees [1, 2]
1024        acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1025        // acc2 sees [2, 3]
1026        acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1027        // Extract their states as Vec<ScalarValue>
1028        let state_sv1 = acc1.state()?;
1029        let state_sv2 = acc2.state()?;
1030        // Convert ScalarValue states into Vec<ArrayRef>, propagating errors
1031        // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray
1032        let state_arr1: Vec<ArrayRef> = state_sv1
1033            .into_iter()
1034            .map(|sv| sv.to_array())
1035            .collect::<Result<_>>()?;
1036        let state_arr2: Vec<ArrayRef> = state_sv2
1037            .into_iter()
1038            .map(|sv| sv.to_array())
1039            .collect::<Result<_>>()?;
1040        // Merge both states into a fresh accumulator
1041        let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1042        merged.merge_batch(&state_arr1)?;
1043        merged.merge_batch(&state_arr2)?;
1044        // Expect distinct {1,2,3} → count = 3
1045        assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1046        Ok(())
1047    }
1048}