datafusion_spark/function/bitwise/
bit_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23    DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24    UInt64Type, UInt8Type,
25};
26use datafusion_common::{plan_err, Result};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
29    Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkBitCount {
35    signature: Signature,
36}
37
38impl Default for SparkBitCount {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkBitCount {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::one_of(
48                vec![
49                    TypeSignature::Exact(vec![DataType::Int8]),
50                    TypeSignature::Exact(vec![DataType::Int16]),
51                    TypeSignature::Exact(vec![DataType::Int32]),
52                    TypeSignature::Exact(vec![DataType::Int64]),
53                    TypeSignature::Exact(vec![DataType::UInt8]),
54                    TypeSignature::Exact(vec![DataType::UInt16]),
55                    TypeSignature::Exact(vec![DataType::UInt32]),
56                    TypeSignature::Exact(vec![DataType::UInt64]),
57                ],
58                Volatility::Immutable,
59            ),
60        }
61    }
62}
63
64impl ScalarUDFImpl for SparkBitCount {
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn name(&self) -> &str {
70        "bit_count"
71    }
72
73    fn signature(&self) -> &Signature {
74        &self.signature
75    }
76
77    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78        Ok(DataType::Int32) // Spark returns int (Int32)
79    }
80
81    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82        if args.args.len() != 1 {
83            return plan_err!("bit_count expects exactly 1 argument");
84        }
85
86        make_scalar_function(spark_bit_count, vec![])(&args.args)
87    }
88}
89
90fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
91    let value_array = value_array[0].as_ref();
92    match value_array.data_type() {
93        DataType::Int8 => {
94            let result: Int32Array = value_array
95                .as_primitive::<Int8Type>()
96                .unary(|v| v.count_ones() as i32);
97            Ok(Arc::new(result))
98        }
99        DataType::Int16 => {
100            let result: Int32Array = value_array
101                .as_primitive::<Int16Type>()
102                .unary(|v| v.count_ones() as i32);
103            Ok(Arc::new(result))
104        }
105        DataType::Int32 => {
106            let result: Int32Array = value_array
107                .as_primitive::<Int32Type>()
108                .unary(|v| v.count_ones() as i32);
109            Ok(Arc::new(result))
110        }
111        DataType::Int64 => {
112            let result: Int32Array = value_array
113                .as_primitive::<Int64Type>()
114                .unary(|v| v.count_ones() as i32);
115            Ok(Arc::new(result))
116        }
117        DataType::UInt8 => {
118            let result: Int32Array = value_array
119                .as_primitive::<UInt8Type>()
120                .unary(|v| v.count_ones() as i32);
121            Ok(Arc::new(result))
122        }
123        DataType::UInt16 => {
124            let result: Int32Array = value_array
125                .as_primitive::<UInt16Type>()
126                .unary(|v| v.count_ones() as i32);
127            Ok(Arc::new(result))
128        }
129        DataType::UInt32 => {
130            let result: Int32Array = value_array
131                .as_primitive::<UInt32Type>()
132                .unary(|v| v.count_ones() as i32);
133            Ok(Arc::new(result))
134        }
135        DataType::UInt64 => {
136            let result: Int32Array = value_array
137                .as_primitive::<UInt64Type>()
138                .unary(|v| v.count_ones() as i32);
139            Ok(Arc::new(result))
140        }
141        _ => {
142            plan_err!(
143                "bit_count function does not support data type: {:?}",
144                value_array.data_type()
145            )
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use arrow::array::{
154        Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array,
155        UInt64Array, UInt8Array,
156    };
157    use arrow::datatypes::Int32Type;
158
159    #[test]
160    fn test_bit_count_basic() {
161        // Test bit_count(0) - no bits set
162        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
163
164        assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
165
166        // Test bit_count(1) - 1 bit set
167        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
168
169        assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
170
171        // Test bit_count(7) - 7 = 111 in binary, 3 bits set
172        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
173
174        assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
175
176        // Test bit_count(15) - 15 = 1111 in binary, 4 bits set
177        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
178
179        assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
180    }
181
182    #[test]
183    fn test_bit_count_int8() {
184        // Test bit_count on Int8Array
185        let result =
186            spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
187                .unwrap();
188
189        let arr = result.as_primitive::<Int32Type>();
190        assert_eq!(arr.value(0), 0);
191        assert_eq!(arr.value(1), 1);
192        assert_eq!(arr.value(2), 2);
193        assert_eq!(arr.value(3), 3);
194        assert_eq!(arr.value(4), 4);
195        assert_eq!(arr.value(5), 8);
196    }
197
198    #[test]
199    fn test_bit_count_int16() {
200        // Test bit_count on Int16Array
201        let result =
202            spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
203                .unwrap();
204
205        let arr = result.as_primitive::<Int32Type>();
206        assert_eq!(arr.value(0), 0);
207        assert_eq!(arr.value(1), 1);
208        assert_eq!(arr.value(2), 8);
209        assert_eq!(arr.value(3), 10);
210        assert_eq!(arr.value(4), 16);
211    }
212
213    #[test]
214    fn test_bit_count_int32() {
215        // Test bit_count on Int32Array
216        let result =
217            spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
218                .unwrap();
219
220        let arr = result.as_primitive::<Int32Type>();
221        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
222        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
223        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
224        assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
225        assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set
226    }
227
228    #[test]
229    fn test_bit_count_int64() {
230        // Test bit_count on Int64Array
231        let result =
232            spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
233                .unwrap();
234
235        let arr = result.as_primitive::<Int32Type>();
236        assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0
237        assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1
238        assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8
239        assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10
240        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set
241    }
242
243    #[test]
244    fn test_bit_count_uint8() {
245        // Test bit_count on UInt8Array
246        let result =
247            spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
248
249        let arr = result.as_primitive::<Int32Type>();
250        assert_eq!(arr.value(0), 0); // 0b00000000 = 0
251        assert_eq!(arr.value(1), 1); // 0b00000001 = 1
252        assert_eq!(arr.value(2), 8); // 0b11111111 = 8
253    }
254
255    #[test]
256    fn test_bit_count_uint16() {
257        // Test bit_count on UInt16Array
258        let result =
259            spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
260                .unwrap();
261
262        let arr = result.as_primitive::<Int32Type>();
263        assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0
264        assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1
265        assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8
266        assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16
267    }
268
269    #[test]
270    fn test_bit_count_uint32() {
271        // Test bit_count on UInt32Array
272        let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
273            0u32, 1, 255, 4294967295,
274        ]))])
275        .unwrap();
276
277        let arr = result.as_primitive::<Int32Type>();
278        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
279        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
280        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
281        assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32
282    }
283
284    #[test]
285    fn test_bit_count_uint64() {
286        // Test bit_count on UInt64Array
287        let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
288            0u64,
289            1,
290            255,
291            256,
292            u64::MAX,
293        ]))])
294        .unwrap();
295
296        let arr = result.as_primitive::<Int32Type>();
297        // 0b0 = 0
298        assert_eq!(arr.value(0), 0);
299        // 0b1 = 1
300        assert_eq!(arr.value(1), 1);
301        // 0b11111111 = 8
302        assert_eq!(arr.value(2), 8);
303        // 0b100000000 = 1
304        assert_eq!(arr.value(3), 1);
305        // u64::MAX = all 64 bits set
306        assert_eq!(arr.value(4), 64);
307    }
308
309    #[test]
310    fn test_bit_count_nulls() {
311        // Test bit_count with nulls
312        let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
313        let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
314        let arr = result.as_primitive::<Int32Type>();
315        assert_eq!(arr.value(0), 2); // 0b11
316        assert!(arr.is_null(1));
317        assert_eq!(arr.value(2), 3); // 0b111
318    }
319}