1use arrow::{
19 array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
20 buffer::BooleanBuffer,
21 compute,
22 datatypes::{
23 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
24 FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
25 Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
26 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
27 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28 UInt8Type, UInt16Type, UInt32Type, UInt64Type,
29 },
30};
31use datafusion_common::hash_utils::RandomState;
32use datafusion_common::{
33 HashMap, Result, ScalarValue, downcast_value, exec_err, internal_err, not_impl_err,
34 stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38 ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39 WindowFunctionDefinition,
40 expr::WindowFunction,
41 function::{AccumulatorArgs, StateFieldsArgs},
42 utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
45use datafusion_functions_aggregate_common::aggregate::{
46 count_distinct::Bitmap65536DistinctCountAccumulator,
47 count_distinct::Bitmap65536DistinctCountAccumulatorI16,
48 count_distinct::BoolArray256DistinctCountAccumulator,
49 count_distinct::BoolArray256DistinctCountAccumulatorI8,
50 count_distinct::BytesDistinctCountAccumulator,
51 count_distinct::BytesViewDistinctCountAccumulator,
52 count_distinct::DictionaryCountAccumulator,
53 count_distinct::FloatDistinctCountAccumulator,
54 count_distinct::PrimitiveDistinctCountAccumulator,
55 groups_accumulator::accumulate::accumulate_indices,
56};
57use datafusion_macros::user_doc;
58use datafusion_physical_expr::expressions;
59use datafusion_physical_expr_common::binary_map::OutputType;
60use std::{
61 collections::HashSet,
62 fmt::Debug,
63 mem::{size_of, size_of_val},
64 ops::BitAnd,
65 sync::Arc,
66};
67
68make_udaf_expr_and_func!(
69 Count,
70 count,
71 expr,
72 "Count the number of non-null values in the column",
73 count_udaf
74);
75
76pub fn count_distinct(expr: Expr) -> Expr {
77 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
78 count_udaf(),
79 vec![expr],
80 true,
81 None,
82 vec![],
83 None,
84 ))
85}
86
87pub fn count_all() -> Expr {
105 count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
106}
107
108pub fn count_all_window() -> Expr {
128 Expr::from(WindowFunction::new(
129 WindowFunctionDefinition::AggregateUDF(count_udaf()),
130 vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
131 ))
132}
133
134#[user_doc(
135 doc_section(label = "General Functions"),
136 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
137 syntax_example = "count(expression)",
138 sql_example = r#"```sql
139> SELECT count(column_name) FROM table_name;
140+-----------------------+
141| count(column_name) |
142+-----------------------+
143| 100 |
144+-----------------------+
145
146> SELECT count(*) FROM table_name;
147+------------------+
148| count(*) |
149+------------------+
150| 120 |
151+------------------+
152```"#,
153 standard_argument(name = "expression",)
154)]
155#[derive(PartialEq, Eq, Hash, Debug)]
156pub struct Count {
157 signature: Signature,
158}
159
160impl Default for Count {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166impl Count {
167 pub fn new() -> Self {
168 Self {
169 signature: Signature::one_of(
170 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
171 Volatility::Immutable,
172 ),
173 }
174 }
175}
176fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
177 match data_type {
178 DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
180 data_type,
181 )),
182 DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
183 data_type,
184 )),
185 DataType::UInt32 => Box::new(
186 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
187 ),
188 DataType::UInt64 => Box::new(
189 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
190 ),
191 DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => {
193 get_small_int_accumulator(data_type).unwrap()
194 }
195 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
196 Decimal128Type,
197 >::new(data_type)),
198 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
199 Decimal256Type,
200 >::new(data_type)),
201
202 DataType::Date32 => Box::new(
203 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
204 ),
205 DataType::Date64 => Box::new(
206 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
207 ),
208 DataType::Time32(TimeUnit::Millisecond) => Box::new(
209 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
210 ),
211 DataType::Time32(TimeUnit::Second) => Box::new(
212 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
213 ),
214 DataType::Time64(TimeUnit::Microsecond) => Box::new(
215 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
216 ),
217 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
218 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
219 ),
220 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
221 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
222 ),
223 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
224 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
225 ),
226 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
227 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
228 ),
229 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
230 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
231 ),
232
233 DataType::Float16 => {
234 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
235 }
236 DataType::Float32 => {
237 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
238 }
239 DataType::Float64 => {
240 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
241 }
242
243 DataType::Utf8 => {
244 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
245 }
246 DataType::Utf8View => {
247 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
248 }
249 DataType::LargeUtf8 => {
250 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
251 }
252 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
253 OutputType::Binary,
254 )),
255 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
256 OutputType::BinaryView,
257 )),
258 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
259 OutputType::Binary,
260 )),
261
262 _ => Box::new(DistinctCountAccumulator {
264 values: HashSet::default(),
265 state_data_type: data_type.clone(),
266 }),
267 }
268}
269
270#[cold]
272fn get_small_int_accumulator(data_type: &DataType) -> Result<Box<dyn Accumulator>> {
273 match data_type {
274 DataType::UInt8 => Ok(Box::new(BoolArray256DistinctCountAccumulator::new())),
275 DataType::Int8 => Ok(Box::new(BoolArray256DistinctCountAccumulatorI8::new())),
276 DataType::UInt16 => Ok(Box::new(Bitmap65536DistinctCountAccumulator::new())),
277 DataType::Int16 => Ok(Box::new(Bitmap65536DistinctCountAccumulatorI16::new())),
278 _ => exec_err!("unsupported accumulator for datatype: {}", data_type),
279 }
280}
281
282impl AggregateUDFImpl for Count {
283 fn name(&self) -> &str {
284 "count"
285 }
286
287 fn signature(&self) -> &Signature {
288 &self.signature
289 }
290
291 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
292 Ok(DataType::Int64)
293 }
294
295 fn is_nullable(&self) -> bool {
296 false
297 }
298
299 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
300 if args.is_distinct {
301 let dtype: DataType = match &args.input_fields[0].data_type() {
302 DataType::Dictionary(_, values_type) => (**values_type).clone(),
303 &dtype => dtype.clone(),
304 };
305
306 Ok(vec![
307 Field::new_list(
308 format_state_name(args.name, "count distinct"),
309 Field::new_list_field(dtype, true),
311 false,
312 )
313 .into(),
314 ])
315 } else {
316 Ok(vec![
317 Field::new(
318 format_state_name(args.name, "count"),
319 DataType::Int64,
320 false,
321 )
322 .into(),
323 ])
324 }
325 }
326
327 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
328 if !acc_args.is_distinct {
329 return Ok(Box::new(CountAccumulator::new()));
330 }
331
332 if acc_args.exprs.len() > 1 {
333 return not_impl_err!("COUNT DISTINCT with multiple arguments");
334 }
335
336 let data_type = acc_args.expr_fields[0].data_type();
337
338 Ok(match data_type {
339 DataType::Dictionary(_, values_type) => {
340 let inner = get_count_accumulator(values_type);
341 Box::new(DictionaryCountAccumulator::new(inner))
342 }
343 _ => get_count_accumulator(data_type),
344 })
345 }
346
347 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348 if args.exprs.len() != 1 {
349 return false;
350 }
351 if !args.is_distinct {
352 return true;
353 }
354 matches!(
355 args.expr_fields[0].data_type(),
356 DataType::Int8
357 | DataType::Int16
358 | DataType::Int32
359 | DataType::Int64
360 | DataType::UInt8
361 | DataType::UInt16
362 | DataType::UInt32
363 | DataType::UInt64
364 )
365 }
366
367 fn create_groups_accumulator(
368 &self,
369 args: AccumulatorArgs,
370 ) -> Result<Box<dyn GroupsAccumulator>> {
371 if !args.is_distinct {
372 return Ok(Box::new(CountGroupsAccumulator::new()));
373 }
374 create_distinct_count_groups_accumulator(&args)
375 }
376
377 fn reverse_expr(&self) -> ReversedUDAF {
378 ReversedUDAF::Identical
379 }
380
381 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
382 Ok(ScalarValue::Int64(Some(0)))
383 }
384
385 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
386 let [expr] = statistics_args.exprs else {
387 return None;
388 };
389 let col_stats = &statistics_args.statistics.column_statistics;
390
391 if statistics_args.is_distinct {
392 let col_expr = expr.downcast_ref::<expressions::Column>()?;
395 if let Precision::Exact(dc) = col_stats[col_expr.index()].distinct_count {
396 let dc = i64::try_from(dc).ok()?;
397 return Some(ScalarValue::Int64(Some(dc)));
398 }
399 return None;
400 }
401
402 let Precision::Exact(num_rows) = statistics_args.statistics.num_rows else {
403 return None;
404 };
405
406 if let Some(col_expr) = expr.downcast_ref::<expressions::Column>() {
408 if let Precision::Exact(val) = col_stats[col_expr.index()].null_count {
409 let count = i64::try_from(num_rows - val).ok()?;
410 return Some(ScalarValue::Int64(Some(count)));
411 }
412 } else if let Some(lit_expr) = expr.downcast_ref::<expressions::Literal>()
413 && lit_expr.value() == &COUNT_STAR_EXPANSION
414 {
415 let num_rows = i64::try_from(num_rows).ok()?;
416 return Some(ScalarValue::Int64(Some(num_rows)));
417 }
418
419 None
420 }
421
422 fn documentation(&self) -> Option<&Documentation> {
423 self.doc()
424 }
425
426 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
427 SetMonotonicity::Increasing
430 }
431
432 fn create_sliding_accumulator(
433 &self,
434 args: AccumulatorArgs,
435 ) -> Result<Box<dyn Accumulator>> {
436 if args.is_distinct {
437 let acc =
438 SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
439 Ok(Box::new(acc))
440 } else {
441 let acc = CountAccumulator::new();
442 Ok(Box::new(acc))
443 }
444 }
445}
446
447#[cold]
448fn create_distinct_count_groups_accumulator(
449 args: &AccumulatorArgs,
450) -> Result<Box<dyn GroupsAccumulator>> {
451 let data_type = args.expr_fields[0].data_type();
452 match data_type {
453 DataType::Int8 => Ok(Box::new(
454 PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
455 )),
456 DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
457 Int16Type,
458 >::new())),
459 DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
460 Int32Type,
461 >::new())),
462 DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
463 Int64Type,
464 >::new())),
465 DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
466 UInt8Type,
467 >::new())),
468 DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
469 UInt16Type,
470 >::new())),
471 DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
472 UInt32Type,
473 >::new())),
474 DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
475 UInt64Type,
476 >::new())),
477 _ => not_impl_err!(
478 "GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
479 data_type
480 ),
481 }
482}
483
484#[derive(Debug)]
488pub struct SlidingDistinctCountAccumulator {
489 counts: HashMap<ScalarValue, usize, RandomState>,
490 data_type: DataType,
491}
492
493impl SlidingDistinctCountAccumulator {
494 pub fn try_new(data_type: &DataType) -> Result<Self> {
495 Ok(Self {
496 counts: HashMap::default(),
497 data_type: data_type.clone(),
498 })
499 }
500}
501
502impl Accumulator for SlidingDistinctCountAccumulator {
503 fn state(&mut self) -> Result<Vec<ScalarValue>> {
504 let keys = self.counts.keys().cloned().collect::<Vec<_>>();
505 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
506 keys.as_slice(),
507 &self.data_type,
508 ))])
509 }
510
511 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
512 let arr = &values[0];
513 for i in 0..arr.len() {
514 let v = ScalarValue::try_from_array(arr, i)?;
515 if !v.is_null() {
516 *self.counts.entry(v).or_default() += 1;
517 }
518 }
519 Ok(())
520 }
521
522 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523 let arr = &values[0];
524 for i in 0..arr.len() {
525 let v = ScalarValue::try_from_array(arr, i)?;
526 if !v.is_null()
527 && let Some(cnt) = self.counts.get_mut(&v)
528 {
529 *cnt -= 1;
530 if *cnt == 0 {
531 self.counts.remove(&v);
532 }
533 }
534 }
535 Ok(())
536 }
537
538 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
539 let list_arr = states[0].as_list::<i32>();
540 for inner in list_arr.iter().flatten() {
541 for j in 0..inner.len() {
542 let v = ScalarValue::try_from_array(&*inner, j)?;
543 *self.counts.entry(v).or_default() += 1;
544 }
545 }
546 Ok(())
547 }
548
549 fn evaluate(&mut self) -> Result<ScalarValue> {
550 Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
551 }
552
553 fn supports_retract_batch(&self) -> bool {
554 true
555 }
556
557 fn size(&self) -> usize {
558 size_of_val(self)
559 }
560}
561
562#[derive(Debug)]
563struct CountAccumulator {
564 count: i64,
565}
566
567impl CountAccumulator {
568 pub fn new() -> Self {
570 Self { count: 0 }
571 }
572}
573
574impl Accumulator for CountAccumulator {
575 fn state(&mut self) -> Result<Vec<ScalarValue>> {
576 Ok(vec![ScalarValue::Int64(Some(self.count))])
577 }
578
579 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
580 let array = &values[0];
581 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
582 Ok(())
583 }
584
585 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
586 let array = &values[0];
587 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
588 Ok(())
589 }
590
591 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
592 let counts = downcast_value!(states[0], Int64Array);
593 let delta = &compute::sum(counts);
594 if let Some(d) = delta {
595 self.count += *d;
596 }
597 Ok(())
598 }
599
600 fn evaluate(&mut self) -> Result<ScalarValue> {
601 Ok(ScalarValue::Int64(Some(self.count)))
602 }
603
604 fn supports_retract_batch(&self) -> bool {
605 true
606 }
607
608 fn size(&self) -> usize {
609 size_of_val(self)
610 }
611}
612
613#[derive(Debug)]
620struct CountGroupsAccumulator {
621 counts: Vec<i64>,
628}
629
630impl CountGroupsAccumulator {
631 pub fn new() -> Self {
632 Self { counts: vec![] }
633 }
634}
635
636impl GroupsAccumulator for CountGroupsAccumulator {
637 fn update_batch(
638 &mut self,
639 values: &[ArrayRef],
640 group_indices: &[usize],
641 opt_filter: Option<&BooleanArray>,
642 total_num_groups: usize,
643 ) -> Result<()> {
644 assert_eq!(values.len(), 1, "single argument to update_batch");
645 let values = &values[0];
646
647 self.counts.resize(total_num_groups, 0);
650 accumulate_indices(
651 group_indices,
652 values.logical_nulls().as_ref(),
653 opt_filter,
654 |group_index| {
655 let count = unsafe { self.counts.get_unchecked_mut(group_index) };
657 *count += 1;
658 },
659 );
660
661 Ok(())
662 }
663
664 fn merge_batch(
665 &mut self,
666 values: &[ArrayRef],
667 group_indices: &[usize],
668 _opt_filter: Option<&BooleanArray>,
670 total_num_groups: usize,
671 ) -> Result<()> {
672 assert_eq!(values.len(), 1, "one argument to merge_batch");
673 let partial_counts = values[0].as_primitive::<Int64Type>();
675
676 assert_eq!(partial_counts.null_count(), 0);
678 let partial_counts = partial_counts.values();
679
680 self.counts.resize(total_num_groups, 0);
682 group_indices.iter().zip(partial_counts.iter()).for_each(
683 |(&group_index, partial_count)| {
684 self.counts[group_index] += partial_count;
685 },
686 );
687
688 Ok(())
689 }
690
691 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
692 let counts = emit_to.take_needed(&mut self.counts);
693
694 let nulls = None;
696 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
697
698 Ok(Arc::new(array))
699 }
700
701 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
703 let counts = emit_to.take_needed(&mut self.counts);
704 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
706 }
707
708 fn convert_to_state(
714 &self,
715 values: &[ArrayRef],
716 opt_filter: Option<&BooleanArray>,
717 ) -> Result<Vec<ArrayRef>> {
718 let values = &values[0];
719
720 let state_array = match (values.logical_nulls(), opt_filter) {
721 (None, None) => {
722 Arc::new(Int64Array::from_value(1, values.len()))
724 }
725 (Some(nulls), None) => {
726 let nulls = BooleanArray::new(nulls.into_inner(), None);
729 compute::cast(&nulls, &DataType::Int64)?
730 }
731 (None, Some(filter)) => {
732 let (filter_values, filter_nulls) = filter.clone().into_parts();
737
738 let state_buf = match filter_nulls {
739 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
740 None => filter_values,
741 };
742
743 let boolean_state = BooleanArray::new(state_buf, None);
744 compute::cast(&boolean_state, &DataType::Int64)?
745 }
746 (Some(nulls), Some(filter)) => {
747 let (filter_values, filter_nulls) = filter.clone().into_parts();
754
755 let filter_buf = match filter_nulls {
756 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
757 None => filter_values,
758 };
759 let state_buf = &filter_buf & nulls.inner();
760
761 let boolean_state = BooleanArray::new(state_buf, None);
762 compute::cast(&boolean_state, &DataType::Int64)?
763 }
764 };
765
766 Ok(vec![state_array])
767 }
768
769 fn supports_convert_to_state(&self) -> bool {
770 true
771 }
772
773 fn size(&self) -> usize {
774 self.counts.capacity() * size_of::<usize>()
775 }
776}
777
778fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
781 if values.len() > 1 {
782 let result_bool_buf: Option<BooleanBuffer> = values
783 .iter()
784 .map(|a| a.logical_nulls())
785 .fold(None, |acc, b| match (acc, b) {
786 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
787 (Some(acc), None) => Some(acc),
788 (None, Some(b)) => Some(b.into_inner()),
789 _ => None,
790 });
791 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
792 } else {
793 values[0]
794 .logical_nulls()
795 .map_or(0, |nulls| nulls.null_count())
796 }
797}
798
799#[derive(Debug)]
808struct DistinctCountAccumulator {
809 values: HashSet<ScalarValue, RandomState>,
810 state_data_type: DataType,
811}
812
813impl DistinctCountAccumulator {
814 fn fixed_size(&self) -> usize {
818 size_of_val(self)
819 + (size_of::<ScalarValue>() * self.values.capacity())
820 + self
821 .values
822 .iter()
823 .next()
824 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
825 .unwrap_or(0)
826 + size_of::<DataType>()
827 }
828
829 fn full_size(&self) -> usize {
832 size_of_val(self)
833 + (size_of::<ScalarValue>() * self.values.capacity())
834 + self
835 .values
836 .iter()
837 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
838 .sum::<usize>()
839 + size_of::<DataType>()
840 }
841}
842
843impl Accumulator for DistinctCountAccumulator {
844 fn state(&mut self) -> Result<Vec<ScalarValue>> {
846 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
847 let arr =
848 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
849 Ok(vec![ScalarValue::List(arr)])
850 }
851
852 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
853 if values.is_empty() {
854 return Ok(());
855 }
856
857 let arr = &values[0];
858 if arr.data_type() == &DataType::Null {
859 return Ok(());
860 }
861
862 (0..arr.len()).try_for_each(|index| {
863 let scalar = ScalarValue::try_from_array(arr, index)?;
864 if !scalar.is_null() {
865 self.values.insert(scalar);
866 }
867 Ok(())
868 })
869 }
870
871 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
877 if states.is_empty() {
878 return Ok(());
879 }
880 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
881 let array = &states[0];
882 let list_array = array.as_list::<i32>();
883 for inner_array in list_array.iter() {
884 let Some(inner_array) = inner_array else {
885 return internal_err!(
886 "Intermediate results of COUNT DISTINCT should always be non null"
887 );
888 };
889 self.update_batch(&[inner_array])?;
890 }
891 Ok(())
892 }
893
894 fn evaluate(&mut self) -> Result<ScalarValue> {
895 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
896 }
897
898 fn size(&self) -> usize {
899 match &self.state_data_type {
900 DataType::Boolean | DataType::Null => self.fixed_size(),
901 d if d.is_primitive() => self.fixed_size(),
902 _ => self.full_size(),
903 }
904 }
905}
906
907#[cfg(test)]
908mod tests {
909
910 use super::*;
911 use arrow::{
912 array::{DictionaryArray, Int32Array, NullArray, StringArray},
913 datatypes::{DataType, Field, Int32Type, Schema},
914 };
915 use datafusion_expr::function::AccumulatorArgs;
916 use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
917 use std::sync::Arc;
918 fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
924 let values = StringArray::from(vec![Some("a"), None, Some("c")]);
925 let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); Ok(DictionaryArray::<Int32Type>::try_new(
927 keys,
928 Arc::new(values),
929 )?)
930 }
931
932 #[test]
933 fn count_accumulator_nulls() -> Result<()> {
934 let mut accumulator = CountAccumulator::new();
935 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
936 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
937 Ok(())
938 }
939
940 #[test]
941 fn test_nested_dictionary() -> Result<()> {
942 let schema = Arc::new(Schema::new(vec![Field::new(
943 "dict_col",
944 DataType::Dictionary(
945 Box::new(DataType::Int32),
946 Box::new(DataType::Dictionary(
947 Box::new(DataType::Int32),
948 Box::new(DataType::Utf8),
949 )),
950 ),
951 true,
952 )]));
953
954 let count = Count::new();
956 let expr = Arc::new(Column::new("dict_col", 0));
957 let expr_field = expr.return_field(&schema)?;
958 let args = AccumulatorArgs {
959 schema: &schema,
960 expr_fields: &[expr_field],
961 exprs: &[expr],
962 is_distinct: true,
963 name: "count",
964 ignore_nulls: false,
965 is_reversed: false,
966 return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
967 order_bys: &[],
968 };
969
970 let inner_dict =
971 DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
972
973 let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
974 let dict_of_dict =
975 DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
976
977 let mut acc = count.accumulator(args)?;
978 acc.update_batch(&[Arc::new(dict_of_dict)])?;
979 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
980
981 Ok(())
982 }
983
984 #[test]
985 fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
986 let dict_array = create_dictionary_with_null_values()?;
987
988 let mut accumulator = DistinctCountAccumulator {
991 values: HashSet::default(),
992 state_data_type: dict_array.data_type().clone(),
993 };
994
995 accumulator.update_batch(&[Arc::new(dict_array)])?;
996
997 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
999 Ok(())
1000 }
1001
1002 #[test]
1003 fn count_accumulator_dictionary_with_null_values() -> Result<()> {
1004 let dict_array = create_dictionary_with_null_values()?;
1005
1006 let mut accumulator = CountAccumulator::new();
1008
1009 accumulator.update_batch(&[Arc::new(dict_array)])?;
1010
1011 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
1014 Ok(())
1015 }
1016
1017 #[test]
1018 fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
1019 let dict_values = StringArray::from(vec![None, Some("abc")]);
1021 let dict_indices = Int32Array::from(vec![0; 5]);
1022 let dict_array =
1023 DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
1024
1025 let mut accumulator = DistinctCountAccumulator {
1026 values: HashSet::default(),
1027 state_data_type: dict_array.data_type().clone(),
1028 };
1029
1030 accumulator.update_batch(&[Arc::new(dict_array)])?;
1031
1032 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1034 Ok(())
1035 }
1036
1037 #[test]
1038 fn sliding_distinct_count_accumulator_basic() -> Result<()> {
1039 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1041 let values: ArrayRef = Arc::new(Int32Array::from(vec![
1043 Some(1),
1044 Some(2),
1045 Some(2),
1046 Some(3),
1047 None,
1048 ]));
1049 acc.update_batch(&[values])?;
1050 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
1052 Ok(())
1053 }
1054
1055 #[test]
1056 fn sliding_distinct_count_accumulator_retract() -> Result<()> {
1057 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
1059 let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1061 as ArrayRef;
1062 acc.update_batch(&[arr1])?;
1063 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); let arr2 =
1067 Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1068 acc.retract_batch(&[arr2])?;
1069 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1071 Ok(())
1072 }
1073
1074 #[test]
1075 fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1076 let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1078 let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1079 acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1081 acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1083 let state_sv1 = acc1.state()?;
1085 let state_sv2 = acc2.state()?;
1086 let state_arr1: Vec<ArrayRef> = state_sv1
1089 .into_iter()
1090 .map(|sv| sv.to_array())
1091 .collect::<Result<_>>()?;
1092 let state_arr2: Vec<ArrayRef> = state_sv2
1093 .into_iter()
1094 .map(|sv| sv.to_array())
1095 .collect::<Result<_>>()?;
1096 let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1098 merged.merge_batch(&state_arr1)?;
1099 merged.merge_batch(&state_arr2)?;
1100 assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1102 Ok(())
1103 }
1104}