datafusion_functions_aggregate_common/aggregate/count_distinct/
dict.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, BooleanArray};
19use arrow::downcast_dictionary_array;
20use datafusion_common::{arrow_datafusion_err, ScalarValue};
21use datafusion_common::{internal_err, DataFusionError};
22use datafusion_expr_common::accumulator::Accumulator;
23
24#[derive(Debug)]
25pub struct DictionaryCountAccumulator {
26    inner: Box<dyn Accumulator>,
27}
28
29impl DictionaryCountAccumulator {
30    pub fn new(inner: Box<dyn Accumulator>) -> Self {
31        Self { inner }
32    }
33}
34
35impl Accumulator for DictionaryCountAccumulator {
36    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
37        let values: Vec<_> = values
38            .iter()
39            .map(|dict| {
40                downcast_dictionary_array! {
41                    dict => {
42                        let buff: BooleanArray = dict.occupancy().into();
43                        arrow::compute::filter(
44                            dict.values(),
45                            &buff
46                        ).map_err(|e| arrow_datafusion_err!(e))
47                    },
48                    _ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays")
49                }
50            })
51            .collect::<Result<Vec<_>, _>>()?;
52        self.inner.update_batch(values.as_slice())
53    }
54
55    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
56        self.inner.evaluate()
57    }
58
59    fn size(&self) -> usize {
60        self.inner.size()
61    }
62
63    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
64        self.inner.state()
65    }
66
67    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
68        self.inner.merge_batch(states)
69    }
70}