Skip to main content

datafusion_spark/function/bitmap/
bitmap_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    Array, ArrayRef, BinaryArray, BinaryViewArray, FixedSizeBinaryArray, Int64Array,
22    LargeBinaryArray, as_dictionary_array,
23};
24use arrow::datatypes::DataType::{
25    Binary, BinaryView, Dictionary, FixedSizeBinary, LargeBinary,
26};
27use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{Result, internal_err};
30use datafusion_expr::{
31    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    TypeSignatureClass, Volatility,
33};
34use datafusion_functions::downcast_arg;
35use datafusion_functions::utils::make_scalar_function;
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct BitmapCount {
39    signature: Signature,
40}
41
42impl Default for BitmapCount {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl BitmapCount {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::coercible(
52                vec![Coercion::new_exact(TypeSignatureClass::Binary)],
53                Volatility::Immutable,
54            ),
55        }
56    }
57}
58
59impl ScalarUDFImpl for BitmapCount {
60    fn name(&self) -> &str {
61        "bitmap_count"
62    }
63
64    fn signature(&self) -> &Signature {
65        &self.signature
66    }
67
68    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
69        internal_err!("return_field_from_args should be used instead")
70    }
71
72    fn return_field_from_args(
73        &self,
74        args: datafusion_expr::ReturnFieldArgs,
75    ) -> Result<FieldRef> {
76        use arrow::datatypes::Field;
77        // bitmap_count returns Int64 with the same nullability as the input
78        Ok(Arc::new(Field::new(
79            args.arg_fields[0].name(),
80            DataType::Int64,
81            args.arg_fields[0].is_nullable(),
82        )))
83    }
84
85    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86        make_scalar_function(bitmap_count_inner, vec![])(&args.args)
87    }
88}
89
90fn binary_count_ones(opt: Option<&[u8]>) -> Option<i64> {
91    opt.map(|value| value.iter().map(|b| b.count_ones() as i64).sum())
92}
93
94macro_rules! downcast_and_count_ones {
95    ($input_array:expr, $array_type:ident) => {{
96        let arr = downcast_arg!($input_array, $array_type);
97        Ok(arr.iter().map(binary_count_ones).collect::<Int64Array>())
98    }};
99}
100
101macro_rules! downcast_dict_and_count_ones {
102    ($input_dict:expr, $key_array_type:ident) => {{
103        let dict_array = as_dictionary_array::<$key_array_type>($input_dict);
104        let array = dict_array.downcast_dict::<BinaryArray>().unwrap();
105        Ok(array
106            .into_iter()
107            .map(binary_count_ones)
108            .collect::<Int64Array>())
109    }};
110}
111
112pub fn bitmap_count_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
113    let [input_array] = take_function_args("bitmap_count", arg)?;
114
115    let res: Result<Int64Array> = match &input_array.data_type() {
116        Binary => downcast_and_count_ones!(input_array, BinaryArray),
117        BinaryView => downcast_and_count_ones!(input_array, BinaryViewArray),
118        LargeBinary => downcast_and_count_ones!(input_array, LargeBinaryArray),
119        FixedSizeBinary(_size) => {
120            downcast_and_count_ones!(input_array, FixedSizeBinaryArray)
121        }
122        Dictionary(k, v) if v.as_ref() == &Binary => match k.as_ref() {
123            DataType::Int8 => downcast_dict_and_count_ones!(input_array, Int8Type),
124            DataType::Int16 => downcast_dict_and_count_ones!(input_array, Int16Type),
125            DataType::Int32 => downcast_dict_and_count_ones!(input_array, Int32Type),
126            DataType::Int64 => downcast_dict_and_count_ones!(input_array, Int64Type),
127            data_type => {
128                internal_err!(
129                    "bitmap_count does not support Dictionary({data_type}, Binary)"
130                )
131            }
132        },
133        data_type => {
134            internal_err!("bitmap_count does not support {data_type}")
135        }
136    };
137
138    Ok(Arc::new(res?))
139}
140
141#[cfg(test)]
142mod tests {
143    use crate::function::bitmap::bitmap_count::BitmapCount;
144    use crate::function::utils::test::test_scalar_function;
145    use arrow::array::{Array, Int64Array};
146    use arrow::datatypes::DataType::Int64;
147    use arrow::datatypes::{DataType, Field};
148    use datafusion_common::config::ConfigOptions;
149    use datafusion_common::{Result, ScalarValue};
150    use datafusion_expr::ColumnarValue::Scalar;
151    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
152    use std::sync::Arc;
153
154    macro_rules! test_bitmap_count_binary_invoke {
155        ($INPUT:expr, $EXPECTED:expr) => {
156            test_scalar_function!(
157                BitmapCount::new(),
158                vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
159                $EXPECTED,
160                i64,
161                Int64,
162                Int64Array
163            );
164
165            test_scalar_function!(
166                BitmapCount::new(),
167                vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
168                $EXPECTED,
169                i64,
170                Int64,
171                Int64Array
172            );
173
174            test_scalar_function!(
175                BitmapCount::new(),
176                vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
177                $EXPECTED,
178                i64,
179                Int64,
180                Int64Array
181            );
182
183            test_scalar_function!(
184                BitmapCount::new(),
185                vec![ColumnarValue::Scalar(ScalarValue::FixedSizeBinary(
186                    $INPUT.map(|a| a.len()).unwrap_or(0) as i32,
187                    $INPUT
188                ))],
189                $EXPECTED,
190                i64,
191                Int64,
192                Int64Array
193            );
194        };
195    }
196
197    #[test]
198    fn test_bitmap_count_invoke() -> Result<()> {
199        test_bitmap_count_binary_invoke!(None::<Vec<u8>>, Ok(None));
200        test_bitmap_count_binary_invoke!(Some(vec![0x0Au8]), Ok(Some(2)));
201        test_bitmap_count_binary_invoke!(Some(vec![0xFFu8, 0xFFu8]), Ok(Some(16)));
202        test_bitmap_count_binary_invoke!(
203            Some(vec![0x0Au8, 0xB0u8, 0xCDu8]),
204            Ok(Some(10))
205        );
206        Ok(())
207    }
208
209    #[test]
210    fn test_dictionary_encoded_bitmap_count_invoke() -> Result<()> {
211        let dict = Scalar(ScalarValue::Dictionary(
212            Box::new(DataType::Int32),
213            Box::new(ScalarValue::Binary(Some(vec![0xFFu8, 0xFFu8]))),
214        ));
215
216        let arg_fields = vec![
217            Field::new(
218                "a",
219                DataType::Dictionary(
220                    Box::new(DataType::Int32),
221                    Box::new(DataType::Binary),
222                ),
223                true,
224            )
225            .into(),
226        ];
227        let args = ScalarFunctionArgs {
228            args: vec![dict.clone()],
229            arg_fields,
230            number_rows: 1,
231            return_field: Field::new("f", Int64, true).into(),
232            config_options: Arc::new(ConfigOptions::default()),
233        };
234        let udf = BitmapCount::new();
235        let actual = udf.invoke_with_args(args)?;
236        let expect = Scalar(ScalarValue::Int64(Some(16)));
237        assert_eq!(*actual.into_array(1)?, *expect.into_array(1)?);
238        Ok(())
239    }
240
241    #[test]
242    fn test_bitmap_count_nullability() -> Result<()> {
243        use datafusion_expr::ReturnFieldArgs;
244
245        let bitmap_count = BitmapCount::new();
246
247        // Test with non-nullable binary field
248        let non_nullable_field = Arc::new(Field::new("bin", DataType::Binary, false));
249
250        let result = bitmap_count.return_field_from_args(ReturnFieldArgs {
251            arg_fields: &[Arc::clone(&non_nullable_field)],
252            scalar_arguments: &[None],
253        })?;
254
255        // The result should not be nullable (same as input)
256        assert!(!result.is_nullable());
257        assert_eq!(result.data_type(), &Int64);
258
259        // Test with nullable binary field
260        let nullable_field = Arc::new(Field::new("bin", DataType::Binary, true));
261
262        let result = bitmap_count.return_field_from_args(ReturnFieldArgs {
263            arg_fields: &[Arc::clone(&nullable_field)],
264            scalar_arguments: &[None],
265        })?;
266
267        // The result should be nullable (same as input)
268        assert!(result.is_nullable());
269        assert_eq!(result.data_type(), &Int64);
270
271        Ok(())
272    }
273}