datafusion_spark/function/bitwise/
bit_count.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23    DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24    UInt64Type, UInt8Type,
25};
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::{plan_err, Result};
28use datafusion_expr::{
29    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30    Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkBitCount {
36    signature: Signature,
37}
38
39impl Default for SparkBitCount {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkBitCount {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::one_of(
49                vec![
50                    TypeSignature::Exact(vec![DataType::Boolean]),
51                    TypeSignature::Exact(vec![DataType::Int8]),
52                    TypeSignature::Exact(vec![DataType::Int16]),
53                    TypeSignature::Exact(vec![DataType::Int32]),
54                    TypeSignature::Exact(vec![DataType::Int64]),
55                    TypeSignature::Exact(vec![DataType::UInt8]),
56                    TypeSignature::Exact(vec![DataType::UInt16]),
57                    TypeSignature::Exact(vec![DataType::UInt32]),
58                    TypeSignature::Exact(vec![DataType::UInt64]),
59                ],
60                Volatility::Immutable,
61            ),
62        }
63    }
64}
65
66impl ScalarUDFImpl for SparkBitCount {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70
71    fn name(&self) -> &str {
72        "bit_count"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80        Ok(DataType::Int32) // Spark returns int (Int32)
81    }
82
83    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84        if args.args.len() != 1 {
85            return plan_err!("bit_count expects exactly 1 argument");
86        }
87
88        make_scalar_function(spark_bit_count, vec![])(&args.args)
89    }
90}
91
92fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
93    let value_array = value_array[0].as_ref();
94    match value_array.data_type() {
95        DataType::Boolean => {
96            let result: Int32Array = as_boolean_array(value_array)?
97                .iter()
98                .map(|x| x.map(|y| y as i32))
99                .collect();
100            Ok(Arc::new(result))
101        }
102        DataType::Int8 => {
103            let result: Int32Array = value_array
104                .as_primitive::<Int8Type>()
105                .unary(|v| bit_count(v.into()));
106            Ok(Arc::new(result))
107        }
108        DataType::Int16 => {
109            let result: Int32Array = value_array
110                .as_primitive::<Int16Type>()
111                .unary(|v| bit_count(v.into()));
112            Ok(Arc::new(result))
113        }
114        DataType::Int32 => {
115            let result: Int32Array = value_array
116                .as_primitive::<Int32Type>()
117                .unary(|v| bit_count(v.into()));
118            Ok(Arc::new(result))
119        }
120        DataType::Int64 => {
121            let result: Int32Array =
122                value_array.as_primitive::<Int64Type>().unary(bit_count);
123            Ok(Arc::new(result))
124        }
125        DataType::UInt8 => {
126            let result: Int32Array = value_array
127                .as_primitive::<UInt8Type>()
128                .unary(|v| v.count_ones() as i32);
129            Ok(Arc::new(result))
130        }
131        DataType::UInt16 => {
132            let result: Int32Array = value_array
133                .as_primitive::<UInt16Type>()
134                .unary(|v| v.count_ones() as i32);
135            Ok(Arc::new(result))
136        }
137        DataType::UInt32 => {
138            let result: Int32Array = value_array
139                .as_primitive::<UInt32Type>()
140                .unary(|v| v.count_ones() as i32);
141            Ok(Arc::new(result))
142        }
143        DataType::UInt64 => {
144            let result: Int32Array = value_array
145                .as_primitive::<UInt64Type>()
146                .unary(|v| v.count_ones() as i32);
147            Ok(Arc::new(result))
148        }
149        _ => {
150            plan_err!(
151                "bit_count function does not support data type: {}",
152                value_array.data_type()
153            )
154        }
155    }
156}
157
158// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
159// Spark: https://github.com/apache/spark/blob/ac717dd7aec665de578d7c6b0070e8fcdde3cea9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala#L243
160// Java impl: https://github.com/openjdk/jdk/blob/d226023643f90027a8980d161ec6d423887ae3ce/src/java.base/share/classes/java/lang/Long.java#L1584
161fn bit_count(i: i64) -> i32 {
162    let mut u = i as u64;
163    u = u - ((u >> 1) & 0x5555555555555555);
164    u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
165    u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
166    u = u + (u >> 8);
167    u = u + (u >> 16);
168    u = u + (u >> 32);
169    (u as i32) & 0x7f
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use arrow::array::{
176        Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
177        UInt32Array, UInt64Array, UInt8Array,
178    };
179    use arrow::datatypes::Int32Type;
180
181    #[test]
182    fn test_bit_count_basic() {
183        // Test bit_count(0) - no bits set
184        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
185
186        assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
187
188        // Test bit_count(1) - 1 bit set
189        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
190
191        assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
192
193        // Test bit_count(7) - 7 = 111 in binary, 3 bits set
194        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
195
196        assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
197
198        // Test bit_count(15) - 15 = 1111 in binary, 4 bits set
199        let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
200
201        assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
202    }
203
204    #[test]
205    fn test_bit_count_int8() {
206        // Test bit_count on Int8Array
207        let result =
208            spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
209                .unwrap();
210
211        let arr = result.as_primitive::<Int32Type>();
212        assert_eq!(arr.value(0), 0);
213        assert_eq!(arr.value(1), 1);
214        assert_eq!(arr.value(2), 2);
215        assert_eq!(arr.value(3), 3);
216        assert_eq!(arr.value(4), 4);
217        assert_eq!(arr.value(5), 64);
218    }
219
220    #[test]
221    fn test_bit_count_boolean() {
222        // Test bit_count on BooleanArray
223        let result =
224            spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
225
226        let arr = result.as_primitive::<Int32Type>();
227        assert_eq!(arr.value(0), 1);
228        assert_eq!(arr.value(1), 0);
229    }
230
231    #[test]
232    fn test_bit_count_int16() {
233        // Test bit_count on Int16Array
234        let result =
235            spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
236                .unwrap();
237
238        let arr = result.as_primitive::<Int32Type>();
239        assert_eq!(arr.value(0), 0);
240        assert_eq!(arr.value(1), 1);
241        assert_eq!(arr.value(2), 8);
242        assert_eq!(arr.value(3), 10);
243        assert_eq!(arr.value(4), 64);
244    }
245
246    #[test]
247    fn test_bit_count_int32() {
248        // Test bit_count on Int32Array
249        let result =
250            spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
251                .unwrap();
252
253        let arr = result.as_primitive::<Int32Type>();
254        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
255        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
256        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
257        assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
258        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set
259    }
260
261    #[test]
262    fn test_bit_count_int64() {
263        // Test bit_count on Int64Array
264        let result =
265            spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
266                .unwrap();
267
268        let arr = result.as_primitive::<Int32Type>();
269        assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0
270        assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1
271        assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8
272        assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10
273        assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set
274    }
275
276    #[test]
277    fn test_bit_count_uint8() {
278        // Test bit_count on UInt8Array
279        let result =
280            spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
281
282        let arr = result.as_primitive::<Int32Type>();
283        assert_eq!(arr.value(0), 0); // 0b00000000 = 0
284        assert_eq!(arr.value(1), 1); // 0b00000001 = 1
285        assert_eq!(arr.value(2), 8); // 0b11111111 = 8
286    }
287
288    #[test]
289    fn test_bit_count_uint16() {
290        // Test bit_count on UInt16Array
291        let result =
292            spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
293                .unwrap();
294
295        let arr = result.as_primitive::<Int32Type>();
296        assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0
297        assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1
298        assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8
299        assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16
300    }
301
302    #[test]
303    fn test_bit_count_uint32() {
304        // Test bit_count on UInt32Array
305        let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
306            0u32, 1, 255, 4294967295,
307        ]))])
308        .unwrap();
309
310        let arr = result.as_primitive::<Int32Type>();
311        assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0
312        assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
313        assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
314        assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32
315    }
316
317    #[test]
318    fn test_bit_count_uint64() {
319        // Test bit_count on UInt64Array
320        let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
321            0u64,
322            1,
323            255,
324            256,
325            u64::MAX,
326        ]))])
327        .unwrap();
328
329        let arr = result.as_primitive::<Int32Type>();
330        // 0b0 = 0
331        assert_eq!(arr.value(0), 0);
332        // 0b1 = 1
333        assert_eq!(arr.value(1), 1);
334        // 0b11111111 = 8
335        assert_eq!(arr.value(2), 8);
336        // 0b100000000 = 1
337        assert_eq!(arr.value(3), 1);
338        // u64::MAX = all 64 bits set
339        assert_eq!(arr.value(4), 64);
340    }
341
342    #[test]
343    fn test_bit_count_nulls() {
344        // Test bit_count with nulls
345        let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
346        let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
347        let arr = result.as_primitive::<Int32Type>();
348        assert_eq!(arr.value(0), 2); // 0b11
349        assert!(arr.is_null(1));
350        assert_eq!(arr.value(2), 3); // 0b111
351    }
352}