1use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_expr::expr::WindowFunction;
21use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
22use datafusion_macros::user_doc;
23use datafusion_physical_expr::expressions;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::mem::{size_of, size_of_val};
27use std::ops::BitAnd;
28use std::sync::Arc;
29
30use arrow::{
31 array::{ArrayRef, AsArray},
32 compute,
33 datatypes::{
34 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
35 Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
36 Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
37 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
38 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
39 UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40 },
41};
42
43use arrow::datatypes::FieldRef;
44use arrow::{
45 array::{Array, BooleanArray, Int64Array, PrimitiveArray},
46 buffer::BooleanBuffer,
47};
48use datafusion_common::{
49 downcast_value, internal_err, not_impl_err, Result, ScalarValue,
50};
51use datafusion_expr::function::StateFieldsArgs;
52use datafusion_expr::{
53 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
54 Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
55};
56use datafusion_expr::{
57 Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition,
58};
59use datafusion_functions_aggregate_common::aggregate::count_distinct::{
60 BytesDistinctCountAccumulator, DictionaryCountAccumulator,
61 FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
62};
63use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
64use datafusion_physical_expr_common::binary_map::OutputType;
65
66use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
67make_udaf_expr_and_func!(
68 Count,
69 count,
70 expr,
71 "Count the number of non-null values in the column",
72 count_udaf
73);
74
75pub fn count_distinct(expr: Expr) -> Expr {
76 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
77 count_udaf(),
78 vec![expr],
79 true,
80 None,
81 None,
82 None,
83 ))
84}
85
86pub fn count_all() -> Expr {
104 count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
105}
106
107pub fn count_all_window() -> Expr {
127 Expr::from(WindowFunction::new(
128 WindowFunctionDefinition::AggregateUDF(count_udaf()),
129 vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
130 ))
131}
132
133#[user_doc(
134 doc_section(label = "General Functions"),
135 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
136 syntax_example = "count(expression)",
137 sql_example = r#"```sql
138> SELECT count(column_name) FROM table_name;
139+-----------------------+
140| count(column_name) |
141+-----------------------+
142| 100 |
143+-----------------------+
144
145> SELECT count(*) FROM table_name;
146+------------------+
147| count(*) |
148+------------------+
149| 120 |
150+------------------+
151```"#,
152 standard_argument(name = "expression",)
153)]
154pub struct Count {
155 signature: Signature,
156}
157
158impl Debug for Count {
159 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160 f.debug_struct("Count")
161 .field("name", &self.name())
162 .field("signature", &self.signature)
163 .finish()
164 }
165}
166
167impl Default for Count {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl Count {
174 pub fn new() -> Self {
175 Self {
176 signature: Signature::one_of(
177 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
178 Volatility::Immutable,
179 ),
180 }
181 }
182}
183fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
184 match data_type {
185 DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
187 data_type,
188 )),
189 DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
190 data_type,
191 )),
192 DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
193 data_type,
194 )),
195 DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
196 data_type,
197 )),
198 DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
199 data_type,
200 )),
201 DataType::UInt16 => Box::new(
202 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
203 ),
204 DataType::UInt32 => Box::new(
205 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
206 ),
207 DataType::UInt64 => Box::new(
208 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
209 ),
210 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211 Decimal128Type,
212 >::new(data_type)),
213 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
214 Decimal256Type,
215 >::new(data_type)),
216
217 DataType::Date32 => Box::new(
218 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
219 ),
220 DataType::Date64 => Box::new(
221 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
222 ),
223 DataType::Time32(TimeUnit::Millisecond) => Box::new(
224 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
225 ),
226 DataType::Time32(TimeUnit::Second) => Box::new(
227 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
228 ),
229 DataType::Time64(TimeUnit::Microsecond) => Box::new(
230 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
231 ),
232 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
233 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
234 ),
235 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
236 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
237 ),
238 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
239 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
240 ),
241 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
242 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
243 ),
244 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
245 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
246 ),
247
248 DataType::Float16 => {
249 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
250 }
251 DataType::Float32 => {
252 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
253 }
254 DataType::Float64 => {
255 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
256 }
257
258 DataType::Utf8 => {
259 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
260 }
261 DataType::Utf8View => {
262 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
263 }
264 DataType::LargeUtf8 => {
265 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
266 }
267 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
268 OutputType::Binary,
269 )),
270 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
271 OutputType::BinaryView,
272 )),
273 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
274 OutputType::Binary,
275 )),
276
277 _ => Box::new(DistinctCountAccumulator {
279 values: HashSet::default(),
280 state_data_type: data_type.clone(),
281 }),
282 }
283}
284
285impl AggregateUDFImpl for Count {
286 fn as_any(&self) -> &dyn std::any::Any {
287 self
288 }
289
290 fn name(&self) -> &str {
291 "count"
292 }
293
294 fn signature(&self) -> &Signature {
295 &self.signature
296 }
297
298 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
299 Ok(DataType::Int64)
300 }
301
302 fn is_nullable(&self) -> bool {
303 false
304 }
305
306 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
307 if args.is_distinct {
308 let dtype: DataType = match &args.input_fields[0].data_type() {
309 DataType::Dictionary(_, values_type) => (**values_type).clone(),
310 &dtype => dtype.clone(),
311 };
312
313 Ok(vec![Field::new_list(
314 format_state_name(args.name, "count distinct"),
315 Field::new_list_field(dtype, true),
317 false,
318 )
319 .into()])
320 } else {
321 Ok(vec![Field::new(
322 format_state_name(args.name, "count"),
323 DataType::Int64,
324 false,
325 )
326 .into()])
327 }
328 }
329
330 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
331 if !acc_args.is_distinct {
332 return Ok(Box::new(CountAccumulator::new()));
333 }
334
335 if acc_args.exprs.len() > 1 {
336 return not_impl_err!("COUNT DISTINCT with multiple arguments");
337 }
338
339 let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
340
341 Ok(match data_type {
342 DataType::Dictionary(_, values_type) => {
343 let inner = get_count_accumulator(values_type);
344 Box::new(DictionaryCountAccumulator::new(inner))
345 }
346 _ => get_count_accumulator(data_type),
347 })
348 }
349
350 fn aliases(&self) -> &[String] {
351 &[]
352 }
353
354 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
355 if args.is_distinct {
358 return false;
359 }
360 args.exprs.len() == 1
361 }
362
363 fn create_groups_accumulator(
364 &self,
365 _args: AccumulatorArgs,
366 ) -> Result<Box<dyn GroupsAccumulator>> {
367 Ok(Box::new(CountGroupsAccumulator::new()))
369 }
370
371 fn reverse_expr(&self) -> ReversedUDAF {
372 ReversedUDAF::Identical
373 }
374
375 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
376 Ok(ScalarValue::Int64(Some(0)))
377 }
378
379 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
380 if statistics_args.is_distinct {
381 return None;
382 }
383 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
384 if statistics_args.exprs.len() == 1 {
385 if let Some(col_expr) = statistics_args.exprs[0]
387 .as_any()
388 .downcast_ref::<expressions::Column>()
389 {
390 let current_val = &statistics_args.statistics.column_statistics
391 [col_expr.index()]
392 .null_count;
393 if let &Precision::Exact(val) = current_val {
394 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
395 }
396 } else if let Some(lit_expr) = statistics_args.exprs[0]
397 .as_any()
398 .downcast_ref::<expressions::Literal>()
399 {
400 if lit_expr.value() == &COUNT_STAR_EXPANSION {
401 return Some(ScalarValue::Int64(Some(num_rows as i64)));
402 }
403 }
404 }
405 }
406 None
407 }
408
409 fn documentation(&self) -> Option<&Documentation> {
410 self.doc()
411 }
412
413 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
414 SetMonotonicity::Increasing
417 }
418}
419
420#[derive(Debug)]
421struct CountAccumulator {
422 count: i64,
423}
424
425impl CountAccumulator {
426 pub fn new() -> Self {
428 Self { count: 0 }
429 }
430}
431
432impl Accumulator for CountAccumulator {
433 fn state(&mut self) -> Result<Vec<ScalarValue>> {
434 Ok(vec![ScalarValue::Int64(Some(self.count))])
435 }
436
437 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
438 let array = &values[0];
439 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
440 Ok(())
441 }
442
443 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
444 let array = &values[0];
445 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
446 Ok(())
447 }
448
449 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
450 let counts = downcast_value!(states[0], Int64Array);
451 let delta = &compute::sum(counts);
452 if let Some(d) = delta {
453 self.count += *d;
454 }
455 Ok(())
456 }
457
458 fn evaluate(&mut self) -> Result<ScalarValue> {
459 Ok(ScalarValue::Int64(Some(self.count)))
460 }
461
462 fn supports_retract_batch(&self) -> bool {
463 true
464 }
465
466 fn size(&self) -> usize {
467 size_of_val(self)
468 }
469}
470
471#[derive(Debug)]
478struct CountGroupsAccumulator {
479 counts: Vec<i64>,
486}
487
488impl CountGroupsAccumulator {
489 pub fn new() -> Self {
490 Self { counts: vec![] }
491 }
492}
493
494impl GroupsAccumulator for CountGroupsAccumulator {
495 fn update_batch(
496 &mut self,
497 values: &[ArrayRef],
498 group_indices: &[usize],
499 opt_filter: Option<&BooleanArray>,
500 total_num_groups: usize,
501 ) -> Result<()> {
502 assert_eq!(values.len(), 1, "single argument to update_batch");
503 let values = &values[0];
504
505 self.counts.resize(total_num_groups, 0);
508 accumulate_indices(
509 group_indices,
510 values.logical_nulls().as_ref(),
511 opt_filter,
512 |group_index| {
513 self.counts[group_index] += 1;
514 },
515 );
516
517 Ok(())
518 }
519
520 fn merge_batch(
521 &mut self,
522 values: &[ArrayRef],
523 group_indices: &[usize],
524 _opt_filter: Option<&BooleanArray>,
526 total_num_groups: usize,
527 ) -> Result<()> {
528 assert_eq!(values.len(), 1, "one argument to merge_batch");
529 let partial_counts = values[0].as_primitive::<Int64Type>();
531
532 assert_eq!(partial_counts.null_count(), 0);
534 let partial_counts = partial_counts.values();
535
536 self.counts.resize(total_num_groups, 0);
538 group_indices.iter().zip(partial_counts.iter()).for_each(
539 |(&group_index, partial_count)| {
540 self.counts[group_index] += partial_count;
541 },
542 );
543
544 Ok(())
545 }
546
547 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
548 let counts = emit_to.take_needed(&mut self.counts);
549
550 let nulls = None;
552 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
553
554 Ok(Arc::new(array))
555 }
556
557 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
559 let counts = emit_to.take_needed(&mut self.counts);
560 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
562 }
563
564 fn convert_to_state(
570 &self,
571 values: &[ArrayRef],
572 opt_filter: Option<&BooleanArray>,
573 ) -> Result<Vec<ArrayRef>> {
574 let values = &values[0];
575
576 let state_array = match (values.logical_nulls(), opt_filter) {
577 (None, None) => {
578 Arc::new(Int64Array::from_value(1, values.len()))
580 }
581 (Some(nulls), None) => {
582 let nulls = BooleanArray::new(nulls.into_inner(), None);
585 compute::cast(&nulls, &DataType::Int64)?
586 }
587 (None, Some(filter)) => {
588 let (filter_values, filter_nulls) = filter.clone().into_parts();
593
594 let state_buf = match filter_nulls {
595 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
596 None => filter_values,
597 };
598
599 let boolean_state = BooleanArray::new(state_buf, None);
600 compute::cast(&boolean_state, &DataType::Int64)?
601 }
602 (Some(nulls), Some(filter)) => {
603 let (filter_values, filter_nulls) = filter.clone().into_parts();
610
611 let filter_buf = match filter_nulls {
612 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
613 None => filter_values,
614 };
615 let state_buf = &filter_buf & nulls.inner();
616
617 let boolean_state = BooleanArray::new(state_buf, None);
618 compute::cast(&boolean_state, &DataType::Int64)?
619 }
620 };
621
622 Ok(vec![state_array])
623 }
624
625 fn supports_convert_to_state(&self) -> bool {
626 true
627 }
628
629 fn size(&self) -> usize {
630 self.counts.capacity() * size_of::<usize>()
631 }
632}
633
634fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
637 if values.len() > 1 {
638 let result_bool_buf: Option<BooleanBuffer> = values
639 .iter()
640 .map(|a| a.logical_nulls())
641 .fold(None, |acc, b| match (acc, b) {
642 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
643 (Some(acc), None) => Some(acc),
644 (None, Some(b)) => Some(b.into_inner()),
645 _ => None,
646 });
647 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
648 } else {
649 values[0]
650 .logical_nulls()
651 .map_or(0, |nulls| nulls.null_count())
652 }
653}
654
655#[derive(Debug)]
664struct DistinctCountAccumulator {
665 values: HashSet<ScalarValue, RandomState>,
666 state_data_type: DataType,
667}
668
669impl DistinctCountAccumulator {
670 fn fixed_size(&self) -> usize {
674 size_of_val(self)
675 + (size_of::<ScalarValue>() * self.values.capacity())
676 + self
677 .values
678 .iter()
679 .next()
680 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
681 .unwrap_or(0)
682 + size_of::<DataType>()
683 }
684
685 fn full_size(&self) -> usize {
688 size_of_val(self)
689 + (size_of::<ScalarValue>() * self.values.capacity())
690 + self
691 .values
692 .iter()
693 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
694 .sum::<usize>()
695 + size_of::<DataType>()
696 }
697}
698
699impl Accumulator for DistinctCountAccumulator {
700 fn state(&mut self) -> Result<Vec<ScalarValue>> {
702 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
703 let arr =
704 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
705 Ok(vec![ScalarValue::List(arr)])
706 }
707
708 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
709 if values.is_empty() {
710 return Ok(());
711 }
712
713 let arr = &values[0];
714 if arr.data_type() == &DataType::Null {
715 return Ok(());
716 }
717
718 (0..arr.len()).try_for_each(|index| {
719 if !arr.is_null(index) {
720 let scalar = ScalarValue::try_from_array(arr, index)?;
721 self.values.insert(scalar);
722 }
723 Ok(())
724 })
725 }
726
727 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
733 if states.is_empty() {
734 return Ok(());
735 }
736 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
737 let array = &states[0];
738 let list_array = array.as_list::<i32>();
739 for inner_array in list_array.iter() {
740 let Some(inner_array) = inner_array else {
741 return internal_err!(
742 "Intermediate results of COUNT DISTINCT should always be non null"
743 );
744 };
745 self.update_batch(&[inner_array])?;
746 }
747 Ok(())
748 }
749
750 fn evaluate(&mut self) -> Result<ScalarValue> {
751 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
752 }
753
754 fn size(&self) -> usize {
755 match &self.state_data_type {
756 DataType::Boolean | DataType::Null => self.fixed_size(),
757 d if d.is_primitive() => self.fixed_size(),
758 _ => self.full_size(),
759 }
760 }
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use arrow::array::{Int32Array, NullArray};
767 use arrow::datatypes::{DataType, Field, Int32Type, Schema};
768 use datafusion_expr::function::AccumulatorArgs;
769 use datafusion_physical_expr::expressions::Column;
770 use datafusion_physical_expr::LexOrdering;
771 use std::sync::Arc;
772
773 #[test]
774 fn count_accumulator_nulls() -> Result<()> {
775 let mut accumulator = CountAccumulator::new();
776 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
777 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
778 Ok(())
779 }
780
781 #[test]
782 fn test_nested_dictionary() -> Result<()> {
783 let schema = Arc::new(Schema::new(vec![Field::new(
784 "dict_col",
785 DataType::Dictionary(
786 Box::new(DataType::Int32),
787 Box::new(DataType::Dictionary(
788 Box::new(DataType::Int32),
789 Box::new(DataType::Utf8),
790 )),
791 ),
792 true,
793 )]));
794
795 let count = Count::new();
797 let expr = Arc::new(Column::new("dict_col", 0));
798 let args = AccumulatorArgs {
799 schema: &schema,
800 exprs: &[expr],
801 is_distinct: true,
802 name: "count",
803 ignore_nulls: false,
804 is_reversed: false,
805 return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
806 ordering_req: &LexOrdering::default(),
807 };
808
809 let inner_dict = arrow::array::DictionaryArray::<Int32Type>::from_iter([
810 "a", "b", "c", "d", "a", "b",
811 ]);
812
813 let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
814 let dict_of_dict = arrow::array::DictionaryArray::<Int32Type>::try_new(
815 keys,
816 Arc::new(inner_dict),
817 )?;
818
819 let mut acc = count.accumulator(args)?;
820 acc.update_batch(&[Arc::new(dict_of_dict)])?;
821 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
822
823 Ok(())
824 }
825}