1use ahash::RandomState;
19use arrow::{
20 array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21 buffer::BooleanBuffer,
22 compute,
23 datatypes::{
24 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25 FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
26 Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29 UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30 },
31};
32use datafusion_common::{
33 HashMap, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
34 stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38 ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39 WindowFunctionDefinition,
40 expr::WindowFunction,
41 function::{AccumulatorArgs, StateFieldsArgs},
42 utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45 count_distinct::BytesDistinctCountAccumulator,
46 count_distinct::BytesViewDistinctCountAccumulator,
47 count_distinct::DictionaryCountAccumulator,
48 count_distinct::FloatDistinctCountAccumulator,
49 count_distinct::PrimitiveDistinctCountAccumulator,
50 groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56 collections::HashSet,
57 fmt::Debug,
58 mem::{size_of, size_of_val},
59 ops::BitAnd,
60 sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64 Count,
65 count,
66 expr,
67 "Count the number of non-null values in the column",
68 count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73 count_udaf(),
74 vec![expr],
75 true,
76 None,
77 vec![],
78 None,
79 ))
80}
81
82pub fn count_all() -> Expr {
100 count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103pub fn count_all_window() -> Expr {
123 Expr::from(WindowFunction::new(
124 WindowFunctionDefinition::AggregateUDF(count_udaf()),
125 vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126 ))
127}
128
129#[user_doc(
130 doc_section(label = "General Functions"),
131 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132 syntax_example = "count(expression)",
133 sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name) |
137+-----------------------+
138| 100 |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*) |
144+------------------+
145| 120 |
146+------------------+
147```"#,
148 standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash)]
151pub struct Count {
152 signature: Signature,
153}
154
155impl Debug for Count {
156 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157 f.debug_struct("Count")
158 .field("name", &self.name())
159 .field("signature", &self.signature)
160 .finish()
161 }
162}
163
164impl Default for Count {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl Count {
171 pub fn new() -> Self {
172 Self {
173 signature: Signature::one_of(
174 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
175 Volatility::Immutable,
176 ),
177 }
178 }
179}
180fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
181 match data_type {
182 DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
184 data_type,
185 )),
186 DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
187 data_type,
188 )),
189 DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
190 data_type,
191 )),
192 DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
193 data_type,
194 )),
195 DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
196 data_type,
197 )),
198 DataType::UInt16 => Box::new(
199 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
200 ),
201 DataType::UInt32 => Box::new(
202 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
203 ),
204 DataType::UInt64 => Box::new(
205 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
206 ),
207 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
208 Decimal128Type,
209 >::new(data_type)),
210 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211 Decimal256Type,
212 >::new(data_type)),
213
214 DataType::Date32 => Box::new(
215 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
216 ),
217 DataType::Date64 => Box::new(
218 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
219 ),
220 DataType::Time32(TimeUnit::Millisecond) => Box::new(
221 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
222 ),
223 DataType::Time32(TimeUnit::Second) => Box::new(
224 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
225 ),
226 DataType::Time64(TimeUnit::Microsecond) => Box::new(
227 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
228 ),
229 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
230 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
231 ),
232 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
233 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
234 ),
235 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
236 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
237 ),
238 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
239 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
240 ),
241 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
242 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
243 ),
244
245 DataType::Float16 => {
246 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
247 }
248 DataType::Float32 => {
249 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
250 }
251 DataType::Float64 => {
252 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
253 }
254
255 DataType::Utf8 => {
256 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
257 }
258 DataType::Utf8View => {
259 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
260 }
261 DataType::LargeUtf8 => {
262 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
263 }
264 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
265 OutputType::Binary,
266 )),
267 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
268 OutputType::BinaryView,
269 )),
270 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
271 OutputType::Binary,
272 )),
273
274 _ => Box::new(DistinctCountAccumulator {
276 values: HashSet::default(),
277 state_data_type: data_type.clone(),
278 }),
279 }
280}
281
282impl AggregateUDFImpl for Count {
283 fn as_any(&self) -> &dyn std::any::Any {
284 self
285 }
286
287 fn name(&self) -> &str {
288 "count"
289 }
290
291 fn signature(&self) -> &Signature {
292 &self.signature
293 }
294
295 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296 Ok(DataType::Int64)
297 }
298
299 fn is_nullable(&self) -> bool {
300 false
301 }
302
303 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
304 if args.is_distinct {
305 let dtype: DataType = match &args.input_fields[0].data_type() {
306 DataType::Dictionary(_, values_type) => (**values_type).clone(),
307 &dtype => dtype.clone(),
308 };
309
310 Ok(vec![
311 Field::new_list(
312 format_state_name(args.name, "count distinct"),
313 Field::new_list_field(dtype, true),
315 false,
316 )
317 .into(),
318 ])
319 } else {
320 Ok(vec![
321 Field::new(
322 format_state_name(args.name, "count"),
323 DataType::Int64,
324 false,
325 )
326 .into(),
327 ])
328 }
329 }
330
331 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
332 if !acc_args.is_distinct {
333 return Ok(Box::new(CountAccumulator::new()));
334 }
335
336 if acc_args.exprs.len() > 1 {
337 return not_impl_err!("COUNT DISTINCT with multiple arguments");
338 }
339
340 let data_type = acc_args.expr_fields[0].data_type();
341
342 Ok(match data_type {
343 DataType::Dictionary(_, values_type) => {
344 let inner = get_count_accumulator(values_type);
345 Box::new(DictionaryCountAccumulator::new(inner))
346 }
347 _ => get_count_accumulator(data_type),
348 })
349 }
350
351 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
352 if args.is_distinct {
355 return false;
356 }
357 args.exprs.len() == 1
358 }
359
360 fn create_groups_accumulator(
361 &self,
362 _args: AccumulatorArgs,
363 ) -> Result<Box<dyn GroupsAccumulator>> {
364 Ok(Box::new(CountGroupsAccumulator::new()))
366 }
367
368 fn reverse_expr(&self) -> ReversedUDAF {
369 ReversedUDAF::Identical
370 }
371
372 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
373 Ok(ScalarValue::Int64(Some(0)))
374 }
375
376 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
377 if statistics_args.is_distinct {
378 return None;
379 }
380 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows
381 && statistics_args.exprs.len() == 1
382 {
383 if let Some(col_expr) = statistics_args.exprs[0]
385 .as_any()
386 .downcast_ref::<expressions::Column>()
387 {
388 let current_val = &statistics_args.statistics.column_statistics
389 [col_expr.index()]
390 .null_count;
391 if let &Precision::Exact(val) = current_val {
392 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
393 }
394 } else if let Some(lit_expr) = statistics_args.exprs[0]
395 .as_any()
396 .downcast_ref::<expressions::Literal>()
397 && lit_expr.value() == &COUNT_STAR_EXPANSION
398 {
399 return Some(ScalarValue::Int64(Some(num_rows as i64)));
400 }
401 }
402 None
403 }
404
405 fn documentation(&self) -> Option<&Documentation> {
406 self.doc()
407 }
408
409 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
410 SetMonotonicity::Increasing
413 }
414
415 fn create_sliding_accumulator(
416 &self,
417 args: AccumulatorArgs,
418 ) -> Result<Box<dyn Accumulator>> {
419 if args.is_distinct {
420 let acc =
421 SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
422 Ok(Box::new(acc))
423 } else {
424 let acc = CountAccumulator::new();
425 Ok(Box::new(acc))
426 }
427 }
428}
429
430#[derive(Debug)]
434pub struct SlidingDistinctCountAccumulator {
435 counts: HashMap<ScalarValue, usize, RandomState>,
436 data_type: DataType,
437}
438
439impl SlidingDistinctCountAccumulator {
440 pub fn try_new(data_type: &DataType) -> Result<Self> {
441 Ok(Self {
442 counts: HashMap::default(),
443 data_type: data_type.clone(),
444 })
445 }
446}
447
448impl Accumulator for SlidingDistinctCountAccumulator {
449 fn state(&mut self) -> Result<Vec<ScalarValue>> {
450 let keys = self.counts.keys().cloned().collect::<Vec<_>>();
451 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
452 keys.as_slice(),
453 &self.data_type,
454 ))])
455 }
456
457 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
458 let arr = &values[0];
459 for i in 0..arr.len() {
460 let v = ScalarValue::try_from_array(arr, i)?;
461 if !v.is_null() {
462 *self.counts.entry(v).or_default() += 1;
463 }
464 }
465 Ok(())
466 }
467
468 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469 let arr = &values[0];
470 for i in 0..arr.len() {
471 let v = ScalarValue::try_from_array(arr, i)?;
472 if !v.is_null()
473 && let Some(cnt) = self.counts.get_mut(&v)
474 {
475 *cnt -= 1;
476 if *cnt == 0 {
477 self.counts.remove(&v);
478 }
479 }
480 }
481 Ok(())
482 }
483
484 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
485 let list_arr = states[0].as_list::<i32>();
486 for inner in list_arr.iter().flatten() {
487 for j in 0..inner.len() {
488 let v = ScalarValue::try_from_array(&*inner, j)?;
489 *self.counts.entry(v).or_default() += 1;
490 }
491 }
492 Ok(())
493 }
494
495 fn evaluate(&mut self) -> Result<ScalarValue> {
496 Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
497 }
498
499 fn supports_retract_batch(&self) -> bool {
500 true
501 }
502
503 fn size(&self) -> usize {
504 size_of_val(self)
505 }
506}
507
508#[derive(Debug)]
509struct CountAccumulator {
510 count: i64,
511}
512
513impl CountAccumulator {
514 pub fn new() -> Self {
516 Self { count: 0 }
517 }
518}
519
520impl Accumulator for CountAccumulator {
521 fn state(&mut self) -> Result<Vec<ScalarValue>> {
522 Ok(vec![ScalarValue::Int64(Some(self.count))])
523 }
524
525 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
526 let array = &values[0];
527 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
528 Ok(())
529 }
530
531 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
532 let array = &values[0];
533 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
534 Ok(())
535 }
536
537 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
538 let counts = downcast_value!(states[0], Int64Array);
539 let delta = &compute::sum(counts);
540 if let Some(d) = delta {
541 self.count += *d;
542 }
543 Ok(())
544 }
545
546 fn evaluate(&mut self) -> Result<ScalarValue> {
547 Ok(ScalarValue::Int64(Some(self.count)))
548 }
549
550 fn supports_retract_batch(&self) -> bool {
551 true
552 }
553
554 fn size(&self) -> usize {
555 size_of_val(self)
556 }
557}
558
559#[derive(Debug)]
566struct CountGroupsAccumulator {
567 counts: Vec<i64>,
574}
575
576impl CountGroupsAccumulator {
577 pub fn new() -> Self {
578 Self { counts: vec![] }
579 }
580}
581
582impl GroupsAccumulator for CountGroupsAccumulator {
583 fn update_batch(
584 &mut self,
585 values: &[ArrayRef],
586 group_indices: &[usize],
587 opt_filter: Option<&BooleanArray>,
588 total_num_groups: usize,
589 ) -> Result<()> {
590 assert_eq!(values.len(), 1, "single argument to update_batch");
591 let values = &values[0];
592
593 self.counts.resize(total_num_groups, 0);
596 accumulate_indices(
597 group_indices,
598 values.logical_nulls().as_ref(),
599 opt_filter,
600 |group_index| {
601 self.counts[group_index] += 1;
602 },
603 );
604
605 Ok(())
606 }
607
608 fn merge_batch(
609 &mut self,
610 values: &[ArrayRef],
611 group_indices: &[usize],
612 _opt_filter: Option<&BooleanArray>,
614 total_num_groups: usize,
615 ) -> Result<()> {
616 assert_eq!(values.len(), 1, "one argument to merge_batch");
617 let partial_counts = values[0].as_primitive::<Int64Type>();
619
620 assert_eq!(partial_counts.null_count(), 0);
622 let partial_counts = partial_counts.values();
623
624 self.counts.resize(total_num_groups, 0);
626 group_indices.iter().zip(partial_counts.iter()).for_each(
627 |(&group_index, partial_count)| {
628 self.counts[group_index] += partial_count;
629 },
630 );
631
632 Ok(())
633 }
634
635 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
636 let counts = emit_to.take_needed(&mut self.counts);
637
638 let nulls = None;
640 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
641
642 Ok(Arc::new(array))
643 }
644
645 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
647 let counts = emit_to.take_needed(&mut self.counts);
648 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
650 }
651
652 fn convert_to_state(
658 &self,
659 values: &[ArrayRef],
660 opt_filter: Option<&BooleanArray>,
661 ) -> Result<Vec<ArrayRef>> {
662 let values = &values[0];
663
664 let state_array = match (values.logical_nulls(), opt_filter) {
665 (None, None) => {
666 Arc::new(Int64Array::from_value(1, values.len()))
668 }
669 (Some(nulls), None) => {
670 let nulls = BooleanArray::new(nulls.into_inner(), None);
673 compute::cast(&nulls, &DataType::Int64)?
674 }
675 (None, Some(filter)) => {
676 let (filter_values, filter_nulls) = filter.clone().into_parts();
681
682 let state_buf = match filter_nulls {
683 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
684 None => filter_values,
685 };
686
687 let boolean_state = BooleanArray::new(state_buf, None);
688 compute::cast(&boolean_state, &DataType::Int64)?
689 }
690 (Some(nulls), Some(filter)) => {
691 let (filter_values, filter_nulls) = filter.clone().into_parts();
698
699 let filter_buf = match filter_nulls {
700 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
701 None => filter_values,
702 };
703 let state_buf = &filter_buf & nulls.inner();
704
705 let boolean_state = BooleanArray::new(state_buf, None);
706 compute::cast(&boolean_state, &DataType::Int64)?
707 }
708 };
709
710 Ok(vec![state_array])
711 }
712
713 fn supports_convert_to_state(&self) -> bool {
714 true
715 }
716
717 fn size(&self) -> usize {
718 self.counts.capacity() * size_of::<usize>()
719 }
720}
721
722fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
725 if values.len() > 1 {
726 let result_bool_buf: Option<BooleanBuffer> = values
727 .iter()
728 .map(|a| a.logical_nulls())
729 .fold(None, |acc, b| match (acc, b) {
730 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
731 (Some(acc), None) => Some(acc),
732 (None, Some(b)) => Some(b.into_inner()),
733 _ => None,
734 });
735 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
736 } else {
737 values[0]
738 .logical_nulls()
739 .map_or(0, |nulls| nulls.null_count())
740 }
741}
742
743#[derive(Debug)]
752struct DistinctCountAccumulator {
753 values: HashSet<ScalarValue, RandomState>,
754 state_data_type: DataType,
755}
756
757impl DistinctCountAccumulator {
758 fn fixed_size(&self) -> usize {
762 size_of_val(self)
763 + (size_of::<ScalarValue>() * self.values.capacity())
764 + self
765 .values
766 .iter()
767 .next()
768 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
769 .unwrap_or(0)
770 + size_of::<DataType>()
771 }
772
773 fn full_size(&self) -> usize {
776 size_of_val(self)
777 + (size_of::<ScalarValue>() * self.values.capacity())
778 + self
779 .values
780 .iter()
781 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
782 .sum::<usize>()
783 + size_of::<DataType>()
784 }
785}
786
787impl Accumulator for DistinctCountAccumulator {
788 fn state(&mut self) -> Result<Vec<ScalarValue>> {
790 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
791 let arr =
792 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
793 Ok(vec![ScalarValue::List(arr)])
794 }
795
796 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
797 if values.is_empty() {
798 return Ok(());
799 }
800
801 let arr = &values[0];
802 if arr.data_type() == &DataType::Null {
803 return Ok(());
804 }
805
806 (0..arr.len()).try_for_each(|index| {
807 let scalar = ScalarValue::try_from_array(arr, index)?;
808 if !scalar.is_null() {
809 self.values.insert(scalar);
810 }
811 Ok(())
812 })
813 }
814
815 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
821 if states.is_empty() {
822 return Ok(());
823 }
824 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
825 let array = &states[0];
826 let list_array = array.as_list::<i32>();
827 for inner_array in list_array.iter() {
828 let Some(inner_array) = inner_array else {
829 return internal_err!(
830 "Intermediate results of COUNT DISTINCT should always be non null"
831 );
832 };
833 self.update_batch(&[inner_array])?;
834 }
835 Ok(())
836 }
837
838 fn evaluate(&mut self) -> Result<ScalarValue> {
839 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
840 }
841
842 fn size(&self) -> usize {
843 match &self.state_data_type {
844 DataType::Boolean | DataType::Null => self.fixed_size(),
845 d if d.is_primitive() => self.fixed_size(),
846 _ => self.full_size(),
847 }
848 }
849}
850
851#[cfg(test)]
852mod tests {
853
854 use super::*;
855 use arrow::{
856 array::{DictionaryArray, Int32Array, NullArray, StringArray},
857 datatypes::{DataType, Field, Int32Type, Schema},
858 };
859 use datafusion_expr::function::AccumulatorArgs;
860 use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
861 use std::sync::Arc;
862 fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
868 let values = StringArray::from(vec![Some("a"), None, Some("c")]);
869 let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); Ok(DictionaryArray::<Int32Type>::try_new(
871 keys,
872 Arc::new(values),
873 )?)
874 }
875
876 #[test]
877 fn count_accumulator_nulls() -> Result<()> {
878 let mut accumulator = CountAccumulator::new();
879 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
880 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
881 Ok(())
882 }
883
884 #[test]
885 fn test_nested_dictionary() -> Result<()> {
886 let schema = Arc::new(Schema::new(vec![Field::new(
887 "dict_col",
888 DataType::Dictionary(
889 Box::new(DataType::Int32),
890 Box::new(DataType::Dictionary(
891 Box::new(DataType::Int32),
892 Box::new(DataType::Utf8),
893 )),
894 ),
895 true,
896 )]));
897
898 let count = Count::new();
900 let expr = Arc::new(Column::new("dict_col", 0));
901 let expr_field = expr.return_field(&schema)?;
902 let args = AccumulatorArgs {
903 schema: &schema,
904 expr_fields: &[expr_field],
905 exprs: &[expr],
906 is_distinct: true,
907 name: "count",
908 ignore_nulls: false,
909 is_reversed: false,
910 return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
911 order_bys: &[],
912 };
913
914 let inner_dict =
915 DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
916
917 let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
918 let dict_of_dict =
919 DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
920
921 let mut acc = count.accumulator(args)?;
922 acc.update_batch(&[Arc::new(dict_of_dict)])?;
923 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
924
925 Ok(())
926 }
927
928 #[test]
929 fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
930 let dict_array = create_dictionary_with_null_values()?;
931
932 let mut accumulator = DistinctCountAccumulator {
935 values: HashSet::default(),
936 state_data_type: dict_array.data_type().clone(),
937 };
938
939 accumulator.update_batch(&[Arc::new(dict_array)])?;
940
941 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
943 Ok(())
944 }
945
946 #[test]
947 fn count_accumulator_dictionary_with_null_values() -> Result<()> {
948 let dict_array = create_dictionary_with_null_values()?;
949
950 let mut accumulator = CountAccumulator::new();
952
953 accumulator.update_batch(&[Arc::new(dict_array)])?;
954
955 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
958 Ok(())
959 }
960
961 #[test]
962 fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
963 let dict_values = StringArray::from(vec![None, Some("abc")]);
965 let dict_indices = Int32Array::from(vec![0; 5]);
966 let dict_array =
967 DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
968
969 let mut accumulator = DistinctCountAccumulator {
970 values: HashSet::default(),
971 state_data_type: dict_array.data_type().clone(),
972 };
973
974 accumulator.update_batch(&[Arc::new(dict_array)])?;
975
976 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
978 Ok(())
979 }
980
981 #[test]
982 fn sliding_distinct_count_accumulator_basic() -> Result<()> {
983 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
985 let values: ArrayRef = Arc::new(Int32Array::from(vec![
987 Some(1),
988 Some(2),
989 Some(2),
990 Some(3),
991 None,
992 ]));
993 acc.update_batch(&[values])?;
994 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
996 Ok(())
997 }
998
999 #[test]
1000 fn sliding_distinct_count_accumulator_retract() -> Result<()> {
1001 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
1003 let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1005 as ArrayRef;
1006 acc.update_batch(&[arr1])?;
1007 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); let arr2 =
1011 Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1012 acc.retract_batch(&[arr2])?;
1013 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1015 Ok(())
1016 }
1017
1018 #[test]
1019 fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1020 let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1022 let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1023 acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1025 acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1027 let state_sv1 = acc1.state()?;
1029 let state_sv2 = acc2.state()?;
1030 let state_arr1: Vec<ArrayRef> = state_sv1
1033 .into_iter()
1034 .map(|sv| sv.to_array())
1035 .collect::<Result<_>>()?;
1036 let state_arr2: Vec<ArrayRef> = state_sv2
1037 .into_iter()
1038 .map(|sv| sv.to_array())
1039 .collect::<Result<_>>()?;
1040 let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1042 merged.merge_batch(&state_arr1)?;
1043 merged.merge_batch(&state_arr2)?;
1044 assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1046 Ok(())
1047 }
1048}