Skip to main content

datafusion_spark/function/bitwise/
bit_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23    DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
24    UInt32Type, UInt64Type,
25};
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::{Result, internal_err, plan_err};
28use datafusion_expr::{
29    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30    Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkBitCount {
36    signature: Signature,
37}
38
39impl Default for SparkBitCount {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkBitCount {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::one_of(
49                vec![
50                    TypeSignature::Exact(vec![DataType::Boolean]),
51                    TypeSignature::Exact(vec![DataType::Int8]),
52                    TypeSignature::Exact(vec![DataType::Int16]),
53                    TypeSignature::Exact(vec![DataType::Int32]),
54                    TypeSignature::Exact(vec![DataType::Int64]),
55                    TypeSignature::Exact(vec![DataType::UInt8]),
56                    TypeSignature::Exact(vec![DataType::UInt16]),
57                    TypeSignature::Exact(vec![DataType::UInt32]),
58                    TypeSignature::Exact(vec![DataType::UInt64]),
59                ],
60                Volatility::Immutable,
61            ),
62        }
63    }
64}
65
66impl ScalarUDFImpl for SparkBitCount {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70
71    fn name(&self) -> &str {
72        "bit_count"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80        internal_err!("return_field_from_args should be used instead")
81    }
82
83    fn return_field_from_args(
84        &self,
85        args: datafusion_expr::ReturnFieldArgs,
86    ) -> Result<FieldRef> {
87        use arrow::datatypes::Field;
88        // bit_count returns Int32 with the same nullability as the input
89        Ok(Arc::new(Field::new(
90            args.arg_fields[0].name(),
91            DataType::Int32,
92            args.arg_fields[0].is_nullable(),
93        )))
94    }
95
96    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97        if args.args.len() != 1 {
98            return plan_err!("bit_count expects exactly 1 argument");
99        }
100
101        make_scalar_function(spark_bit_count, vec![])(&args.args)
102    }
103}
104
105fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
106    let value_array = value_array[0].as_ref();
107    match value_array.data_type() {
108        DataType::Boolean => {
109            let result: Int32Array = as_boolean_array(value_array)?
110                .iter()
111                .map(|x| x.map(|y| y as i32))
112                .collect();
113            Ok(Arc::new(result))
114        }
115        DataType::Int8 => {
116            let result: Int32Array = value_array
117                .as_primitive::<Int8Type>()
118                .unary(|v| (v as i64).count_ones() as i32);
119            Ok(Arc::new(result))
120        }
121        DataType::Int16 => {
122            let result: Int32Array = value_array
123                .as_primitive::<Int16Type>()
124                .unary(|v| (v as i64).count_ones() as i32);
125            Ok(Arc::new(result))
126        }
127        DataType::Int32 => {
128            let result: Int32Array = value_array
129                .as_primitive::<Int32Type>()
130                .unary(|v| (v as i64).count_ones() as i32);
131            Ok(Arc::new(result))
132        }
133        DataType::Int64 => {
134            let result: Int32Array = value_array
135                .as_primitive::<Int64Type>()
136                .unary(|v| v.count_ones() as i32);
137            Ok(Arc::new(result))
138        }
139        DataType::UInt8 => {
140            let result: Int32Array = value_array
141                .as_primitive::<UInt8Type>()
142                .unary(|v| v.count_ones() as i32);
143            Ok(Arc::new(result))
144        }
145        DataType::UInt16 => {
146            let result: Int32Array = value_array
147                .as_primitive::<UInt16Type>()
148                .unary(|v| v.count_ones() as i32);
149            Ok(Arc::new(result))
150        }
151        DataType::UInt32 => {
152            let result: Int32Array = value_array
153                .as_primitive::<UInt32Type>()
154                .unary(|v| v.count_ones() as i32);
155            Ok(Arc::new(result))
156        }
157        DataType::UInt64 => {
158            let result: Int32Array = value_array
159                .as_primitive::<UInt64Type>()
160                .unary(|v| v.count_ones() as i32);
161            Ok(Arc::new(result))
162        }
163        _ => {
164            plan_err!(
165                "bit_count function does not support data type: {}",
166                value_array.data_type()
167            )
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use arrow::array::{
176        Array, BooleanArray, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array,
177        UInt16Array, UInt32Array, UInt64Array,
178    };
179    use arrow::datatypes::{Field, Int32Type};
180
181    #[test]
182    fn test_bit_count_basic() {
183        // Test bit_count(0) - no bits set
184        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
185
186        assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
187
188        // Test bit_count(1) - 1 bit set
189        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
190
191        assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
192
193        // Test bit_count(7) - 7 = 111 in binary, 3 bits set
194        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
195
196        assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
197
198        // Test bit_count(15) - 15 = 1111 in binary, 4 bits set
199        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
200
201        assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
202    }
203
204    #[test]
205    fn test_bit_count_int8() {
206        // Test bit_count on Int8Array
207        let result =
208            spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
209                .unwrap();
210
211        let arr = result.as_primitive::<Int32Type>();
212        assert_eq!(arr.value(0), 0);
213        assert_eq!(arr.value(1), 1);
214        assert_eq!(arr.value(2), 2);
215        assert_eq!(arr.value(3), 3);
216        assert_eq!(arr.value(4), 4);
217        assert_eq!(arr.value(5), 64);
218    }
219
220    #[test]
221    fn test_bit_count_boolean() {
222        // Test bit_count on BooleanArray
223        let result =
224            spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
225
226        let arr = result.as_primitive::<Int32Type>();
227        assert_eq!(arr.value(0), 1);
228        assert_eq!(arr.value(1), 0);
229    }
230
231    #[test]
232    fn test_bit_count_int16() {
233        // Test bit_count on Int16Array
234        let result =
235            spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
236                .unwrap();
237
238        let arr = result.as_primitive::<Int32Type>();
239        assert_eq!(arr.value(0), 0);
240        assert_eq!(arr.value(1), 1);
241        assert_eq!(arr.value(2), 8);
242        assert_eq!(arr.value(3), 10);
243        assert_eq!(arr.value(4), 64);
244    }
245
246    #[test]
247    fn test_bit_count_int32() {
248        // Test bit_count on Int32Array
249        let result =
250            spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
251                .unwrap();
252
253        let arr = result.as_primitive::<Int32Type>();
254        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
255        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
256        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
257        assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
258        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set
259    }
260
261    #[test]
262    fn test_bit_count_int64() {
263        // Test bit_count on Int64Array
264        let result =
265            spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
266                .unwrap();
267
268        let arr = result.as_primitive::<Int32Type>();
269        assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0
270        assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1
271        assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8
272        assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10
273        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set
274    }
275
276    #[test]
277    fn test_bit_count_uint8() {
278        // Test bit_count on UInt8Array
279        let result =
280            spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
281
282        let arr = result.as_primitive::<Int32Type>();
283        assert_eq!(arr.value(0), 0); // 0b00000000 = 0
284        assert_eq!(arr.value(1), 1); // 0b00000001 = 1
285        assert_eq!(arr.value(2), 8); // 0b11111111 = 8
286    }
287
288    #[test]
289    fn test_bit_count_uint16() {
290        // Test bit_count on UInt16Array
291        let result =
292            spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
293                .unwrap();
294
295        let arr = result.as_primitive::<Int32Type>();
296        assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0
297        assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1
298        assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8
299        assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16
300    }
301
302    #[test]
303    fn test_bit_count_uint32() {
304        // Test bit_count on UInt32Array
305        let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
306            0u32, 1, 255, 4294967295,
307        ]))])
308        .unwrap();
309
310        let arr = result.as_primitive::<Int32Type>();
311        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
312        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
313        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
314        assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32
315    }
316
317    #[test]
318    fn test_bit_count_uint64() {
319        // Test bit_count on UInt64Array
320        let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
321            0u64,
322            1,
323            255,
324            256,
325            u64::MAX,
326        ]))])
327        .unwrap();
328
329        let arr = result.as_primitive::<Int32Type>();
330        // 0b0 = 0
331        assert_eq!(arr.value(0), 0);
332        // 0b1 = 1
333        assert_eq!(arr.value(1), 1);
334        // 0b11111111 = 8
335        assert_eq!(arr.value(2), 8);
336        // 0b100000000 = 1
337        assert_eq!(arr.value(3), 1);
338        // u64::MAX = all 64 bits set
339        assert_eq!(arr.value(4), 64);
340    }
341
342    #[test]
343    fn test_bit_count_nulls() {
344        // Test bit_count with nulls
345        let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
346        let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
347        let arr = result.as_primitive::<Int32Type>();
348        assert_eq!(arr.value(0), 2); // 0b11
349        assert!(arr.is_null(1));
350        assert_eq!(arr.value(2), 3); // 0b111
351    }
352
353    #[test]
354    fn test_bit_count_nullability() -> Result<()> {
355        use datafusion_expr::ReturnFieldArgs;
356
357        let bit_count = SparkBitCount::new();
358
359        // Test with non-nullable Int32 field
360        let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false));
361
362        let result = bit_count.return_field_from_args(ReturnFieldArgs {
363            arg_fields: &[Arc::clone(&non_nullable_field)],
364            scalar_arguments: &[None],
365        })?;
366
367        // The result should not be nullable (same as input)
368        assert!(!result.is_nullable());
369        assert_eq!(result.data_type(), &DataType::Int32);
370
371        // Test with nullable Int32 field
372        let nullable_field = Arc::new(Field::new("num", DataType::Int32, true));
373
374        let result = bit_count.return_field_from_args(ReturnFieldArgs {
375            arg_fields: &[Arc::clone(&nullable_field)],
376            scalar_arguments: &[None],
377        })?;
378
379        // The result should be nullable (same as input)
380        assert!(result.is_nullable());
381        assert_eq!(result.data_type(), &DataType::Int32);
382
383        Ok(())
384    }
385}