datafusion_spark/function/bitwise/
bit_count.rs1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray, Int32Array};
21use arrow::datatypes::{
22 DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
23 UInt32Type, UInt64Type,
24};
25use datafusion_common::cast::as_boolean_array;
26use datafusion_common::{Result, internal_err, plan_err};
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
29 Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkBitCount {
35 signature: Signature,
36}
37
38impl Default for SparkBitCount {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkBitCount {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::one_of(
48 vec![
49 TypeSignature::Exact(vec![DataType::Boolean]),
50 TypeSignature::Exact(vec![DataType::Int8]),
51 TypeSignature::Exact(vec![DataType::Int16]),
52 TypeSignature::Exact(vec![DataType::Int32]),
53 TypeSignature::Exact(vec![DataType::Int64]),
54 TypeSignature::Exact(vec![DataType::UInt8]),
55 TypeSignature::Exact(vec![DataType::UInt16]),
56 TypeSignature::Exact(vec![DataType::UInt32]),
57 TypeSignature::Exact(vec![DataType::UInt64]),
58 ],
59 Volatility::Immutable,
60 ),
61 }
62 }
63}
64
65impl ScalarUDFImpl for SparkBitCount {
66 fn name(&self) -> &str {
67 "bit_count"
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75 internal_err!("return_field_from_args should be used instead")
76 }
77
78 fn return_field_from_args(
79 &self,
80 args: datafusion_expr::ReturnFieldArgs,
81 ) -> Result<FieldRef> {
82 use arrow::datatypes::Field;
83 Ok(Arc::new(Field::new(
85 args.arg_fields[0].name(),
86 DataType::Int32,
87 args.arg_fields[0].is_nullable(),
88 )))
89 }
90
91 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 if args.args.len() != 1 {
93 return plan_err!("bit_count expects exactly 1 argument");
94 }
95
96 make_scalar_function(spark_bit_count, vec![])(&args.args)
97 }
98}
99
100fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
101 let value_array = value_array[0].as_ref();
102 match value_array.data_type() {
103 DataType::Boolean => {
104 let result: Int32Array = as_boolean_array(value_array)?
105 .iter()
106 .map(|x| x.map(|y| y as i32))
107 .collect();
108 Ok(Arc::new(result))
109 }
110 DataType::Int8 => {
111 let result: Int32Array = value_array
112 .as_primitive::<Int8Type>()
113 .unary(|v| (v as i64).count_ones() as i32);
114 Ok(Arc::new(result))
115 }
116 DataType::Int16 => {
117 let result: Int32Array = value_array
118 .as_primitive::<Int16Type>()
119 .unary(|v| (v as i64).count_ones() as i32);
120 Ok(Arc::new(result))
121 }
122 DataType::Int32 => {
123 let result: Int32Array = value_array
124 .as_primitive::<Int32Type>()
125 .unary(|v| (v as i64).count_ones() as i32);
126 Ok(Arc::new(result))
127 }
128 DataType::Int64 => {
129 let result: Int32Array = value_array
130 .as_primitive::<Int64Type>()
131 .unary(|v| v.count_ones() as i32);
132 Ok(Arc::new(result))
133 }
134 DataType::UInt8 => {
135 let result: Int32Array = value_array
136 .as_primitive::<UInt8Type>()
137 .unary(|v| v.count_ones() as i32);
138 Ok(Arc::new(result))
139 }
140 DataType::UInt16 => {
141 let result: Int32Array = value_array
142 .as_primitive::<UInt16Type>()
143 .unary(|v| v.count_ones() as i32);
144 Ok(Arc::new(result))
145 }
146 DataType::UInt32 => {
147 let result: Int32Array = value_array
148 .as_primitive::<UInt32Type>()
149 .unary(|v| v.count_ones() as i32);
150 Ok(Arc::new(result))
151 }
152 DataType::UInt64 => {
153 let result: Int32Array = value_array
154 .as_primitive::<UInt64Type>()
155 .unary(|v| v.count_ones() as i32);
156 Ok(Arc::new(result))
157 }
158 _ => {
159 plan_err!(
160 "bit_count function does not support data type: {}",
161 value_array.data_type()
162 )
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use arrow::array::{
171 Array, BooleanArray, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array,
172 UInt16Array, UInt32Array, UInt64Array,
173 };
174 use arrow::datatypes::Field;
175
176 #[test]
177 fn test_bit_count_basic() {
178 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
180
181 assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
182
183 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
185
186 assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
187
188 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
190
191 assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
192
193 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
195
196 assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
197 }
198
199 #[test]
200 fn test_bit_count_int8() {
201 let result =
203 spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
204 .unwrap();
205
206 let arr = result.as_primitive::<Int32Type>();
207 assert_eq!(arr.value(0), 0);
208 assert_eq!(arr.value(1), 1);
209 assert_eq!(arr.value(2), 2);
210 assert_eq!(arr.value(3), 3);
211 assert_eq!(arr.value(4), 4);
212 assert_eq!(arr.value(5), 64);
213 }
214
215 #[test]
216 fn test_bit_count_boolean() {
217 let result =
219 spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
220
221 let arr = result.as_primitive::<Int32Type>();
222 assert_eq!(arr.value(0), 1);
223 assert_eq!(arr.value(1), 0);
224 }
225
226 #[test]
227 fn test_bit_count_int16() {
228 let result =
230 spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
231 .unwrap();
232
233 let arr = result.as_primitive::<Int32Type>();
234 assert_eq!(arr.value(0), 0);
235 assert_eq!(arr.value(1), 1);
236 assert_eq!(arr.value(2), 8);
237 assert_eq!(arr.value(3), 10);
238 assert_eq!(arr.value(4), 64);
239 }
240
241 #[test]
242 fn test_bit_count_int32() {
243 let result =
245 spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
246 .unwrap();
247
248 let arr = result.as_primitive::<Int32Type>();
249 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
255
256 #[test]
257 fn test_bit_count_int64() {
258 let result =
260 spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
261 .unwrap();
262
263 let arr = result.as_primitive::<Int32Type>();
264 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
270
271 #[test]
272 fn test_bit_count_uint8() {
273 let result =
275 spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
276
277 let arr = result.as_primitive::<Int32Type>();
278 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); }
282
283 #[test]
284 fn test_bit_count_uint16() {
285 let result =
287 spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
288 .unwrap();
289
290 let arr = result.as_primitive::<Int32Type>();
291 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 16); }
296
297 #[test]
298 fn test_bit_count_uint32() {
299 let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
301 0u32, 1, 255, 4294967295,
302 ]))])
303 .unwrap();
304
305 let arr = result.as_primitive::<Int32Type>();
306 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 32); }
311
312 #[test]
313 fn test_bit_count_uint64() {
314 let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
316 0u64,
317 1,
318 255,
319 256,
320 u64::MAX,
321 ]))])
322 .unwrap();
323
324 let arr = result.as_primitive::<Int32Type>();
325 assert_eq!(arr.value(0), 0);
327 assert_eq!(arr.value(1), 1);
329 assert_eq!(arr.value(2), 8);
331 assert_eq!(arr.value(3), 1);
333 assert_eq!(arr.value(4), 64);
335 }
336
337 #[test]
338 fn test_bit_count_nulls() {
339 let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
341 let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
342 let arr = result.as_primitive::<Int32Type>();
343 assert_eq!(arr.value(0), 2); assert!(arr.is_null(1));
345 assert_eq!(arr.value(2), 3); }
347
348 #[test]
349 fn test_bit_count_nullability() -> Result<()> {
350 use datafusion_expr::ReturnFieldArgs;
351
352 let bit_count = SparkBitCount::new();
353
354 let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false));
356
357 let result = bit_count.return_field_from_args(ReturnFieldArgs {
358 arg_fields: &[Arc::clone(&non_nullable_field)],
359 scalar_arguments: &[None],
360 })?;
361
362 assert!(!result.is_nullable());
364 assert_eq!(result.data_type(), &DataType::Int32);
365
366 let nullable_field = Arc::new(Field::new("num", DataType::Int32, true));
368
369 let result = bit_count.return_field_from_args(ReturnFieldArgs {
370 arg_fields: &[Arc::clone(&nullable_field)],
371 scalar_arguments: &[None],
372 })?;
373
374 assert!(result.is_nullable());
376 assert_eq!(result.data_type(), &DataType::Int32);
377
378 Ok(())
379 }
380}