1use ahash::RandomState;
19use arrow::{
20 array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21 buffer::BooleanBuffer,
22 compute,
23 datatypes::{
24 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25 FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
26 Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29 UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30 },
31};
32use datafusion_common::{
33 downcast_value, internal_err, not_impl_err, stats::Precision,
34 utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue,
35};
36use datafusion_expr::{
37 expr::WindowFunction,
38 function::{AccumulatorArgs, StateFieldsArgs},
39 utils::format_state_name,
40 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
41 ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
42 WindowFunctionDefinition,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45 count_distinct::BytesDistinctCountAccumulator,
46 count_distinct::BytesViewDistinctCountAccumulator,
47 count_distinct::DictionaryCountAccumulator,
48 count_distinct::FloatDistinctCountAccumulator,
49 count_distinct::PrimitiveDistinctCountAccumulator,
50 groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56 collections::HashSet,
57 fmt::Debug,
58 mem::{size_of, size_of_val},
59 ops::BitAnd,
60 sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64 Count,
65 count,
66 expr,
67 "Count the number of non-null values in the column",
68 count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73 count_udaf(),
74 vec![expr],
75 true,
76 None,
77 vec![],
78 None,
79 ))
80}
81
82pub fn count_all() -> Expr {
100 count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103pub fn count_all_window() -> Expr {
123 Expr::from(WindowFunction::new(
124 WindowFunctionDefinition::AggregateUDF(count_udaf()),
125 vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126 ))
127}
128
129#[user_doc(
130 doc_section(label = "General Functions"),
131 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132 syntax_example = "count(expression)",
133 sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name) |
137+-----------------------+
138| 100 |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*) |
144+------------------+
145| 120 |
146+------------------+
147```"#,
148 standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash)]
151pub struct Count {
152 signature: Signature,
153}
154
155impl Debug for Count {
156 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157 f.debug_struct("Count")
158 .field("name", &self.name())
159 .field("signature", &self.signature)
160 .finish()
161 }
162}
163
164impl Default for Count {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl Count {
171 pub fn new() -> Self {
172 Self {
173 signature: Signature::one_of(
174 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
175 Volatility::Immutable,
176 ),
177 }
178 }
179}
180fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
181 match data_type {
182 DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
184 data_type,
185 )),
186 DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
187 data_type,
188 )),
189 DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
190 data_type,
191 )),
192 DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
193 data_type,
194 )),
195 DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
196 data_type,
197 )),
198 DataType::UInt16 => Box::new(
199 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
200 ),
201 DataType::UInt32 => Box::new(
202 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
203 ),
204 DataType::UInt64 => Box::new(
205 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
206 ),
207 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
208 Decimal128Type,
209 >::new(data_type)),
210 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
211 Decimal256Type,
212 >::new(data_type)),
213
214 DataType::Date32 => Box::new(
215 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
216 ),
217 DataType::Date64 => Box::new(
218 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
219 ),
220 DataType::Time32(TimeUnit::Millisecond) => Box::new(
221 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
222 ),
223 DataType::Time32(TimeUnit::Second) => Box::new(
224 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
225 ),
226 DataType::Time64(TimeUnit::Microsecond) => Box::new(
227 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
228 ),
229 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
230 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
231 ),
232 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
233 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
234 ),
235 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
236 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
237 ),
238 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
239 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
240 ),
241 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
242 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
243 ),
244
245 DataType::Float16 => {
246 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
247 }
248 DataType::Float32 => {
249 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
250 }
251 DataType::Float64 => {
252 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
253 }
254
255 DataType::Utf8 => {
256 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
257 }
258 DataType::Utf8View => {
259 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
260 }
261 DataType::LargeUtf8 => {
262 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
263 }
264 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
265 OutputType::Binary,
266 )),
267 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
268 OutputType::BinaryView,
269 )),
270 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
271 OutputType::Binary,
272 )),
273
274 _ => Box::new(DistinctCountAccumulator {
276 values: HashSet::default(),
277 state_data_type: data_type.clone(),
278 }),
279 }
280}
281
282impl AggregateUDFImpl for Count {
283 fn as_any(&self) -> &dyn std::any::Any {
284 self
285 }
286
287 fn name(&self) -> &str {
288 "count"
289 }
290
291 fn signature(&self) -> &Signature {
292 &self.signature
293 }
294
295 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296 Ok(DataType::Int64)
297 }
298
299 fn is_nullable(&self) -> bool {
300 false
301 }
302
303 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
304 if args.is_distinct {
305 let dtype: DataType = match &args.input_fields[0].data_type() {
306 DataType::Dictionary(_, values_type) => (**values_type).clone(),
307 &dtype => dtype.clone(),
308 };
309
310 Ok(vec![Field::new_list(
311 format_state_name(args.name, "count distinct"),
312 Field::new_list_field(dtype, true),
314 false,
315 )
316 .into()])
317 } else {
318 Ok(vec![Field::new(
319 format_state_name(args.name, "count"),
320 DataType::Int64,
321 false,
322 )
323 .into()])
324 }
325 }
326
327 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
328 if !acc_args.is_distinct {
329 return Ok(Box::new(CountAccumulator::new()));
330 }
331
332 if acc_args.exprs.len() > 1 {
333 return not_impl_err!("COUNT DISTINCT with multiple arguments");
334 }
335
336 let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?;
337
338 Ok(match data_type {
339 DataType::Dictionary(_, values_type) => {
340 let inner = get_count_accumulator(values_type);
341 Box::new(DictionaryCountAccumulator::new(inner))
342 }
343 _ => get_count_accumulator(data_type),
344 })
345 }
346
347 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
348 if args.is_distinct {
351 return false;
352 }
353 args.exprs.len() == 1
354 }
355
356 fn create_groups_accumulator(
357 &self,
358 _args: AccumulatorArgs,
359 ) -> Result<Box<dyn GroupsAccumulator>> {
360 Ok(Box::new(CountGroupsAccumulator::new()))
362 }
363
364 fn reverse_expr(&self) -> ReversedUDAF {
365 ReversedUDAF::Identical
366 }
367
368 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
369 Ok(ScalarValue::Int64(Some(0)))
370 }
371
372 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
373 if statistics_args.is_distinct {
374 return None;
375 }
376 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
377 if statistics_args.exprs.len() == 1 {
378 if let Some(col_expr) = statistics_args.exprs[0]
380 .as_any()
381 .downcast_ref::<expressions::Column>()
382 {
383 let current_val = &statistics_args.statistics.column_statistics
384 [col_expr.index()]
385 .null_count;
386 if let &Precision::Exact(val) = current_val {
387 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
388 }
389 } else if let Some(lit_expr) = statistics_args.exprs[0]
390 .as_any()
391 .downcast_ref::<expressions::Literal>()
392 {
393 if lit_expr.value() == &COUNT_STAR_EXPANSION {
394 return Some(ScalarValue::Int64(Some(num_rows as i64)));
395 }
396 }
397 }
398 }
399 None
400 }
401
402 fn documentation(&self) -> Option<&Documentation> {
403 self.doc()
404 }
405
406 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
407 SetMonotonicity::Increasing
410 }
411
412 fn create_sliding_accumulator(
413 &self,
414 args: AccumulatorArgs,
415 ) -> Result<Box<dyn Accumulator>> {
416 if args.is_distinct {
417 let acc =
418 SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
419 Ok(Box::new(acc))
420 } else {
421 let acc = CountAccumulator::new();
422 Ok(Box::new(acc))
423 }
424 }
425}
426
427#[derive(Debug)]
431pub struct SlidingDistinctCountAccumulator {
432 counts: HashMap<ScalarValue, usize, RandomState>,
433 data_type: DataType,
434}
435
436impl SlidingDistinctCountAccumulator {
437 pub fn try_new(data_type: &DataType) -> Result<Self> {
438 Ok(Self {
439 counts: HashMap::default(),
440 data_type: data_type.clone(),
441 })
442 }
443}
444
445impl Accumulator for SlidingDistinctCountAccumulator {
446 fn state(&mut self) -> Result<Vec<ScalarValue>> {
447 let keys = self.counts.keys().cloned().collect::<Vec<_>>();
448 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
449 keys.as_slice(),
450 &self.data_type,
451 ))])
452 }
453
454 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
455 let arr = &values[0];
456 for i in 0..arr.len() {
457 let v = ScalarValue::try_from_array(arr, i)?;
458 if !v.is_null() {
459 *self.counts.entry(v).or_default() += 1;
460 }
461 }
462 Ok(())
463 }
464
465 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
466 let arr = &values[0];
467 for i in 0..arr.len() {
468 let v = ScalarValue::try_from_array(arr, i)?;
469 if !v.is_null() {
470 if let Some(cnt) = self.counts.get_mut(&v) {
471 *cnt -= 1;
472 if *cnt == 0 {
473 self.counts.remove(&v);
474 }
475 }
476 }
477 }
478 Ok(())
479 }
480
481 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
482 let list_arr = states[0].as_list::<i32>();
483 for inner in list_arr.iter().flatten() {
484 for j in 0..inner.len() {
485 let v = ScalarValue::try_from_array(&*inner, j)?;
486 *self.counts.entry(v).or_default() += 1;
487 }
488 }
489 Ok(())
490 }
491
492 fn evaluate(&mut self) -> Result<ScalarValue> {
493 Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
494 }
495
496 fn supports_retract_batch(&self) -> bool {
497 true
498 }
499
500 fn size(&self) -> usize {
501 size_of_val(self)
502 }
503}
504
505#[derive(Debug)]
506struct CountAccumulator {
507 count: i64,
508}
509
510impl CountAccumulator {
511 pub fn new() -> Self {
513 Self { count: 0 }
514 }
515}
516
517impl Accumulator for CountAccumulator {
518 fn state(&mut self) -> Result<Vec<ScalarValue>> {
519 Ok(vec![ScalarValue::Int64(Some(self.count))])
520 }
521
522 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523 let array = &values[0];
524 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
525 Ok(())
526 }
527
528 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
529 let array = &values[0];
530 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
531 Ok(())
532 }
533
534 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
535 let counts = downcast_value!(states[0], Int64Array);
536 let delta = &compute::sum(counts);
537 if let Some(d) = delta {
538 self.count += *d;
539 }
540 Ok(())
541 }
542
543 fn evaluate(&mut self) -> Result<ScalarValue> {
544 Ok(ScalarValue::Int64(Some(self.count)))
545 }
546
547 fn supports_retract_batch(&self) -> bool {
548 true
549 }
550
551 fn size(&self) -> usize {
552 size_of_val(self)
553 }
554}
555
556#[derive(Debug)]
563struct CountGroupsAccumulator {
564 counts: Vec<i64>,
571}
572
573impl CountGroupsAccumulator {
574 pub fn new() -> Self {
575 Self { counts: vec![] }
576 }
577}
578
579impl GroupsAccumulator for CountGroupsAccumulator {
580 fn update_batch(
581 &mut self,
582 values: &[ArrayRef],
583 group_indices: &[usize],
584 opt_filter: Option<&BooleanArray>,
585 total_num_groups: usize,
586 ) -> Result<()> {
587 assert_eq!(values.len(), 1, "single argument to update_batch");
588 let values = &values[0];
589
590 self.counts.resize(total_num_groups, 0);
593 accumulate_indices(
594 group_indices,
595 values.logical_nulls().as_ref(),
596 opt_filter,
597 |group_index| {
598 self.counts[group_index] += 1;
599 },
600 );
601
602 Ok(())
603 }
604
605 fn merge_batch(
606 &mut self,
607 values: &[ArrayRef],
608 group_indices: &[usize],
609 _opt_filter: Option<&BooleanArray>,
611 total_num_groups: usize,
612 ) -> Result<()> {
613 assert_eq!(values.len(), 1, "one argument to merge_batch");
614 let partial_counts = values[0].as_primitive::<Int64Type>();
616
617 assert_eq!(partial_counts.null_count(), 0);
619 let partial_counts = partial_counts.values();
620
621 self.counts.resize(total_num_groups, 0);
623 group_indices.iter().zip(partial_counts.iter()).for_each(
624 |(&group_index, partial_count)| {
625 self.counts[group_index] += partial_count;
626 },
627 );
628
629 Ok(())
630 }
631
632 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
633 let counts = emit_to.take_needed(&mut self.counts);
634
635 let nulls = None;
637 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
638
639 Ok(Arc::new(array))
640 }
641
642 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
644 let counts = emit_to.take_needed(&mut self.counts);
645 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
647 }
648
649 fn convert_to_state(
655 &self,
656 values: &[ArrayRef],
657 opt_filter: Option<&BooleanArray>,
658 ) -> Result<Vec<ArrayRef>> {
659 let values = &values[0];
660
661 let state_array = match (values.logical_nulls(), opt_filter) {
662 (None, None) => {
663 Arc::new(Int64Array::from_value(1, values.len()))
665 }
666 (Some(nulls), None) => {
667 let nulls = BooleanArray::new(nulls.into_inner(), None);
670 compute::cast(&nulls, &DataType::Int64)?
671 }
672 (None, Some(filter)) => {
673 let (filter_values, filter_nulls) = filter.clone().into_parts();
678
679 let state_buf = match filter_nulls {
680 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
681 None => filter_values,
682 };
683
684 let boolean_state = BooleanArray::new(state_buf, None);
685 compute::cast(&boolean_state, &DataType::Int64)?
686 }
687 (Some(nulls), Some(filter)) => {
688 let (filter_values, filter_nulls) = filter.clone().into_parts();
695
696 let filter_buf = match filter_nulls {
697 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
698 None => filter_values,
699 };
700 let state_buf = &filter_buf & nulls.inner();
701
702 let boolean_state = BooleanArray::new(state_buf, None);
703 compute::cast(&boolean_state, &DataType::Int64)?
704 }
705 };
706
707 Ok(vec![state_array])
708 }
709
710 fn supports_convert_to_state(&self) -> bool {
711 true
712 }
713
714 fn size(&self) -> usize {
715 self.counts.capacity() * size_of::<usize>()
716 }
717}
718
719fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
722 if values.len() > 1 {
723 let result_bool_buf: Option<BooleanBuffer> = values
724 .iter()
725 .map(|a| a.logical_nulls())
726 .fold(None, |acc, b| match (acc, b) {
727 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
728 (Some(acc), None) => Some(acc),
729 (None, Some(b)) => Some(b.into_inner()),
730 _ => None,
731 });
732 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
733 } else {
734 values[0]
735 .logical_nulls()
736 .map_or(0, |nulls| nulls.null_count())
737 }
738}
739
740#[derive(Debug)]
749struct DistinctCountAccumulator {
750 values: HashSet<ScalarValue, RandomState>,
751 state_data_type: DataType,
752}
753
754impl DistinctCountAccumulator {
755 fn fixed_size(&self) -> usize {
759 size_of_val(self)
760 + (size_of::<ScalarValue>() * self.values.capacity())
761 + self
762 .values
763 .iter()
764 .next()
765 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
766 .unwrap_or(0)
767 + size_of::<DataType>()
768 }
769
770 fn full_size(&self) -> usize {
773 size_of_val(self)
774 + (size_of::<ScalarValue>() * self.values.capacity())
775 + self
776 .values
777 .iter()
778 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
779 .sum::<usize>()
780 + size_of::<DataType>()
781 }
782}
783
784impl Accumulator for DistinctCountAccumulator {
785 fn state(&mut self) -> Result<Vec<ScalarValue>> {
787 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
788 let arr =
789 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
790 Ok(vec![ScalarValue::List(arr)])
791 }
792
793 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
794 if values.is_empty() {
795 return Ok(());
796 }
797
798 let arr = &values[0];
799 if arr.data_type() == &DataType::Null {
800 return Ok(());
801 }
802
803 (0..arr.len()).try_for_each(|index| {
804 let scalar = ScalarValue::try_from_array(arr, index)?;
805 if !scalar.is_null() {
806 self.values.insert(scalar);
807 }
808 Ok(())
809 })
810 }
811
812 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
818 if states.is_empty() {
819 return Ok(());
820 }
821 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
822 let array = &states[0];
823 let list_array = array.as_list::<i32>();
824 for inner_array in list_array.iter() {
825 let Some(inner_array) = inner_array else {
826 return internal_err!(
827 "Intermediate results of COUNT DISTINCT should always be non null"
828 );
829 };
830 self.update_batch(&[inner_array])?;
831 }
832 Ok(())
833 }
834
835 fn evaluate(&mut self) -> Result<ScalarValue> {
836 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
837 }
838
839 fn size(&self) -> usize {
840 match &self.state_data_type {
841 DataType::Boolean | DataType::Null => self.fixed_size(),
842 d if d.is_primitive() => self.fixed_size(),
843 _ => self.full_size(),
844 }
845 }
846}
847
848#[cfg(test)]
849mod tests {
850
851 use super::*;
852 use arrow::{
853 array::{DictionaryArray, Int32Array, NullArray, StringArray},
854 datatypes::{DataType, Field, Int32Type, Schema},
855 };
856 use datafusion_expr::function::AccumulatorArgs;
857 use datafusion_physical_expr::expressions::Column;
858 use std::sync::Arc;
859 fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
865 let values = StringArray::from(vec![Some("a"), None, Some("c")]);
866 let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); Ok(DictionaryArray::<Int32Type>::try_new(
868 keys,
869 Arc::new(values),
870 )?)
871 }
872
873 #[test]
874 fn count_accumulator_nulls() -> Result<()> {
875 let mut accumulator = CountAccumulator::new();
876 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
877 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
878 Ok(())
879 }
880
881 #[test]
882 fn test_nested_dictionary() -> Result<()> {
883 let schema = Arc::new(Schema::new(vec![Field::new(
884 "dict_col",
885 DataType::Dictionary(
886 Box::new(DataType::Int32),
887 Box::new(DataType::Dictionary(
888 Box::new(DataType::Int32),
889 Box::new(DataType::Utf8),
890 )),
891 ),
892 true,
893 )]));
894
895 let count = Count::new();
897 let expr = Arc::new(Column::new("dict_col", 0));
898 let args = AccumulatorArgs {
899 schema: &schema,
900 exprs: &[expr],
901 is_distinct: true,
902 name: "count",
903 ignore_nulls: false,
904 is_reversed: false,
905 return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
906 order_bys: &[],
907 };
908
909 let inner_dict =
910 DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
911
912 let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
913 let dict_of_dict =
914 DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
915
916 let mut acc = count.accumulator(args)?;
917 acc.update_batch(&[Arc::new(dict_of_dict)])?;
918 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
919
920 Ok(())
921 }
922
923 #[test]
924 fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
925 let dict_array = create_dictionary_with_null_values()?;
926
927 let mut accumulator = DistinctCountAccumulator {
930 values: HashSet::default(),
931 state_data_type: dict_array.data_type().clone(),
932 };
933
934 accumulator.update_batch(&[Arc::new(dict_array)])?;
935
936 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
938 Ok(())
939 }
940
941 #[test]
942 fn count_accumulator_dictionary_with_null_values() -> Result<()> {
943 let dict_array = create_dictionary_with_null_values()?;
944
945 let mut accumulator = CountAccumulator::new();
947
948 accumulator.update_batch(&[Arc::new(dict_array)])?;
949
950 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
953 Ok(())
954 }
955
956 #[test]
957 fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
958 let dict_values = StringArray::from(vec![None, Some("abc")]);
960 let dict_indices = Int32Array::from(vec![0; 5]);
961 let dict_array =
962 DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
963
964 let mut accumulator = DistinctCountAccumulator {
965 values: HashSet::default(),
966 state_data_type: dict_array.data_type().clone(),
967 };
968
969 accumulator.update_batch(&[Arc::new(dict_array)])?;
970
971 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
973 Ok(())
974 }
975
976 #[test]
977 fn sliding_distinct_count_accumulator_basic() -> Result<()> {
978 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
980 let values: ArrayRef = Arc::new(Int32Array::from(vec![
982 Some(1),
983 Some(2),
984 Some(2),
985 Some(3),
986 None,
987 ]));
988 acc.update_batch(&[values])?;
989 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
991 Ok(())
992 }
993
994 #[test]
995 fn sliding_distinct_count_accumulator_retract() -> Result<()> {
996 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
998 let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
1000 as ArrayRef;
1001 acc.update_batch(&[arr1])?;
1002 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); let arr2 =
1006 Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1007 acc.retract_batch(&[arr2])?;
1008 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1010 Ok(())
1011 }
1012
1013 #[test]
1014 fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1015 let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1017 let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1018 acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1020 acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1022 let state_sv1 = acc1.state()?;
1024 let state_sv2 = acc2.state()?;
1025 let state_arr1: Vec<ArrayRef> = state_sv1
1028 .into_iter()
1029 .map(|sv| sv.to_array())
1030 .collect::<Result<_>>()?;
1031 let state_arr2: Vec<ArrayRef> = state_sv2
1032 .into_iter()
1033 .map(|sv| sv.to_array())
1034 .collect::<Result<_>>()?;
1035 let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1037 merged.merge_batch(&state_arr1)?;
1038 merged.merge_batch(&state_arr2)?;
1039 assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1041 Ok(())
1042 }
1043}