Skip to main content

datafusion_spark/function/bitwise/
bit_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray, Int32Array};
21use arrow::datatypes::{
22    DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
23    UInt32Type, UInt64Type,
24};
25use datafusion_common::cast::as_boolean_array;
26use datafusion_common::{Result, internal_err, plan_err};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
29    Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkBitCount {
35    signature: Signature,
36}
37
38impl Default for SparkBitCount {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkBitCount {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::one_of(
48                vec![
49                    TypeSignature::Exact(vec![DataType::Boolean]),
50                    TypeSignature::Exact(vec![DataType::Int8]),
51                    TypeSignature::Exact(vec![DataType::Int16]),
52                    TypeSignature::Exact(vec![DataType::Int32]),
53                    TypeSignature::Exact(vec![DataType::Int64]),
54                    TypeSignature::Exact(vec![DataType::UInt8]),
55                    TypeSignature::Exact(vec![DataType::UInt16]),
56                    TypeSignature::Exact(vec![DataType::UInt32]),
57                    TypeSignature::Exact(vec![DataType::UInt64]),
58                ],
59                Volatility::Immutable,
60            ),
61        }
62    }
63}
64
65impl ScalarUDFImpl for SparkBitCount {
66    fn name(&self) -> &str {
67        "bit_count"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75        internal_err!("return_field_from_args should be used instead")
76    }
77
78    fn return_field_from_args(
79        &self,
80        args: datafusion_expr::ReturnFieldArgs,
81    ) -> Result<FieldRef> {
82        use arrow::datatypes::Field;
83        // bit_count returns Int32 with the same nullability as the input
84        Ok(Arc::new(Field::new(
85            args.arg_fields[0].name(),
86            DataType::Int32,
87            args.arg_fields[0].is_nullable(),
88        )))
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        if args.args.len() != 1 {
93            return plan_err!("bit_count expects exactly 1 argument");
94        }
95
96        make_scalar_function(spark_bit_count, vec![])(&args.args)
97    }
98}
99
100fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
101    let value_array = value_array[0].as_ref();
102    match value_array.data_type() {
103        DataType::Boolean => {
104            let result: Int32Array = as_boolean_array(value_array)?
105                .iter()
106                .map(|x| x.map(|y| y as i32))
107                .collect();
108            Ok(Arc::new(result))
109        }
110        DataType::Int8 => {
111            let result: Int32Array = value_array
112                .as_primitive::<Int8Type>()
113                .unary(|v| (v as i64).count_ones() as i32);
114            Ok(Arc::new(result))
115        }
116        DataType::Int16 => {
117            let result: Int32Array = value_array
118                .as_primitive::<Int16Type>()
119                .unary(|v| (v as i64).count_ones() as i32);
120            Ok(Arc::new(result))
121        }
122        DataType::Int32 => {
123            let result: Int32Array = value_array
124                .as_primitive::<Int32Type>()
125                .unary(|v| (v as i64).count_ones() as i32);
126            Ok(Arc::new(result))
127        }
128        DataType::Int64 => {
129            let result: Int32Array = value_array
130                .as_primitive::<Int64Type>()
131                .unary(|v| v.count_ones() as i32);
132            Ok(Arc::new(result))
133        }
134        DataType::UInt8 => {
135            let result: Int32Array = value_array
136                .as_primitive::<UInt8Type>()
137                .unary(|v| v.count_ones() as i32);
138            Ok(Arc::new(result))
139        }
140        DataType::UInt16 => {
141            let result: Int32Array = value_array
142                .as_primitive::<UInt16Type>()
143                .unary(|v| v.count_ones() as i32);
144            Ok(Arc::new(result))
145        }
146        DataType::UInt32 => {
147            let result: Int32Array = value_array
148                .as_primitive::<UInt32Type>()
149                .unary(|v| v.count_ones() as i32);
150            Ok(Arc::new(result))
151        }
152        DataType::UInt64 => {
153            let result: Int32Array = value_array
154                .as_primitive::<UInt64Type>()
155                .unary(|v| v.count_ones() as i32);
156            Ok(Arc::new(result))
157        }
158        _ => {
159            plan_err!(
160                "bit_count function does not support data type: {}",
161                value_array.data_type()
162            )
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use arrow::array::{
171        Array, BooleanArray, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array,
172        UInt16Array, UInt32Array, UInt64Array,
173    };
174    use arrow::datatypes::Field;
175
176    #[test]
177    fn test_bit_count_basic() {
178        // Test bit_count(0) - no bits set
179        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
180
181        assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
182
183        // Test bit_count(1) - 1 bit set
184        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
185
186        assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
187
188        // Test bit_count(7) - 7 = 111 in binary, 3 bits set
189        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
190
191        assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
192
193        // Test bit_count(15) - 15 = 1111 in binary, 4 bits set
194        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
195
196        assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
197    }
198
199    #[test]
200    fn test_bit_count_int8() {
201        // Test bit_count on Int8Array
202        let result =
203            spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
204                .unwrap();
205
206        let arr = result.as_primitive::<Int32Type>();
207        assert_eq!(arr.value(0), 0);
208        assert_eq!(arr.value(1), 1);
209        assert_eq!(arr.value(2), 2);
210        assert_eq!(arr.value(3), 3);
211        assert_eq!(arr.value(4), 4);
212        assert_eq!(arr.value(5), 64);
213    }
214
215    #[test]
216    fn test_bit_count_boolean() {
217        // Test bit_count on BooleanArray
218        let result =
219            spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
220
221        let arr = result.as_primitive::<Int32Type>();
222        assert_eq!(arr.value(0), 1);
223        assert_eq!(arr.value(1), 0);
224    }
225
226    #[test]
227    fn test_bit_count_int16() {
228        // Test bit_count on Int16Array
229        let result =
230            spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
231                .unwrap();
232
233        let arr = result.as_primitive::<Int32Type>();
234        assert_eq!(arr.value(0), 0);
235        assert_eq!(arr.value(1), 1);
236        assert_eq!(arr.value(2), 8);
237        assert_eq!(arr.value(3), 10);
238        assert_eq!(arr.value(4), 64);
239    }
240
241    #[test]
242    fn test_bit_count_int32() {
243        // Test bit_count on Int32Array
244        let result =
245            spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
246                .unwrap();
247
248        let arr = result.as_primitive::<Int32Type>();
249        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
250        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
251        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
252        assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
253        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set
254    }
255
256    #[test]
257    fn test_bit_count_int64() {
258        // Test bit_count on Int64Array
259        let result =
260            spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
261                .unwrap();
262
263        let arr = result.as_primitive::<Int32Type>();
264        assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0
265        assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1
266        assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8
267        assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10
268        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set
269    }
270
271    #[test]
272    fn test_bit_count_uint8() {
273        // Test bit_count on UInt8Array
274        let result =
275            spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
276
277        let arr = result.as_primitive::<Int32Type>();
278        assert_eq!(arr.value(0), 0); // 0b00000000 = 0
279        assert_eq!(arr.value(1), 1); // 0b00000001 = 1
280        assert_eq!(arr.value(2), 8); // 0b11111111 = 8
281    }
282
283    #[test]
284    fn test_bit_count_uint16() {
285        // Test bit_count on UInt16Array
286        let result =
287            spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
288                .unwrap();
289
290        let arr = result.as_primitive::<Int32Type>();
291        assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0
292        assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1
293        assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8
294        assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16
295    }
296
297    #[test]
298    fn test_bit_count_uint32() {
299        // Test bit_count on UInt32Array
300        let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
301            0u32, 1, 255, 4294967295,
302        ]))])
303        .unwrap();
304
305        let arr = result.as_primitive::<Int32Type>();
306        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
307        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
308        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
309        assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32
310    }
311
312    #[test]
313    fn test_bit_count_uint64() {
314        // Test bit_count on UInt64Array
315        let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
316            0u64,
317            1,
318            255,
319            256,
320            u64::MAX,
321        ]))])
322        .unwrap();
323
324        let arr = result.as_primitive::<Int32Type>();
325        // 0b0 = 0
326        assert_eq!(arr.value(0), 0);
327        // 0b1 = 1
328        assert_eq!(arr.value(1), 1);
329        // 0b11111111 = 8
330        assert_eq!(arr.value(2), 8);
331        // 0b100000000 = 1
332        assert_eq!(arr.value(3), 1);
333        // u64::MAX = all 64 bits set
334        assert_eq!(arr.value(4), 64);
335    }
336
337    #[test]
338    fn test_bit_count_nulls() {
339        // Test bit_count with nulls
340        let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
341        let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
342        let arr = result.as_primitive::<Int32Type>();
343        assert_eq!(arr.value(0), 2); // 0b11
344        assert!(arr.is_null(1));
345        assert_eq!(arr.value(2), 3); // 0b111
346    }
347
348    #[test]
349    fn test_bit_count_nullability() -> Result<()> {
350        use datafusion_expr::ReturnFieldArgs;
351
352        let bit_count = SparkBitCount::new();
353
354        // Test with non-nullable Int32 field
355        let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false));
356
357        let result = bit_count.return_field_from_args(ReturnFieldArgs {
358            arg_fields: &[Arc::clone(&non_nullable_field)],
359            scalar_arguments: &[None],
360        })?;
361
362        // The result should not be nullable (same as input)
363        assert!(!result.is_nullable());
364        assert_eq!(result.data_type(), &DataType::Int32);
365
366        // Test with nullable Int32 field
367        let nullable_field = Arc::new(Field::new("num", DataType::Int32, true));
368
369        let result = bit_count.return_field_from_args(ReturnFieldArgs {
370            arg_fields: &[Arc::clone(&nullable_field)],
371            scalar_arguments: &[None],
372        })?;
373
374        // The result should be nullable (same as input)
375        assert!(result.is_nullable());
376        assert_eq!(result.data_type(), &DataType::Int32);
377
378        Ok(())
379    }
380}