datafusion_spark/function/bitmap/
bitmap_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, BinaryArray, BinaryViewArray, FixedSizeBinaryArray, Int64Array,
23    LargeBinaryArray, as_dictionary_array,
24};
25use arrow::datatypes::DataType::{
26    Binary, BinaryView, Dictionary, FixedSizeBinary, LargeBinary,
27};
28use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type};
29use datafusion_common::utils::take_function_args;
30use datafusion_common::{Result, internal_err};
31use datafusion_expr::{
32    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33    TypeSignatureClass, Volatility,
34};
35use datafusion_functions::downcast_arg;
36use datafusion_functions::utils::make_scalar_function;
37
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct BitmapCount {
40    signature: Signature,
41}
42
43impl Default for BitmapCount {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl BitmapCount {
50    pub fn new() -> Self {
51        Self {
52            signature: Signature::coercible(
53                vec![Coercion::new_exact(TypeSignatureClass::Binary)],
54                Volatility::Immutable,
55            ),
56        }
57    }
58}
59
60impl ScalarUDFImpl for BitmapCount {
61    fn as_any(&self) -> &dyn Any {
62        self
63    }
64
65    fn name(&self) -> &str {
66        "bitmap_count"
67    }
68
69    fn signature(&self) -> &Signature {
70        &self.signature
71    }
72
73    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74        internal_err!("return_field_from_args should be used instead")
75    }
76
77    fn return_field_from_args(
78        &self,
79        args: datafusion_expr::ReturnFieldArgs,
80    ) -> Result<FieldRef> {
81        use arrow::datatypes::Field;
82        // bitmap_count returns Int64 with the same nullability as the input
83        Ok(Arc::new(Field::new(
84            args.arg_fields[0].name(),
85            DataType::Int64,
86            args.arg_fields[0].is_nullable(),
87        )))
88    }
89
90    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
91        make_scalar_function(bitmap_count_inner, vec![])(&args.args)
92    }
93}
94
95fn binary_count_ones(opt: Option<&[u8]>) -> Option<i64> {
96    opt.map(|value| value.iter().map(|b| b.count_ones() as i64).sum())
97}
98
99macro_rules! downcast_and_count_ones {
100    ($input_array:expr, $array_type:ident) => {{
101        let arr = downcast_arg!($input_array, $array_type);
102        Ok(arr.iter().map(binary_count_ones).collect::<Int64Array>())
103    }};
104}
105
106macro_rules! downcast_dict_and_count_ones {
107    ($input_dict:expr, $key_array_type:ident) => {{
108        let dict_array = as_dictionary_array::<$key_array_type>($input_dict);
109        let array = dict_array.downcast_dict::<BinaryArray>().unwrap();
110        Ok(array
111            .into_iter()
112            .map(binary_count_ones)
113            .collect::<Int64Array>())
114    }};
115}
116
117pub fn bitmap_count_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
118    let [input_array] = take_function_args("bitmap_count", arg)?;
119
120    let res: Result<Int64Array> = match &input_array.data_type() {
121        Binary => downcast_and_count_ones!(input_array, BinaryArray),
122        BinaryView => downcast_and_count_ones!(input_array, BinaryViewArray),
123        LargeBinary => downcast_and_count_ones!(input_array, LargeBinaryArray),
124        FixedSizeBinary(_size) => {
125            downcast_and_count_ones!(input_array, FixedSizeBinaryArray)
126        }
127        Dictionary(k, v) if v.as_ref() == &Binary => match k.as_ref() {
128            DataType::Int8 => downcast_dict_and_count_ones!(input_array, Int8Type),
129            DataType::Int16 => downcast_dict_and_count_ones!(input_array, Int16Type),
130            DataType::Int32 => downcast_dict_and_count_ones!(input_array, Int32Type),
131            DataType::Int64 => downcast_dict_and_count_ones!(input_array, Int64Type),
132            data_type => {
133                internal_err!(
134                    "bitmap_count does not support Dictionary({data_type}, Binary)"
135                )
136            }
137        },
138        data_type => {
139            internal_err!("bitmap_count does not support {data_type}")
140        }
141    };
142
143    Ok(Arc::new(res?))
144}
145
146#[cfg(test)]
147mod tests {
148    use crate::function::bitmap::bitmap_count::BitmapCount;
149    use crate::function::utils::test::test_scalar_function;
150    use arrow::array::{Array, Int64Array};
151    use arrow::datatypes::DataType::Int64;
152    use arrow::datatypes::{DataType, Field};
153    use datafusion_common::config::ConfigOptions;
154    use datafusion_common::{Result, ScalarValue};
155    use datafusion_expr::ColumnarValue::Scalar;
156    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
157    use std::sync::Arc;
158
159    macro_rules! test_bitmap_count_binary_invoke {
160        ($INPUT:expr, $EXPECTED:expr) => {
161            test_scalar_function!(
162                BitmapCount::new(),
163                vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
164                $EXPECTED,
165                i64,
166                Int64,
167                Int64Array
168            );
169
170            test_scalar_function!(
171                BitmapCount::new(),
172                vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
173                $EXPECTED,
174                i64,
175                Int64,
176                Int64Array
177            );
178
179            test_scalar_function!(
180                BitmapCount::new(),
181                vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
182                $EXPECTED,
183                i64,
184                Int64,
185                Int64Array
186            );
187
188            test_scalar_function!(
189                BitmapCount::new(),
190                vec![ColumnarValue::Scalar(ScalarValue::FixedSizeBinary(
191                    $INPUT.map(|a| a.len()).unwrap_or(0) as i32,
192                    $INPUT
193                ))],
194                $EXPECTED,
195                i64,
196                Int64,
197                Int64Array
198            );
199        };
200    }
201
202    #[test]
203    fn test_bitmap_count_invoke() -> Result<()> {
204        test_bitmap_count_binary_invoke!(None::<Vec<u8>>, Ok(None));
205        test_bitmap_count_binary_invoke!(Some(vec![0x0Au8]), Ok(Some(2)));
206        test_bitmap_count_binary_invoke!(Some(vec![0xFFu8, 0xFFu8]), Ok(Some(16)));
207        test_bitmap_count_binary_invoke!(
208            Some(vec![0x0Au8, 0xB0u8, 0xCDu8]),
209            Ok(Some(10))
210        );
211        Ok(())
212    }
213
214    #[test]
215    fn test_dictionary_encoded_bitmap_count_invoke() -> Result<()> {
216        let dict = Scalar(ScalarValue::Dictionary(
217            Box::new(DataType::Int32),
218            Box::new(ScalarValue::Binary(Some(vec![0xFFu8, 0xFFu8]))),
219        ));
220
221        let arg_fields = vec![
222            Field::new(
223                "a",
224                DataType::Dictionary(
225                    Box::new(DataType::Int32),
226                    Box::new(DataType::Binary),
227                ),
228                true,
229            )
230            .into(),
231        ];
232        let args = ScalarFunctionArgs {
233            args: vec![dict.clone()],
234            arg_fields,
235            number_rows: 1,
236            return_field: Field::new("f", Int64, true).into(),
237            config_options: Arc::new(ConfigOptions::default()),
238        };
239        let udf = BitmapCount::new();
240        let actual = udf.invoke_with_args(args)?;
241        let expect = Scalar(ScalarValue::Int64(Some(16)));
242        assert_eq!(*actual.into_array(1)?, *expect.into_array(1)?);
243        Ok(())
244    }
245
246    #[test]
247    fn test_bitmap_count_nullability() -> Result<()> {
248        use datafusion_expr::ReturnFieldArgs;
249
250        let bitmap_count = BitmapCount::new();
251
252        // Test with non-nullable binary field
253        let non_nullable_field = Arc::new(Field::new("bin", DataType::Binary, false));
254
255        let result = bitmap_count.return_field_from_args(ReturnFieldArgs {
256            arg_fields: &[Arc::clone(&non_nullable_field)],
257            scalar_arguments: &[None],
258        })?;
259
260        // The result should not be nullable (same as input)
261        assert!(!result.is_nullable());
262        assert_eq!(result.data_type(), &Int64);
263
264        // Test with nullable binary field
265        let nullable_field = Arc::new(Field::new("bin", DataType::Binary, true));
266
267        let result = bitmap_count.return_field_from_args(ReturnFieldArgs {
268            arg_fields: &[Arc::clone(&nullable_field)],
269            scalar_arguments: &[None],
270        })?;
271
272        // The result should be nullable (same as input)
273        assert!(result.is_nullable());
274        assert_eq!(result.data_type(), &Int64);
275
276        Ok(())
277    }
278}