datafusion_spark/function/bitwise/
bit_count.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23 DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
24 UInt32Type, UInt64Type,
25};
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::{Result, internal_err, plan_err};
28use datafusion_expr::{
29 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30 Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkBitCount {
36 signature: Signature,
37}
38
39impl Default for SparkBitCount {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkBitCount {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::one_of(
49 vec![
50 TypeSignature::Exact(vec![DataType::Boolean]),
51 TypeSignature::Exact(vec![DataType::Int8]),
52 TypeSignature::Exact(vec![DataType::Int16]),
53 TypeSignature::Exact(vec![DataType::Int32]),
54 TypeSignature::Exact(vec![DataType::Int64]),
55 TypeSignature::Exact(vec![DataType::UInt8]),
56 TypeSignature::Exact(vec![DataType::UInt16]),
57 TypeSignature::Exact(vec![DataType::UInt32]),
58 TypeSignature::Exact(vec![DataType::UInt64]),
59 ],
60 Volatility::Immutable,
61 ),
62 }
63 }
64}
65
66impl ScalarUDFImpl for SparkBitCount {
67 fn as_any(&self) -> &dyn Any {
68 self
69 }
70
71 fn name(&self) -> &str {
72 "bit_count"
73 }
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80 internal_err!("return_field_from_args should be used instead")
81 }
82
83 fn return_field_from_args(
84 &self,
85 args: datafusion_expr::ReturnFieldArgs,
86 ) -> Result<FieldRef> {
87 use arrow::datatypes::Field;
88 Ok(Arc::new(Field::new(
90 args.arg_fields[0].name(),
91 DataType::Int32,
92 args.arg_fields[0].is_nullable(),
93 )))
94 }
95
96 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
97 if args.args.len() != 1 {
98 return plan_err!("bit_count expects exactly 1 argument");
99 }
100
101 make_scalar_function(spark_bit_count, vec![])(&args.args)
102 }
103}
104
105fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
106 let value_array = value_array[0].as_ref();
107 match value_array.data_type() {
108 DataType::Boolean => {
109 let result: Int32Array = as_boolean_array(value_array)?
110 .iter()
111 .map(|x| x.map(|y| y as i32))
112 .collect();
113 Ok(Arc::new(result))
114 }
115 DataType::Int8 => {
116 let result: Int32Array = value_array
117 .as_primitive::<Int8Type>()
118 .unary(|v| (v as i64).count_ones() as i32);
119 Ok(Arc::new(result))
120 }
121 DataType::Int16 => {
122 let result: Int32Array = value_array
123 .as_primitive::<Int16Type>()
124 .unary(|v| (v as i64).count_ones() as i32);
125 Ok(Arc::new(result))
126 }
127 DataType::Int32 => {
128 let result: Int32Array = value_array
129 .as_primitive::<Int32Type>()
130 .unary(|v| (v as i64).count_ones() as i32);
131 Ok(Arc::new(result))
132 }
133 DataType::Int64 => {
134 let result: Int32Array = value_array
135 .as_primitive::<Int64Type>()
136 .unary(|v| v.count_ones() as i32);
137 Ok(Arc::new(result))
138 }
139 DataType::UInt8 => {
140 let result: Int32Array = value_array
141 .as_primitive::<UInt8Type>()
142 .unary(|v| v.count_ones() as i32);
143 Ok(Arc::new(result))
144 }
145 DataType::UInt16 => {
146 let result: Int32Array = value_array
147 .as_primitive::<UInt16Type>()
148 .unary(|v| v.count_ones() as i32);
149 Ok(Arc::new(result))
150 }
151 DataType::UInt32 => {
152 let result: Int32Array = value_array
153 .as_primitive::<UInt32Type>()
154 .unary(|v| v.count_ones() as i32);
155 Ok(Arc::new(result))
156 }
157 DataType::UInt64 => {
158 let result: Int32Array = value_array
159 .as_primitive::<UInt64Type>()
160 .unary(|v| v.count_ones() as i32);
161 Ok(Arc::new(result))
162 }
163 _ => {
164 plan_err!(
165 "bit_count function does not support data type: {}",
166 value_array.data_type()
167 )
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use arrow::array::{
176 Array, BooleanArray, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array,
177 UInt16Array, UInt32Array, UInt64Array,
178 };
179 use arrow::datatypes::{Field, Int32Type};
180
181 #[test]
182 fn test_bit_count_basic() {
183 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
185
186 assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
187
188 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
190
191 assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
192
193 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
195
196 assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
197
198 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
200
201 assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
202 }
203
204 #[test]
205 fn test_bit_count_int8() {
206 let result =
208 spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
209 .unwrap();
210
211 let arr = result.as_primitive::<Int32Type>();
212 assert_eq!(arr.value(0), 0);
213 assert_eq!(arr.value(1), 1);
214 assert_eq!(arr.value(2), 2);
215 assert_eq!(arr.value(3), 3);
216 assert_eq!(arr.value(4), 4);
217 assert_eq!(arr.value(5), 64);
218 }
219
220 #[test]
221 fn test_bit_count_boolean() {
222 let result =
224 spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
225
226 let arr = result.as_primitive::<Int32Type>();
227 assert_eq!(arr.value(0), 1);
228 assert_eq!(arr.value(1), 0);
229 }
230
231 #[test]
232 fn test_bit_count_int16() {
233 let result =
235 spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
236 .unwrap();
237
238 let arr = result.as_primitive::<Int32Type>();
239 assert_eq!(arr.value(0), 0);
240 assert_eq!(arr.value(1), 1);
241 assert_eq!(arr.value(2), 8);
242 assert_eq!(arr.value(3), 10);
243 assert_eq!(arr.value(4), 64);
244 }
245
246 #[test]
247 fn test_bit_count_int32() {
248 let result =
250 spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
251 .unwrap();
252
253 let arr = result.as_primitive::<Int32Type>();
254 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
260
261 #[test]
262 fn test_bit_count_int64() {
263 let result =
265 spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
266 .unwrap();
267
268 let arr = result.as_primitive::<Int32Type>();
269 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
275
276 #[test]
277 fn test_bit_count_uint8() {
278 let result =
280 spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
281
282 let arr = result.as_primitive::<Int32Type>();
283 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); }
287
288 #[test]
289 fn test_bit_count_uint16() {
290 let result =
292 spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
293 .unwrap();
294
295 let arr = result.as_primitive::<Int32Type>();
296 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 16); }
301
302 #[test]
303 fn test_bit_count_uint32() {
304 let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
306 0u32, 1, 255, 4294967295,
307 ]))])
308 .unwrap();
309
310 let arr = result.as_primitive::<Int32Type>();
311 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 32); }
316
317 #[test]
318 fn test_bit_count_uint64() {
319 let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
321 0u64,
322 1,
323 255,
324 256,
325 u64::MAX,
326 ]))])
327 .unwrap();
328
329 let arr = result.as_primitive::<Int32Type>();
330 assert_eq!(arr.value(0), 0);
332 assert_eq!(arr.value(1), 1);
334 assert_eq!(arr.value(2), 8);
336 assert_eq!(arr.value(3), 1);
338 assert_eq!(arr.value(4), 64);
340 }
341
342 #[test]
343 fn test_bit_count_nulls() {
344 let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
346 let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
347 let arr = result.as_primitive::<Int32Type>();
348 assert_eq!(arr.value(0), 2); assert!(arr.is_null(1));
350 assert_eq!(arr.value(2), 3); }
352
353 #[test]
354 fn test_bit_count_nullability() -> Result<()> {
355 use datafusion_expr::ReturnFieldArgs;
356
357 let bit_count = SparkBitCount::new();
358
359 let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false));
361
362 let result = bit_count.return_field_from_args(ReturnFieldArgs {
363 arg_fields: &[Arc::clone(&non_nullable_field)],
364 scalar_arguments: &[None],
365 })?;
366
367 assert!(!result.is_nullable());
369 assert_eq!(result.data_type(), &DataType::Int32);
370
371 let nullable_field = Arc::new(Field::new("num", DataType::Int32, true));
373
374 let result = bit_count.return_field_from_args(ReturnFieldArgs {
375 arg_fields: &[Arc::clone(&nullable_field)],
376 scalar_arguments: &[None],
377 })?;
378
379 assert!(result.is_nullable());
381 assert_eq!(result.data_type(), &DataType::Int32);
382
383 Ok(())
384 }
385}