datafusion_spark/function/bitwise/
bit_count.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23 DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24 UInt64Type, UInt8Type,
25};
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::{plan_err, Result};
28use datafusion_expr::{
29 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30 Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkBitCount {
36 signature: Signature,
37}
38
39impl Default for SparkBitCount {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkBitCount {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::one_of(
49 vec![
50 TypeSignature::Exact(vec![DataType::Boolean]),
51 TypeSignature::Exact(vec![DataType::Int8]),
52 TypeSignature::Exact(vec![DataType::Int16]),
53 TypeSignature::Exact(vec![DataType::Int32]),
54 TypeSignature::Exact(vec![DataType::Int64]),
55 TypeSignature::Exact(vec![DataType::UInt8]),
56 TypeSignature::Exact(vec![DataType::UInt16]),
57 TypeSignature::Exact(vec![DataType::UInt32]),
58 TypeSignature::Exact(vec![DataType::UInt64]),
59 ],
60 Volatility::Immutable,
61 ),
62 }
63 }
64}
65
66impl ScalarUDFImpl for SparkBitCount {
67 fn as_any(&self) -> &dyn Any {
68 self
69 }
70
71 fn name(&self) -> &str {
72 "bit_count"
73 }
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80 Ok(DataType::Int32) }
82
83 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84 if args.args.len() != 1 {
85 return plan_err!("bit_count expects exactly 1 argument");
86 }
87
88 make_scalar_function(spark_bit_count, vec![])(&args.args)
89 }
90}
91
92fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
93 let value_array = value_array[0].as_ref();
94 match value_array.data_type() {
95 DataType::Boolean => {
96 let result: Int32Array = as_boolean_array(value_array)?
97 .iter()
98 .map(|x| x.map(|y| y as i32))
99 .collect();
100 Ok(Arc::new(result))
101 }
102 DataType::Int8 => {
103 let result: Int32Array = value_array
104 .as_primitive::<Int8Type>()
105 .unary(|v| bit_count(v.into()));
106 Ok(Arc::new(result))
107 }
108 DataType::Int16 => {
109 let result: Int32Array = value_array
110 .as_primitive::<Int16Type>()
111 .unary(|v| bit_count(v.into()));
112 Ok(Arc::new(result))
113 }
114 DataType::Int32 => {
115 let result: Int32Array = value_array
116 .as_primitive::<Int32Type>()
117 .unary(|v| bit_count(v.into()));
118 Ok(Arc::new(result))
119 }
120 DataType::Int64 => {
121 let result: Int32Array =
122 value_array.as_primitive::<Int64Type>().unary(bit_count);
123 Ok(Arc::new(result))
124 }
125 DataType::UInt8 => {
126 let result: Int32Array = value_array
127 .as_primitive::<UInt8Type>()
128 .unary(|v| v.count_ones() as i32);
129 Ok(Arc::new(result))
130 }
131 DataType::UInt16 => {
132 let result: Int32Array = value_array
133 .as_primitive::<UInt16Type>()
134 .unary(|v| v.count_ones() as i32);
135 Ok(Arc::new(result))
136 }
137 DataType::UInt32 => {
138 let result: Int32Array = value_array
139 .as_primitive::<UInt32Type>()
140 .unary(|v| v.count_ones() as i32);
141 Ok(Arc::new(result))
142 }
143 DataType::UInt64 => {
144 let result: Int32Array = value_array
145 .as_primitive::<UInt64Type>()
146 .unary(|v| v.count_ones() as i32);
147 Ok(Arc::new(result))
148 }
149 _ => {
150 plan_err!(
151 "bit_count function does not support data type: {}",
152 value_array.data_type()
153 )
154 }
155 }
156}
157
158fn bit_count(i: i64) -> i32 {
162 let mut u = i as u64;
163 u = u - ((u >> 1) & 0x5555555555555555);
164 u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
165 u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
166 u = u + (u >> 8);
167 u = u + (u >> 16);
168 u = u + (u >> 32);
169 (u as i32) & 0x7f
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use arrow::array::{
176 Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
177 UInt32Array, UInt64Array, UInt8Array,
178 };
179 use arrow::datatypes::Int32Type;
180
181 #[test]
182 fn test_bit_count_basic() {
183 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
185
186 assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
187
188 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
190
191 assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
192
193 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
195
196 assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
197
198 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
200
201 assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
202 }
203
204 #[test]
205 fn test_bit_count_int8() {
206 let result =
208 spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
209 .unwrap();
210
211 let arr = result.as_primitive::<Int32Type>();
212 assert_eq!(arr.value(0), 0);
213 assert_eq!(arr.value(1), 1);
214 assert_eq!(arr.value(2), 2);
215 assert_eq!(arr.value(3), 3);
216 assert_eq!(arr.value(4), 4);
217 assert_eq!(arr.value(5), 64);
218 }
219
220 #[test]
221 fn test_bit_count_boolean() {
222 let result =
224 spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
225
226 let arr = result.as_primitive::<Int32Type>();
227 assert_eq!(arr.value(0), 1);
228 assert_eq!(arr.value(1), 0);
229 }
230
231 #[test]
232 fn test_bit_count_int16() {
233 let result =
235 spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
236 .unwrap();
237
238 let arr = result.as_primitive::<Int32Type>();
239 assert_eq!(arr.value(0), 0);
240 assert_eq!(arr.value(1), 1);
241 assert_eq!(arr.value(2), 8);
242 assert_eq!(arr.value(3), 10);
243 assert_eq!(arr.value(4), 64);
244 }
245
246 #[test]
247 fn test_bit_count_int32() {
248 let result =
250 spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
251 .unwrap();
252
253 let arr = result.as_primitive::<Int32Type>();
254 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
260
261 #[test]
262 fn test_bit_count_int64() {
263 let result =
265 spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
266 .unwrap();
267
268 let arr = result.as_primitive::<Int32Type>();
269 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
275
276 #[test]
277 fn test_bit_count_uint8() {
278 let result =
280 spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
281
282 let arr = result.as_primitive::<Int32Type>();
283 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); }
287
288 #[test]
289 fn test_bit_count_uint16() {
290 let result =
292 spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
293 .unwrap();
294
295 let arr = result.as_primitive::<Int32Type>();
296 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 16); }
301
302 #[test]
303 fn test_bit_count_uint32() {
304 let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
306 0u32, 1, 255, 4294967295,
307 ]))])
308 .unwrap();
309
310 let arr = result.as_primitive::<Int32Type>();
311 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 32); }
316
317 #[test]
318 fn test_bit_count_uint64() {
319 let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
321 0u64,
322 1,
323 255,
324 256,
325 u64::MAX,
326 ]))])
327 .unwrap();
328
329 let arr = result.as_primitive::<Int32Type>();
330 assert_eq!(arr.value(0), 0);
332 assert_eq!(arr.value(1), 1);
334 assert_eq!(arr.value(2), 8);
336 assert_eq!(arr.value(3), 1);
338 assert_eq!(arr.value(4), 64);
340 }
341
342 #[test]
343 fn test_bit_count_nulls() {
344 let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
346 let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
347 let arr = result.as_primitive::<Int32Type>();
348 assert_eq!(arr.value(0), 2); assert!(arr.is_null(1));
350 assert_eq!(arr.value(2), 3); }
352}