1use ahash::RandomState;
19use datafusion_common::stats::Precision;
20use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
21use datafusion_macros::user_doc;
22use datafusion_physical_expr::expressions;
23use std::collections::HashSet;
24use std::fmt::Debug;
25use std::mem::{size_of, size_of_val};
26use std::ops::BitAnd;
27use std::sync::Arc;
28
29use arrow::{
30 array::{ArrayRef, AsArray},
31 compute,
32 datatypes::{
33 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
34 Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35 Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
36 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
38 UInt16Type, UInt32Type, UInt64Type, UInt8Type,
39 },
40};
41
42use arrow::{
43 array::{Array, BooleanArray, Int64Array, PrimitiveArray},
44 buffer::BooleanBuffer,
45};
46use datafusion_common::{
47 downcast_value, internal_err, not_impl_err, Result, ScalarValue,
48};
49use datafusion_expr::function::StateFieldsArgs;
50use datafusion_expr::{
51 function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
52 Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
53};
54use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
55use datafusion_functions_aggregate_common::aggregate::count_distinct::{
56 BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
57 PrimitiveDistinctCountAccumulator,
58};
59use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
60use datafusion_physical_expr_common::binary_map::OutputType;
61
62use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
63make_udaf_expr_and_func!(
64 Count,
65 count,
66 expr,
67 "Count the number of non-null values in the column",
68 count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73 count_udaf(),
74 vec![expr],
75 true,
76 None,
77 None,
78 None,
79 ))
80}
81
82pub fn count_all() -> Expr {
84 count(Expr::Literal(COUNT_STAR_EXPANSION))
85}
86
87#[user_doc(
88 doc_section(label = "General Functions"),
89 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
90 syntax_example = "count(expression)",
91 sql_example = r#"```sql
92> SELECT count(column_name) FROM table_name;
93+-----------------------+
94| count(column_name) |
95+-----------------------+
96| 100 |
97+-----------------------+
98
99> SELECT count(*) FROM table_name;
100+------------------+
101| count(*) |
102+------------------+
103| 120 |
104+------------------+
105```"#,
106 standard_argument(name = "expression",)
107)]
108pub struct Count {
109 signature: Signature,
110}
111
112impl Debug for Count {
113 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114 f.debug_struct("Count")
115 .field("name", &self.name())
116 .field("signature", &self.signature)
117 .finish()
118 }
119}
120
121impl Default for Count {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl Count {
128 pub fn new() -> Self {
129 Self {
130 signature: Signature::one_of(
131 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
132 Volatility::Immutable,
133 ),
134 }
135 }
136}
137
138impl AggregateUDFImpl for Count {
139 fn as_any(&self) -> &dyn std::any::Any {
140 self
141 }
142
143 fn name(&self) -> &str {
144 "count"
145 }
146
147 fn signature(&self) -> &Signature {
148 &self.signature
149 }
150
151 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
152 Ok(DataType::Int64)
153 }
154
155 fn is_nullable(&self) -> bool {
156 false
157 }
158
159 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
160 if args.is_distinct {
161 Ok(vec![Field::new_list(
162 format_state_name(args.name, "count distinct"),
163 Field::new_list_field(args.input_types[0].clone(), true),
165 false,
166 )])
167 } else {
168 Ok(vec![Field::new(
169 format_state_name(args.name, "count"),
170 DataType::Int64,
171 false,
172 )])
173 }
174 }
175
176 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
177 if !acc_args.is_distinct {
178 return Ok(Box::new(CountAccumulator::new()));
179 }
180
181 if acc_args.exprs.len() > 1 {
182 return not_impl_err!("COUNT DISTINCT with multiple arguments");
183 }
184
185 let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
186 Ok(match data_type {
187 DataType::Int8 => Box::new(
189 PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
190 ),
191 DataType::Int16 => Box::new(
192 PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
193 ),
194 DataType::Int32 => Box::new(
195 PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
196 ),
197 DataType::Int64 => Box::new(
198 PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
199 ),
200 DataType::UInt8 => Box::new(
201 PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
202 ),
203 DataType::UInt16 => Box::new(
204 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
205 ),
206 DataType::UInt32 => Box::new(
207 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
208 ),
209 DataType::UInt64 => Box::new(
210 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
211 ),
212 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
213 Decimal128Type,
214 >::new(data_type)),
215 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
216 Decimal256Type,
217 >::new(data_type)),
218
219 DataType::Date32 => Box::new(
220 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
221 ),
222 DataType::Date64 => Box::new(
223 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
224 ),
225 DataType::Time32(TimeUnit::Millisecond) => Box::new(
226 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
227 data_type,
228 ),
229 ),
230 DataType::Time32(TimeUnit::Second) => Box::new(
231 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
232 ),
233 DataType::Time64(TimeUnit::Microsecond) => Box::new(
234 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
235 data_type,
236 ),
237 ),
238 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
239 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
240 ),
241 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
242 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
243 data_type,
244 ),
245 ),
246 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
247 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
248 data_type,
249 ),
250 ),
251 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
252 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
253 data_type,
254 ),
255 ),
256 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
257 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
258 ),
259
260 DataType::Float16 => {
261 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
262 }
263 DataType::Float32 => {
264 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
265 }
266 DataType::Float64 => {
267 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
268 }
269
270 DataType::Utf8 => {
271 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
272 }
273 DataType::Utf8View => {
274 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
275 }
276 DataType::LargeUtf8 => {
277 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
278 }
279 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
280 OutputType::Binary,
281 )),
282 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
283 OutputType::BinaryView,
284 )),
285 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
286 OutputType::Binary,
287 )),
288
289 _ => Box::new(DistinctCountAccumulator {
291 values: HashSet::default(),
292 state_data_type: data_type.clone(),
293 }),
294 })
295 }
296
297 fn aliases(&self) -> &[String] {
298 &[]
299 }
300
301 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
302 if args.is_distinct {
305 return false;
306 }
307 args.exprs.len() == 1
308 }
309
310 fn create_groups_accumulator(
311 &self,
312 _args: AccumulatorArgs,
313 ) -> Result<Box<dyn GroupsAccumulator>> {
314 Ok(Box::new(CountGroupsAccumulator::new()))
316 }
317
318 fn reverse_expr(&self) -> ReversedUDAF {
319 ReversedUDAF::Identical
320 }
321
322 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
323 Ok(ScalarValue::Int64(Some(0)))
324 }
325
326 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
327 if statistics_args.is_distinct {
328 return None;
329 }
330 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
331 if statistics_args.exprs.len() == 1 {
332 if let Some(col_expr) = statistics_args.exprs[0]
334 .as_any()
335 .downcast_ref::<expressions::Column>()
336 {
337 let current_val = &statistics_args.statistics.column_statistics
338 [col_expr.index()]
339 .null_count;
340 if let &Precision::Exact(val) = current_val {
341 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
342 }
343 } else if let Some(lit_expr) = statistics_args.exprs[0]
344 .as_any()
345 .downcast_ref::<expressions::Literal>()
346 {
347 if lit_expr.value() == &COUNT_STAR_EXPANSION {
348 return Some(ScalarValue::Int64(Some(num_rows as i64)));
349 }
350 }
351 }
352 }
353 None
354 }
355
356 fn documentation(&self) -> Option<&Documentation> {
357 self.doc()
358 }
359
360 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
361 SetMonotonicity::Increasing
364 }
365}
366
367#[derive(Debug)]
368struct CountAccumulator {
369 count: i64,
370}
371
372impl CountAccumulator {
373 pub fn new() -> Self {
375 Self { count: 0 }
376 }
377}
378
379impl Accumulator for CountAccumulator {
380 fn state(&mut self) -> Result<Vec<ScalarValue>> {
381 Ok(vec![ScalarValue::Int64(Some(self.count))])
382 }
383
384 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
385 let array = &values[0];
386 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
387 Ok(())
388 }
389
390 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
391 let array = &values[0];
392 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
393 Ok(())
394 }
395
396 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
397 let counts = downcast_value!(states[0], Int64Array);
398 let delta = &compute::sum(counts);
399 if let Some(d) = delta {
400 self.count += *d;
401 }
402 Ok(())
403 }
404
405 fn evaluate(&mut self) -> Result<ScalarValue> {
406 Ok(ScalarValue::Int64(Some(self.count)))
407 }
408
409 fn supports_retract_batch(&self) -> bool {
410 true
411 }
412
413 fn size(&self) -> usize {
414 size_of_val(self)
415 }
416}
417
418#[derive(Debug)]
425struct CountGroupsAccumulator {
426 counts: Vec<i64>,
433}
434
435impl CountGroupsAccumulator {
436 pub fn new() -> Self {
437 Self { counts: vec![] }
438 }
439}
440
441impl GroupsAccumulator for CountGroupsAccumulator {
442 fn update_batch(
443 &mut self,
444 values: &[ArrayRef],
445 group_indices: &[usize],
446 opt_filter: Option<&BooleanArray>,
447 total_num_groups: usize,
448 ) -> Result<()> {
449 assert_eq!(values.len(), 1, "single argument to update_batch");
450 let values = &values[0];
451
452 self.counts.resize(total_num_groups, 0);
455 accumulate_indices(
456 group_indices,
457 values.logical_nulls().as_ref(),
458 opt_filter,
459 |group_index| {
460 self.counts[group_index] += 1;
461 },
462 );
463
464 Ok(())
465 }
466
467 fn merge_batch(
468 &mut self,
469 values: &[ArrayRef],
470 group_indices: &[usize],
471 _opt_filter: Option<&BooleanArray>,
473 total_num_groups: usize,
474 ) -> Result<()> {
475 assert_eq!(values.len(), 1, "one argument to merge_batch");
476 let partial_counts = values[0].as_primitive::<Int64Type>();
478
479 assert_eq!(partial_counts.null_count(), 0);
481 let partial_counts = partial_counts.values();
482
483 self.counts.resize(total_num_groups, 0);
485 group_indices.iter().zip(partial_counts.iter()).for_each(
486 |(&group_index, partial_count)| {
487 self.counts[group_index] += partial_count;
488 },
489 );
490
491 Ok(())
492 }
493
494 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
495 let counts = emit_to.take_needed(&mut self.counts);
496
497 let nulls = None;
499 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
500
501 Ok(Arc::new(array))
502 }
503
504 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
506 let counts = emit_to.take_needed(&mut self.counts);
507 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
509 }
510
511 fn convert_to_state(
517 &self,
518 values: &[ArrayRef],
519 opt_filter: Option<&BooleanArray>,
520 ) -> Result<Vec<ArrayRef>> {
521 let values = &values[0];
522
523 let state_array = match (values.logical_nulls(), opt_filter) {
524 (None, None) => {
525 Arc::new(Int64Array::from_value(1, values.len()))
527 }
528 (Some(nulls), None) => {
529 let nulls = BooleanArray::new(nulls.into_inner(), None);
532 compute::cast(&nulls, &DataType::Int64)?
533 }
534 (None, Some(filter)) => {
535 let (filter_values, filter_nulls) = filter.clone().into_parts();
540
541 let state_buf = match filter_nulls {
542 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
543 None => filter_values,
544 };
545
546 let boolean_state = BooleanArray::new(state_buf, None);
547 compute::cast(&boolean_state, &DataType::Int64)?
548 }
549 (Some(nulls), Some(filter)) => {
550 let (filter_values, filter_nulls) = filter.clone().into_parts();
557
558 let filter_buf = match filter_nulls {
559 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
560 None => filter_values,
561 };
562 let state_buf = &filter_buf & nulls.inner();
563
564 let boolean_state = BooleanArray::new(state_buf, None);
565 compute::cast(&boolean_state, &DataType::Int64)?
566 }
567 };
568
569 Ok(vec![state_array])
570 }
571
572 fn supports_convert_to_state(&self) -> bool {
573 true
574 }
575
576 fn size(&self) -> usize {
577 self.counts.capacity() * size_of::<usize>()
578 }
579}
580
581fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
584 if values.len() > 1 {
585 let result_bool_buf: Option<BooleanBuffer> = values
586 .iter()
587 .map(|a| a.logical_nulls())
588 .fold(None, |acc, b| match (acc, b) {
589 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
590 (Some(acc), None) => Some(acc),
591 (None, Some(b)) => Some(b.into_inner()),
592 _ => None,
593 });
594 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
595 } else {
596 values[0]
597 .logical_nulls()
598 .map_or(0, |nulls| nulls.null_count())
599 }
600}
601
602#[derive(Debug)]
611struct DistinctCountAccumulator {
612 values: HashSet<ScalarValue, RandomState>,
613 state_data_type: DataType,
614}
615
616impl DistinctCountAccumulator {
617 fn fixed_size(&self) -> usize {
621 size_of_val(self)
622 + (size_of::<ScalarValue>() * self.values.capacity())
623 + self
624 .values
625 .iter()
626 .next()
627 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
628 .unwrap_or(0)
629 + size_of::<DataType>()
630 }
631
632 fn full_size(&self) -> usize {
635 size_of_val(self)
636 + (size_of::<ScalarValue>() * self.values.capacity())
637 + self
638 .values
639 .iter()
640 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
641 .sum::<usize>()
642 + size_of::<DataType>()
643 }
644}
645
646impl Accumulator for DistinctCountAccumulator {
647 fn state(&mut self) -> Result<Vec<ScalarValue>> {
649 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
650 let arr =
651 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
652 Ok(vec![ScalarValue::List(arr)])
653 }
654
655 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
656 if values.is_empty() {
657 return Ok(());
658 }
659
660 let arr = &values[0];
661 if arr.data_type() == &DataType::Null {
662 return Ok(());
663 }
664
665 (0..arr.len()).try_for_each(|index| {
666 if !arr.is_null(index) {
667 let scalar = ScalarValue::try_from_array(arr, index)?;
668 self.values.insert(scalar);
669 }
670 Ok(())
671 })
672 }
673
674 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
680 if states.is_empty() {
681 return Ok(());
682 }
683 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
684 let array = &states[0];
685 let list_array = array.as_list::<i32>();
686 for inner_array in list_array.iter() {
687 let Some(inner_array) = inner_array else {
688 return internal_err!(
689 "Intermediate results of COUNT DISTINCT should always be non null"
690 );
691 };
692 self.update_batch(&[inner_array])?;
693 }
694 Ok(())
695 }
696
697 fn evaluate(&mut self) -> Result<ScalarValue> {
698 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
699 }
700
701 fn size(&self) -> usize {
702 match &self.state_data_type {
703 DataType::Boolean | DataType::Null => self.fixed_size(),
704 d if d.is_primitive() => self.fixed_size(),
705 _ => self.full_size(),
706 }
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use arrow::array::NullArray;
714
715 #[test]
716 fn count_accumulator_nulls() -> Result<()> {
717 let mut accumulator = CountAccumulator::new();
718 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
719 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
720 Ok(())
721 }
722}