datafusion_spark/function/bitwise/
bit_get.rs1use std::any::Any;
19use std::mem::size_of;
20use std::sync::Arc;
21
22use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
23use arrow::compute::try_binary;
24use arrow::datatypes::DataType::{
25 Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8,
26};
27use arrow::datatypes::{
28 ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
29 UInt32Type, UInt64Type, UInt8Type,
30};
31use datafusion_common::{exec_err, Result};
32use datafusion_expr::{
33 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
34};
35use datafusion_functions::utils::make_scalar_function;
36
37use crate::function::error_utils::{
38 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
39};
40
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkBitGet {
43 signature: Signature,
44 aliases: Vec<String>,
45}
46
47impl Default for SparkBitGet {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkBitGet {
54 pub fn new() -> Self {
55 Self {
56 signature: Signature::user_defined(Volatility::Immutable),
57 aliases: vec!["getbit".to_string()],
58 }
59 }
60}
61
62impl ScalarUDFImpl for SparkBitGet {
63 fn as_any(&self) -> &dyn Any {
64 self
65 }
66
67 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
68 if arg_types.len() != 2 {
69 return Err(invalid_arg_count_exec_err(
70 "bit_get",
71 (2, 2),
72 arg_types.len(),
73 ));
74 }
75 if !arg_types[0].is_integer() && !arg_types[0].is_null() {
76 return Err(unsupported_data_type_exec_err(
77 "bit_get",
78 "Integer Type",
79 &arg_types[0],
80 ));
81 }
82 if !arg_types[1].is_integer() && !arg_types[1].is_null() {
83 return Err(unsupported_data_type_exec_err(
84 "bit_get",
85 "Integer Type",
86 &arg_types[1],
87 ));
88 }
89 if arg_types[0].is_null() {
90 return Ok(vec![Int8, Int32]);
91 }
92 Ok(vec![arg_types[0].clone(), Int32])
93 }
94
95 fn name(&self) -> &str {
96 "bit_get"
97 }
98
99 fn aliases(&self) -> &[String] {
100 &self.aliases
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108 Ok(Int8)
109 }
110
111 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
112 make_scalar_function(spark_bit_get, vec![])(&args.args)
113 }
114}
115
116fn spark_bit_get_inner<T: ArrowPrimitiveType>(
117 value: &PrimitiveArray<T>,
118 pos: &PrimitiveArray<Int32Type>,
119) -> Result<PrimitiveArray<Int8Type>> {
120 let bit_length = (size_of::<T::Native>() * 8) as i32;
121
122 let result: PrimitiveArray<Int8Type> = try_binary(value, pos, |value, pos| {
123 if pos < 0 || pos >= bit_length {
124 return Err(arrow::error::ArrowError::ComputeError(format!(
125 "bit_get: position {pos} is out of bounds. Expected pos < {bit_length} and pos >= 0"
126 )));
127 }
128 Ok(((value.to_i64().unwrap() >> pos) & 1) as i8)
129 })?;
130 Ok(result)
131}
132
133pub fn spark_bit_get(args: &[ArrayRef]) -> Result<ArrayRef> {
134 if args.len() != 2 {
135 return exec_err!("`bit_get` expects exactly two arguments");
136 }
137
138 if args[1].data_type() != &Int32 {
139 return exec_err!("`bit_get` expects Int32 as the second argument");
140 }
141
142 let pos_arg = args[1].as_primitive::<Int32Type>();
143
144 let ret = match &args[0].data_type() {
145 Int64 => {
146 let value_arg = args[0].as_primitive::<Int64Type>();
147 spark_bit_get_inner(value_arg, pos_arg)
148 }
149 Int32 => {
150 let value_arg = args[0].as_primitive::<Int32Type>();
151 spark_bit_get_inner(value_arg, pos_arg)
152 }
153 Int16 => {
154 let value_arg = args[0].as_primitive::<Int16Type>();
155 spark_bit_get_inner(value_arg, pos_arg)
156 }
157 Int8 => {
158 let value_arg = args[0].as_primitive::<Int8Type>();
159 spark_bit_get_inner(value_arg, pos_arg)
160 }
161 UInt64 => {
162 let value_arg = args[0].as_primitive::<UInt64Type>();
163 spark_bit_get_inner(value_arg, pos_arg)
164 }
165 UInt32 => {
166 let value_arg = args[0].as_primitive::<UInt32Type>();
167 spark_bit_get_inner(value_arg, pos_arg)
168 }
169 UInt16 => {
170 let value_arg = args[0].as_primitive::<UInt16Type>();
171 spark_bit_get_inner(value_arg, pos_arg)
172 }
173 UInt8 => {
174 let value_arg = args[0].as_primitive::<UInt8Type>();
175 spark_bit_get_inner(value_arg, pos_arg)
176 }
177 _ => {
178 exec_err!(
179 "`bit_get` expects Int64, Int32, Int16, or Int8 as the first argument"
180 )
181 }
182 }?;
183 Ok(Arc::new(ret))
184}
185
186#[cfg(test)]
187mod tests {
188 use arrow::array::{Int32Array, Int64Array};
189
190 use super::*;
191
192 #[test]
193 fn test_bit_get_basic() {
194 let result = spark_bit_get(&[
196 Arc::new(Int64Array::from(vec![11])),
197 Arc::new(Int32Array::from(vec![0])),
198 ])
199 .unwrap();
200
201 assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
202
203 let result = spark_bit_get(&[
205 Arc::new(Int64Array::from(vec![11])),
206 Arc::new(Int32Array::from(vec![2])),
207 ])
208 .unwrap();
209
210 assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
211
212 let result = spark_bit_get(&[
214 Arc::new(Int64Array::from(vec![11])),
215 Arc::new(Int32Array::from(vec![3])),
216 ])
217 .unwrap();
218
219 assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
220 }
221
222 #[test]
223 fn test_bit_get_edge_cases() {
224 let result = spark_bit_get(&[
226 Arc::new(Int64Array::from(vec![0])),
227 Arc::new(Int32Array::from(vec![0])),
228 ])
229 .unwrap();
230
231 assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
232
233 let result = spark_bit_get(&[
234 Arc::new(Int64Array::from(vec![11])),
235 Arc::new(Int32Array::from(vec![-1])),
236 ]);
237 assert_eq!(
238 result.unwrap_err().message().lines().next().unwrap(),
239 "Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0"
240 );
241
242 let result = spark_bit_get(&[
243 Arc::new(Int64Array::from(vec![11])),
244 Arc::new(Int32Array::from(vec![64])),
245 ]);
246
247 assert_eq!(
248 result.unwrap_err().message().lines().next().unwrap(),
249 "Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0"
250 );
251 }
252
253 #[test]
254 fn test_bit_get_null_inputs() {
255 let result = spark_bit_get(&[
257 Arc::new(Int64Array::from(vec![None])),
258 Arc::new(Int32Array::from(vec![0])),
259 ])
260 .unwrap();
261
262 assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
263
264 let result = spark_bit_get(&[
266 Arc::new(Int64Array::from(vec![11])),
267 Arc::new(Int32Array::from(vec![None])),
268 ])
269 .unwrap();
270
271 assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
272 }
273
274 #[test]
275 fn test_bit_get_large_numbers() {
276 let result = spark_bit_get(&[
278 Arc::new(Int64Array::from(vec![255])), Arc::new(Int32Array::from(vec![7])), ])
281 .unwrap();
282
283 assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
284
285 let result = spark_bit_get(&[
286 Arc::new(Int64Array::from(vec![255])), Arc::new(Int32Array::from(vec![8])), ])
289 .unwrap();
290
291 assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
292 }
293}