1use ahash::RandomState;
19use arrow::{
20 array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
21 buffer::BooleanBuffer,
22 compute,
23 datatypes::{
24 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
25 FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
26 Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
27 Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
28 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
29 UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30 },
31};
32use datafusion_common::{
33 HashMap, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
34 stats::Precision, utils::expr::COUNT_STAR_EXPANSION,
35};
36use datafusion_expr::{
37 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
38 ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility,
39 WindowFunctionDefinition,
40 expr::WindowFunction,
41 function::{AccumulatorArgs, StateFieldsArgs},
42 utils::format_state_name,
43};
44use datafusion_functions_aggregate_common::aggregate::{
45 count_distinct::BytesDistinctCountAccumulator,
46 count_distinct::BytesViewDistinctCountAccumulator,
47 count_distinct::DictionaryCountAccumulator,
48 count_distinct::FloatDistinctCountAccumulator,
49 count_distinct::PrimitiveDistinctCountAccumulator,
50 groups_accumulator::accumulate::accumulate_indices,
51};
52use datafusion_macros::user_doc;
53use datafusion_physical_expr::expressions;
54use datafusion_physical_expr_common::binary_map::OutputType;
55use std::{
56 collections::HashSet,
57 fmt::Debug,
58 mem::{size_of, size_of_val},
59 ops::BitAnd,
60 sync::Arc,
61};
62
63make_udaf_expr_and_func!(
64 Count,
65 count,
66 expr,
67 "Count the number of non-null values in the column",
68 count_udaf
69);
70
71pub fn count_distinct(expr: Expr) -> Expr {
72 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
73 count_udaf(),
74 vec![expr],
75 true,
76 None,
77 vec![],
78 None,
79 ))
80}
81
82pub fn count_all() -> Expr {
100 count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)")
101}
102
103pub fn count_all_window() -> Expr {
123 Expr::from(WindowFunction::new(
124 WindowFunctionDefinition::AggregateUDF(count_udaf()),
125 vec![Expr::Literal(COUNT_STAR_EXPANSION, None)],
126 ))
127}
128
129#[user_doc(
130 doc_section(label = "General Functions"),
131 description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",
132 syntax_example = "count(expression)",
133 sql_example = r#"```sql
134> SELECT count(column_name) FROM table_name;
135+-----------------------+
136| count(column_name) |
137+-----------------------+
138| 100 |
139+-----------------------+
140
141> SELECT count(*) FROM table_name;
142+------------------+
143| count(*) |
144+------------------+
145| 120 |
146+------------------+
147```"#,
148 standard_argument(name = "expression",)
149)]
150#[derive(PartialEq, Eq, Hash, Debug)]
151pub struct Count {
152 signature: Signature,
153}
154
155impl Default for Count {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl Count {
162 pub fn new() -> Self {
163 Self {
164 signature: Signature::one_of(
165 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
166 Volatility::Immutable,
167 ),
168 }
169 }
170}
171fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> {
172 match data_type {
173 DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
175 data_type,
176 )),
177 DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
178 data_type,
179 )),
180 DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
181 data_type,
182 )),
183 DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
184 data_type,
185 )),
186 DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
187 data_type,
188 )),
189 DataType::UInt16 => Box::new(
190 PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
191 ),
192 DataType::UInt32 => Box::new(
193 PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
194 ),
195 DataType::UInt64 => Box::new(
196 PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
197 ),
198 DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
199 Decimal128Type,
200 >::new(data_type)),
201 DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
202 Decimal256Type,
203 >::new(data_type)),
204
205 DataType::Date32 => Box::new(
206 PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
207 ),
208 DataType::Date64 => Box::new(
209 PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
210 ),
211 DataType::Time32(TimeUnit::Millisecond) => Box::new(
212 PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type),
213 ),
214 DataType::Time32(TimeUnit::Second) => Box::new(
215 PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
216 ),
217 DataType::Time64(TimeUnit::Microsecond) => Box::new(
218 PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type),
219 ),
220 DataType::Time64(TimeUnit::Nanosecond) => Box::new(
221 PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
222 ),
223 DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
224 PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type),
225 ),
226 DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
227 PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type),
228 ),
229 DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
230 PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type),
231 ),
232 DataType::Timestamp(TimeUnit::Second, _) => Box::new(
233 PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
234 ),
235
236 DataType::Float16 => {
237 Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
238 }
239 DataType::Float32 => {
240 Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
241 }
242 DataType::Float64 => {
243 Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
244 }
245
246 DataType::Utf8 => {
247 Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
248 }
249 DataType::Utf8View => {
250 Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
251 }
252 DataType::LargeUtf8 => {
253 Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
254 }
255 DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
256 OutputType::Binary,
257 )),
258 DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
259 OutputType::BinaryView,
260 )),
261 DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
262 OutputType::Binary,
263 )),
264
265 _ => Box::new(DistinctCountAccumulator {
267 values: HashSet::default(),
268 state_data_type: data_type.clone(),
269 }),
270 }
271}
272
273impl AggregateUDFImpl for Count {
274 fn as_any(&self) -> &dyn std::any::Any {
275 self
276 }
277
278 fn name(&self) -> &str {
279 "count"
280 }
281
282 fn signature(&self) -> &Signature {
283 &self.signature
284 }
285
286 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
287 Ok(DataType::Int64)
288 }
289
290 fn is_nullable(&self) -> bool {
291 false
292 }
293
294 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
295 if args.is_distinct {
296 let dtype: DataType = match &args.input_fields[0].data_type() {
297 DataType::Dictionary(_, values_type) => (**values_type).clone(),
298 &dtype => dtype.clone(),
299 };
300
301 Ok(vec![
302 Field::new_list(
303 format_state_name(args.name, "count distinct"),
304 Field::new_list_field(dtype, true),
306 false,
307 )
308 .into(),
309 ])
310 } else {
311 Ok(vec![
312 Field::new(
313 format_state_name(args.name, "count"),
314 DataType::Int64,
315 false,
316 )
317 .into(),
318 ])
319 }
320 }
321
322 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
323 if !acc_args.is_distinct {
324 return Ok(Box::new(CountAccumulator::new()));
325 }
326
327 if acc_args.exprs.len() > 1 {
328 return not_impl_err!("COUNT DISTINCT with multiple arguments");
329 }
330
331 let data_type = acc_args.expr_fields[0].data_type();
332
333 Ok(match data_type {
334 DataType::Dictionary(_, values_type) => {
335 let inner = get_count_accumulator(values_type);
336 Box::new(DictionaryCountAccumulator::new(inner))
337 }
338 _ => get_count_accumulator(data_type),
339 })
340 }
341
342 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
343 if args.is_distinct {
346 return false;
347 }
348 args.exprs.len() == 1
349 }
350
351 fn create_groups_accumulator(
352 &self,
353 _args: AccumulatorArgs,
354 ) -> Result<Box<dyn GroupsAccumulator>> {
355 Ok(Box::new(CountGroupsAccumulator::new()))
357 }
358
359 fn reverse_expr(&self) -> ReversedUDAF {
360 ReversedUDAF::Identical
361 }
362
363 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
364 Ok(ScalarValue::Int64(Some(0)))
365 }
366
367 fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
368 if statistics_args.is_distinct {
369 return None;
370 }
371 if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows
372 && statistics_args.exprs.len() == 1
373 {
374 if let Some(col_expr) = statistics_args.exprs[0]
376 .as_any()
377 .downcast_ref::<expressions::Column>()
378 {
379 let current_val = &statistics_args.statistics.column_statistics
380 [col_expr.index()]
381 .null_count;
382 if let &Precision::Exact(val) = current_val {
383 return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
384 }
385 } else if let Some(lit_expr) = statistics_args.exprs[0]
386 .as_any()
387 .downcast_ref::<expressions::Literal>()
388 && lit_expr.value() == &COUNT_STAR_EXPANSION
389 {
390 return Some(ScalarValue::Int64(Some(num_rows as i64)));
391 }
392 }
393 None
394 }
395
396 fn documentation(&self) -> Option<&Documentation> {
397 self.doc()
398 }
399
400 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
401 SetMonotonicity::Increasing
404 }
405
406 fn create_sliding_accumulator(
407 &self,
408 args: AccumulatorArgs,
409 ) -> Result<Box<dyn Accumulator>> {
410 if args.is_distinct {
411 let acc =
412 SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?;
413 Ok(Box::new(acc))
414 } else {
415 let acc = CountAccumulator::new();
416 Ok(Box::new(acc))
417 }
418 }
419}
420
421#[derive(Debug)]
425pub struct SlidingDistinctCountAccumulator {
426 counts: HashMap<ScalarValue, usize, RandomState>,
427 data_type: DataType,
428}
429
430impl SlidingDistinctCountAccumulator {
431 pub fn try_new(data_type: &DataType) -> Result<Self> {
432 Ok(Self {
433 counts: HashMap::default(),
434 data_type: data_type.clone(),
435 })
436 }
437}
438
439impl Accumulator for SlidingDistinctCountAccumulator {
440 fn state(&mut self) -> Result<Vec<ScalarValue>> {
441 let keys = self.counts.keys().cloned().collect::<Vec<_>>();
442 Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable(
443 keys.as_slice(),
444 &self.data_type,
445 ))])
446 }
447
448 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
449 let arr = &values[0];
450 for i in 0..arr.len() {
451 let v = ScalarValue::try_from_array(arr, i)?;
452 if !v.is_null() {
453 *self.counts.entry(v).or_default() += 1;
454 }
455 }
456 Ok(())
457 }
458
459 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
460 let arr = &values[0];
461 for i in 0..arr.len() {
462 let v = ScalarValue::try_from_array(arr, i)?;
463 if !v.is_null()
464 && let Some(cnt) = self.counts.get_mut(&v)
465 {
466 *cnt -= 1;
467 if *cnt == 0 {
468 self.counts.remove(&v);
469 }
470 }
471 }
472 Ok(())
473 }
474
475 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
476 let list_arr = states[0].as_list::<i32>();
477 for inner in list_arr.iter().flatten() {
478 for j in 0..inner.len() {
479 let v = ScalarValue::try_from_array(&*inner, j)?;
480 *self.counts.entry(v).or_default() += 1;
481 }
482 }
483 Ok(())
484 }
485
486 fn evaluate(&mut self) -> Result<ScalarValue> {
487 Ok(ScalarValue::Int64(Some(self.counts.len() as i64)))
488 }
489
490 fn supports_retract_batch(&self) -> bool {
491 true
492 }
493
494 fn size(&self) -> usize {
495 size_of_val(self)
496 }
497}
498
499#[derive(Debug)]
500struct CountAccumulator {
501 count: i64,
502}
503
504impl CountAccumulator {
505 pub fn new() -> Self {
507 Self { count: 0 }
508 }
509}
510
511impl Accumulator for CountAccumulator {
512 fn state(&mut self) -> Result<Vec<ScalarValue>> {
513 Ok(vec![ScalarValue::Int64(Some(self.count))])
514 }
515
516 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517 let array = &values[0];
518 self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
519 Ok(())
520 }
521
522 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
523 let array = &values[0];
524 self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
525 Ok(())
526 }
527
528 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
529 let counts = downcast_value!(states[0], Int64Array);
530 let delta = &compute::sum(counts);
531 if let Some(d) = delta {
532 self.count += *d;
533 }
534 Ok(())
535 }
536
537 fn evaluate(&mut self) -> Result<ScalarValue> {
538 Ok(ScalarValue::Int64(Some(self.count)))
539 }
540
541 fn supports_retract_batch(&self) -> bool {
542 true
543 }
544
545 fn size(&self) -> usize {
546 size_of_val(self)
547 }
548}
549
550#[derive(Debug)]
557struct CountGroupsAccumulator {
558 counts: Vec<i64>,
565}
566
567impl CountGroupsAccumulator {
568 pub fn new() -> Self {
569 Self { counts: vec![] }
570 }
571}
572
573impl GroupsAccumulator for CountGroupsAccumulator {
574 fn update_batch(
575 &mut self,
576 values: &[ArrayRef],
577 group_indices: &[usize],
578 opt_filter: Option<&BooleanArray>,
579 total_num_groups: usize,
580 ) -> Result<()> {
581 assert_eq!(values.len(), 1, "single argument to update_batch");
582 let values = &values[0];
583
584 self.counts.resize(total_num_groups, 0);
587 accumulate_indices(
588 group_indices,
589 values.logical_nulls().as_ref(),
590 opt_filter,
591 |group_index| {
592 let count = unsafe { self.counts.get_unchecked_mut(group_index) };
594 *count += 1;
595 },
596 );
597
598 Ok(())
599 }
600
601 fn merge_batch(
602 &mut self,
603 values: &[ArrayRef],
604 group_indices: &[usize],
605 _opt_filter: Option<&BooleanArray>,
607 total_num_groups: usize,
608 ) -> Result<()> {
609 assert_eq!(values.len(), 1, "one argument to merge_batch");
610 let partial_counts = values[0].as_primitive::<Int64Type>();
612
613 assert_eq!(partial_counts.null_count(), 0);
615 let partial_counts = partial_counts.values();
616
617 self.counts.resize(total_num_groups, 0);
619 group_indices.iter().zip(partial_counts.iter()).for_each(
620 |(&group_index, partial_count)| {
621 self.counts[group_index] += partial_count;
622 },
623 );
624
625 Ok(())
626 }
627
628 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
629 let counts = emit_to.take_needed(&mut self.counts);
630
631 let nulls = None;
633 let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
634
635 Ok(Arc::new(array))
636 }
637
638 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
640 let counts = emit_to.take_needed(&mut self.counts);
641 let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
643 }
644
645 fn convert_to_state(
651 &self,
652 values: &[ArrayRef],
653 opt_filter: Option<&BooleanArray>,
654 ) -> Result<Vec<ArrayRef>> {
655 let values = &values[0];
656
657 let state_array = match (values.logical_nulls(), opt_filter) {
658 (None, None) => {
659 Arc::new(Int64Array::from_value(1, values.len()))
661 }
662 (Some(nulls), None) => {
663 let nulls = BooleanArray::new(nulls.into_inner(), None);
666 compute::cast(&nulls, &DataType::Int64)?
667 }
668 (None, Some(filter)) => {
669 let (filter_values, filter_nulls) = filter.clone().into_parts();
674
675 let state_buf = match filter_nulls {
676 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
677 None => filter_values,
678 };
679
680 let boolean_state = BooleanArray::new(state_buf, None);
681 compute::cast(&boolean_state, &DataType::Int64)?
682 }
683 (Some(nulls), Some(filter)) => {
684 let (filter_values, filter_nulls) = filter.clone().into_parts();
691
692 let filter_buf = match filter_nulls {
693 Some(filter_nulls) => &filter_values & filter_nulls.inner(),
694 None => filter_values,
695 };
696 let state_buf = &filter_buf & nulls.inner();
697
698 let boolean_state = BooleanArray::new(state_buf, None);
699 compute::cast(&boolean_state, &DataType::Int64)?
700 }
701 };
702
703 Ok(vec![state_array])
704 }
705
706 fn supports_convert_to_state(&self) -> bool {
707 true
708 }
709
710 fn size(&self) -> usize {
711 self.counts.capacity() * size_of::<usize>()
712 }
713}
714
715fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
718 if values.len() > 1 {
719 let result_bool_buf: Option<BooleanBuffer> = values
720 .iter()
721 .map(|a| a.logical_nulls())
722 .fold(None, |acc, b| match (acc, b) {
723 (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
724 (Some(acc), None) => Some(acc),
725 (None, Some(b)) => Some(b.into_inner()),
726 _ => None,
727 });
728 result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
729 } else {
730 values[0]
731 .logical_nulls()
732 .map_or(0, |nulls| nulls.null_count())
733 }
734}
735
736#[derive(Debug)]
745struct DistinctCountAccumulator {
746 values: HashSet<ScalarValue, RandomState>,
747 state_data_type: DataType,
748}
749
750impl DistinctCountAccumulator {
751 fn fixed_size(&self) -> usize {
755 size_of_val(self)
756 + (size_of::<ScalarValue>() * self.values.capacity())
757 + self
758 .values
759 .iter()
760 .next()
761 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
762 .unwrap_or(0)
763 + size_of::<DataType>()
764 }
765
766 fn full_size(&self) -> usize {
769 size_of_val(self)
770 + (size_of::<ScalarValue>() * self.values.capacity())
771 + self
772 .values
773 .iter()
774 .map(|vals| ScalarValue::size(vals) - size_of_val(vals))
775 .sum::<usize>()
776 + size_of::<DataType>()
777 }
778}
779
780impl Accumulator for DistinctCountAccumulator {
781 fn state(&mut self) -> Result<Vec<ScalarValue>> {
783 let scalars = self.values.iter().cloned().collect::<Vec<_>>();
784 let arr =
785 ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
786 Ok(vec![ScalarValue::List(arr)])
787 }
788
789 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
790 if values.is_empty() {
791 return Ok(());
792 }
793
794 let arr = &values[0];
795 if arr.data_type() == &DataType::Null {
796 return Ok(());
797 }
798
799 (0..arr.len()).try_for_each(|index| {
800 let scalar = ScalarValue::try_from_array(arr, index)?;
801 if !scalar.is_null() {
802 self.values.insert(scalar);
803 }
804 Ok(())
805 })
806 }
807
808 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
814 if states.is_empty() {
815 return Ok(());
816 }
817 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
818 let array = &states[0];
819 let list_array = array.as_list::<i32>();
820 for inner_array in list_array.iter() {
821 let Some(inner_array) = inner_array else {
822 return internal_err!(
823 "Intermediate results of COUNT DISTINCT should always be non null"
824 );
825 };
826 self.update_batch(&[inner_array])?;
827 }
828 Ok(())
829 }
830
831 fn evaluate(&mut self) -> Result<ScalarValue> {
832 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
833 }
834
835 fn size(&self) -> usize {
836 match &self.state_data_type {
837 DataType::Boolean | DataType::Null => self.fixed_size(),
838 d if d.is_primitive() => self.fixed_size(),
839 _ => self.full_size(),
840 }
841 }
842}
843
844#[cfg(test)]
845mod tests {
846
847 use super::*;
848 use arrow::{
849 array::{DictionaryArray, Int32Array, NullArray, StringArray},
850 datatypes::{DataType, Field, Int32Type, Schema},
851 };
852 use datafusion_expr::function::AccumulatorArgs;
853 use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
854 use std::sync::Arc;
855 fn create_dictionary_with_null_values() -> Result<DictionaryArray<Int32Type>> {
861 let values = StringArray::from(vec![Some("a"), None, Some("c")]);
862 let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); Ok(DictionaryArray::<Int32Type>::try_new(
864 keys,
865 Arc::new(values),
866 )?)
867 }
868
869 #[test]
870 fn count_accumulator_nulls() -> Result<()> {
871 let mut accumulator = CountAccumulator::new();
872 accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
873 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
874 Ok(())
875 }
876
877 #[test]
878 fn test_nested_dictionary() -> Result<()> {
879 let schema = Arc::new(Schema::new(vec![Field::new(
880 "dict_col",
881 DataType::Dictionary(
882 Box::new(DataType::Int32),
883 Box::new(DataType::Dictionary(
884 Box::new(DataType::Int32),
885 Box::new(DataType::Utf8),
886 )),
887 ),
888 true,
889 )]));
890
891 let count = Count::new();
893 let expr = Arc::new(Column::new("dict_col", 0));
894 let expr_field = expr.return_field(&schema)?;
895 let args = AccumulatorArgs {
896 schema: &schema,
897 expr_fields: &[expr_field],
898 exprs: &[expr],
899 is_distinct: true,
900 name: "count",
901 ignore_nulls: false,
902 is_reversed: false,
903 return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
904 order_bys: &[],
905 };
906
907 let inner_dict =
908 DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "d", "a", "b"]);
909
910 let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]);
911 let dict_of_dict =
912 DictionaryArray::<Int32Type>::try_new(keys, Arc::new(inner_dict))?;
913
914 let mut acc = count.accumulator(args)?;
915 acc.update_batch(&[Arc::new(dict_of_dict)])?;
916 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4)));
917
918 Ok(())
919 }
920
921 #[test]
922 fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
923 let dict_array = create_dictionary_with_null_values()?;
924
925 let mut accumulator = DistinctCountAccumulator {
928 values: HashSet::default(),
929 state_data_type: dict_array.data_type().clone(),
930 };
931
932 accumulator.update_batch(&[Arc::new(dict_array)])?;
933
934 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
936 Ok(())
937 }
938
939 #[test]
940 fn count_accumulator_dictionary_with_null_values() -> Result<()> {
941 let dict_array = create_dictionary_with_null_values()?;
942
943 let mut accumulator = CountAccumulator::new();
945
946 accumulator.update_batch(&[Arc::new(dict_array)])?;
947
948 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
951 Ok(())
952 }
953
954 #[test]
955 fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> {
956 let dict_values = StringArray::from(vec![None, Some("abc")]);
958 let dict_indices = Int32Array::from(vec![0; 5]);
959 let dict_array =
960 DictionaryArray::<Int32Type>::try_new(dict_indices, Arc::new(dict_values))?;
961
962 let mut accumulator = DistinctCountAccumulator {
963 values: HashSet::default(),
964 state_data_type: dict_array.data_type().clone(),
965 };
966
967 accumulator.update_batch(&[Arc::new(dict_array)])?;
968
969 assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
971 Ok(())
972 }
973
974 #[test]
975 fn sliding_distinct_count_accumulator_basic() -> Result<()> {
976 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
978 let values: ArrayRef = Arc::new(Int32Array::from(vec![
980 Some(1),
981 Some(2),
982 Some(2),
983 Some(3),
984 None,
985 ]));
986 acc.update_batch(&[values])?;
987 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3)));
989 Ok(())
990 }
991
992 #[test]
993 fn sliding_distinct_count_accumulator_retract() -> Result<()> {
994 let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?;
996 let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")]))
998 as ArrayRef;
999 acc.update_batch(&[arr1])?;
1000 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); let arr2 =
1004 Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef;
1005 acc.retract_batch(&[arr2])?;
1006 assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1)));
1008 Ok(())
1009 }
1010
1011 #[test]
1012 fn sliding_distinct_count_accumulator_merge_states() -> Result<()> {
1013 let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1015 let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1016 acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?;
1018 acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?;
1020 let state_sv1 = acc1.state()?;
1022 let state_sv2 = acc2.state()?;
1023 let state_arr1: Vec<ArrayRef> = state_sv1
1026 .into_iter()
1027 .map(|sv| sv.to_array())
1028 .collect::<Result<_>>()?;
1029 let state_arr2: Vec<ArrayRef> = state_sv2
1030 .into_iter()
1031 .map(|sv| sv.to_array())
1032 .collect::<Result<_>>()?;
1033 let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?;
1035 merged.merge_batch(&state_arr1)?;
1036 merged.merge_batch(&state_arr2)?;
1037 assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
1039 Ok(())
1040 }
1041}