datafusion_spark/function/bitwise/
bit_count.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, Int32Array};
22use arrow::datatypes::{
23 DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
24 UInt64Type, UInt8Type,
25};
26use datafusion_common::{plan_err, Result};
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
29 Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkBitCount {
35 signature: Signature,
36}
37
38impl Default for SparkBitCount {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkBitCount {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::one_of(
48 vec![
49 TypeSignature::Exact(vec![DataType::Int8]),
50 TypeSignature::Exact(vec![DataType::Int16]),
51 TypeSignature::Exact(vec![DataType::Int32]),
52 TypeSignature::Exact(vec![DataType::Int64]),
53 TypeSignature::Exact(vec![DataType::UInt8]),
54 TypeSignature::Exact(vec![DataType::UInt16]),
55 TypeSignature::Exact(vec![DataType::UInt32]),
56 TypeSignature::Exact(vec![DataType::UInt64]),
57 ],
58 Volatility::Immutable,
59 ),
60 }
61 }
62}
63
64impl ScalarUDFImpl for SparkBitCount {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn name(&self) -> &str {
70 "bit_count"
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
78 Ok(DataType::Int32) }
80
81 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82 if args.args.len() != 1 {
83 return plan_err!("bit_count expects exactly 1 argument");
84 }
85
86 make_scalar_function(spark_bit_count, vec![])(&args.args)
87 }
88}
89
90fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
91 let value_array = value_array[0].as_ref();
92 match value_array.data_type() {
93 DataType::Int8 => {
94 let result: Int32Array = value_array
95 .as_primitive::<Int8Type>()
96 .unary(|v| v.count_ones() as i32);
97 Ok(Arc::new(result))
98 }
99 DataType::Int16 => {
100 let result: Int32Array = value_array
101 .as_primitive::<Int16Type>()
102 .unary(|v| v.count_ones() as i32);
103 Ok(Arc::new(result))
104 }
105 DataType::Int32 => {
106 let result: Int32Array = value_array
107 .as_primitive::<Int32Type>()
108 .unary(|v| v.count_ones() as i32);
109 Ok(Arc::new(result))
110 }
111 DataType::Int64 => {
112 let result: Int32Array = value_array
113 .as_primitive::<Int64Type>()
114 .unary(|v| v.count_ones() as i32);
115 Ok(Arc::new(result))
116 }
117 DataType::UInt8 => {
118 let result: Int32Array = value_array
119 .as_primitive::<UInt8Type>()
120 .unary(|v| v.count_ones() as i32);
121 Ok(Arc::new(result))
122 }
123 DataType::UInt16 => {
124 let result: Int32Array = value_array
125 .as_primitive::<UInt16Type>()
126 .unary(|v| v.count_ones() as i32);
127 Ok(Arc::new(result))
128 }
129 DataType::UInt32 => {
130 let result: Int32Array = value_array
131 .as_primitive::<UInt32Type>()
132 .unary(|v| v.count_ones() as i32);
133 Ok(Arc::new(result))
134 }
135 DataType::UInt64 => {
136 let result: Int32Array = value_array
137 .as_primitive::<UInt64Type>()
138 .unary(|v| v.count_ones() as i32);
139 Ok(Arc::new(result))
140 }
141 _ => {
142 plan_err!(
143 "bit_count function does not support data type: {:?}",
144 value_array.data_type()
145 )
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use arrow::array::{
154 Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array,
155 UInt64Array, UInt8Array,
156 };
157 use arrow::datatypes::Int32Type;
158
159 #[test]
160 fn test_bit_count_basic() {
161 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap();
163
164 assert_eq!(result.as_primitive::<Int32Type>().value(0), 0);
165
166 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap();
168
169 assert_eq!(result.as_primitive::<Int32Type>().value(0), 1);
170
171 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap();
173
174 assert_eq!(result.as_primitive::<Int32Type>().value(0), 3);
175
176 let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap();
178
179 assert_eq!(result.as_primitive::<Int32Type>().value(0), 4);
180 }
181
182 #[test]
183 fn test_bit_count_int8() {
184 let result =
186 spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))])
187 .unwrap();
188
189 let arr = result.as_primitive::<Int32Type>();
190 assert_eq!(arr.value(0), 0);
191 assert_eq!(arr.value(1), 1);
192 assert_eq!(arr.value(2), 2);
193 assert_eq!(arr.value(3), 3);
194 assert_eq!(arr.value(4), 4);
195 assert_eq!(arr.value(5), 8);
196 }
197
198 #[test]
199 fn test_bit_count_int16() {
200 let result =
202 spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))])
203 .unwrap();
204
205 let arr = result.as_primitive::<Int32Type>();
206 assert_eq!(arr.value(0), 0);
207 assert_eq!(arr.value(1), 1);
208 assert_eq!(arr.value(2), 8);
209 assert_eq!(arr.value(3), 10);
210 assert_eq!(arr.value(4), 16);
211 }
212
213 #[test]
214 fn test_bit_count_int32() {
215 let result =
217 spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))])
218 .unwrap();
219
220 let arr = result.as_primitive::<Int32Type>();
221 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 32); }
227
228 #[test]
229 fn test_bit_count_int64() {
230 let result =
232 spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))])
233 .unwrap();
234
235 let arr = result.as_primitive::<Int32Type>();
236 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 10); assert_eq!(arr.value(4), 64); }
242
243 #[test]
244 fn test_bit_count_uint8() {
245 let result =
247 spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap();
248
249 let arr = result.as_primitive::<Int32Type>();
250 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); }
254
255 #[test]
256 fn test_bit_count_uint16() {
257 let result =
259 spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))])
260 .unwrap();
261
262 let arr = result.as_primitive::<Int32Type>();
263 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 16); }
268
269 #[test]
270 fn test_bit_count_uint32() {
271 let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![
273 0u32, 1, 255, 4294967295,
274 ]))])
275 .unwrap();
276
277 let arr = result.as_primitive::<Int32Type>();
278 assert_eq!(arr.value(0), 0); assert_eq!(arr.value(1), 1); assert_eq!(arr.value(2), 8); assert_eq!(arr.value(3), 32); }
283
284 #[test]
285 fn test_bit_count_uint64() {
286 let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![
288 0u64,
289 1,
290 255,
291 256,
292 u64::MAX,
293 ]))])
294 .unwrap();
295
296 let arr = result.as_primitive::<Int32Type>();
297 assert_eq!(arr.value(0), 0);
299 assert_eq!(arr.value(1), 1);
301 assert_eq!(arr.value(2), 8);
303 assert_eq!(arr.value(3), 1);
305 assert_eq!(arr.value(4), 64);
307 }
308
309 #[test]
310 fn test_bit_count_nulls() {
311 let arr = Int32Array::from(vec![Some(3), None, Some(7)]);
313 let result = spark_bit_count(&[Arc::new(arr)]).unwrap();
314 let arr = result.as_primitive::<Int32Type>();
315 assert_eq!(arr.value(0), 2); assert!(arr.is_null(1));
317 assert_eq!(arr.value(2), 3); }
319}