datafusion_spark/function/bitwise/
bit_get.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::mem::size_of;
20use std::sync::Arc;
21
22use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
23use arrow::compute::try_binary;
24use arrow::datatypes::DataType::{
25    Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8,
26};
27use arrow::datatypes::{
28    ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
29    UInt32Type, UInt64Type, UInt8Type,
30};
31use datafusion_common::{exec_err, Result};
32use datafusion_expr::{
33    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
34};
35use datafusion_functions::utils::make_scalar_function;
36
37use crate::function::error_utils::{
38    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
39};
40
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkBitGet {
43    signature: Signature,
44    aliases: Vec<String>,
45}
46
47impl Default for SparkBitGet {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkBitGet {
54    pub fn new() -> Self {
55        Self {
56            signature: Signature::user_defined(Volatility::Immutable),
57            aliases: vec!["getbit".to_string()],
58        }
59    }
60}
61
62impl ScalarUDFImpl for SparkBitGet {
63    fn as_any(&self) -> &dyn Any {
64        self
65    }
66
67    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
68        if arg_types.len() != 2 {
69            return Err(invalid_arg_count_exec_err(
70                "bit_get",
71                (2, 2),
72                arg_types.len(),
73            ));
74        }
75        if !arg_types[0].is_integer() && !arg_types[0].is_null() {
76            return Err(unsupported_data_type_exec_err(
77                "bit_get",
78                "Integer Type",
79                &arg_types[0],
80            ));
81        }
82        if !arg_types[1].is_integer() && !arg_types[1].is_null() {
83            return Err(unsupported_data_type_exec_err(
84                "bit_get",
85                "Integer Type",
86                &arg_types[1],
87            ));
88        }
89        if arg_types[0].is_null() {
90            return Ok(vec![Int8, Int32]);
91        }
92        Ok(vec![arg_types[0].clone(), Int32])
93    }
94
95    fn name(&self) -> &str {
96        "bit_get"
97    }
98
99    fn aliases(&self) -> &[String] {
100        &self.aliases
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108        Ok(Int8)
109    }
110
111    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
112        make_scalar_function(spark_bit_get, vec![])(&args.args)
113    }
114}
115
116fn spark_bit_get_inner<T: ArrowPrimitiveType>(
117    value: &PrimitiveArray<T>,
118    pos: &PrimitiveArray<Int32Type>,
119) -> Result<PrimitiveArray<Int8Type>> {
120    let bit_length = (size_of::<T::Native>() * 8) as i32;
121
122    let result: PrimitiveArray<Int8Type> = try_binary(value, pos, |value, pos| {
123        if pos < 0 || pos >= bit_length {
124            return Err(arrow::error::ArrowError::ComputeError(format!(
125                "bit_get: position {pos} is out of bounds. Expected pos < {bit_length} and pos >= 0"
126            )));
127        }
128        Ok(((value.to_i64().unwrap() >> pos) & 1) as i8)
129    })?;
130    Ok(result)
131}
132
133pub fn spark_bit_get(args: &[ArrayRef]) -> Result<ArrayRef> {
134    if args.len() != 2 {
135        return exec_err!("`bit_get` expects exactly two arguments");
136    }
137
138    if args[1].data_type() != &Int32 {
139        return exec_err!("`bit_get` expects Int32 as the second argument");
140    }
141
142    let pos_arg = args[1].as_primitive::<Int32Type>();
143
144    let ret = match &args[0].data_type() {
145        Int64 => {
146            let value_arg = args[0].as_primitive::<Int64Type>();
147            spark_bit_get_inner(value_arg, pos_arg)
148        }
149        Int32 => {
150            let value_arg = args[0].as_primitive::<Int32Type>();
151            spark_bit_get_inner(value_arg, pos_arg)
152        }
153        Int16 => {
154            let value_arg = args[0].as_primitive::<Int16Type>();
155            spark_bit_get_inner(value_arg, pos_arg)
156        }
157        Int8 => {
158            let value_arg = args[0].as_primitive::<Int8Type>();
159            spark_bit_get_inner(value_arg, pos_arg)
160        }
161        UInt64 => {
162            let value_arg = args[0].as_primitive::<UInt64Type>();
163            spark_bit_get_inner(value_arg, pos_arg)
164        }
165        UInt32 => {
166            let value_arg = args[0].as_primitive::<UInt32Type>();
167            spark_bit_get_inner(value_arg, pos_arg)
168        }
169        UInt16 => {
170            let value_arg = args[0].as_primitive::<UInt16Type>();
171            spark_bit_get_inner(value_arg, pos_arg)
172        }
173        UInt8 => {
174            let value_arg = args[0].as_primitive::<UInt8Type>();
175            spark_bit_get_inner(value_arg, pos_arg)
176        }
177        _ => {
178            exec_err!(
179                "`bit_get` expects Int64, Int32, Int16, or Int8 as the first argument"
180            )
181        }
182    }?;
183    Ok(Arc::new(ret))
184}
185
186#[cfg(test)]
187mod tests {
188    use arrow::array::{Int32Array, Int64Array};
189
190    use super::*;
191
192    #[test]
193    fn test_bit_get_basic() {
194        // Test bit_get(11, 0) - 11 = 1011 in binary, bit 0 = 1
195        let result = spark_bit_get(&[
196            Arc::new(Int64Array::from(vec![11])),
197            Arc::new(Int32Array::from(vec![0])),
198        ])
199        .unwrap();
200
201        assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
202
203        // Test bit_get(11, 2) - 11 = 1011 in binary, bit 2 = 0
204        let result = spark_bit_get(&[
205            Arc::new(Int64Array::from(vec![11])),
206            Arc::new(Int32Array::from(vec![2])),
207        ])
208        .unwrap();
209
210        assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
211
212        // Test bit_get(11, 3) - 11 = 1011 in binary, bit 3 = 1
213        let result = spark_bit_get(&[
214            Arc::new(Int64Array::from(vec![11])),
215            Arc::new(Int32Array::from(vec![3])),
216        ])
217        .unwrap();
218
219        assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
220    }
221
222    #[test]
223    fn test_bit_get_edge_cases() {
224        // Test with 0
225        let result = spark_bit_get(&[
226            Arc::new(Int64Array::from(vec![0])),
227            Arc::new(Int32Array::from(vec![0])),
228        ])
229        .unwrap();
230
231        assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
232
233        let result = spark_bit_get(&[
234            Arc::new(Int64Array::from(vec![11])),
235            Arc::new(Int32Array::from(vec![-1])),
236        ]);
237        assert_eq!(
238            result.unwrap_err().message().lines().next().unwrap(),
239            "Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0"
240        );
241
242        let result = spark_bit_get(&[
243            Arc::new(Int64Array::from(vec![11])),
244            Arc::new(Int32Array::from(vec![64])),
245        ]);
246
247        assert_eq!(
248            result.unwrap_err().message().lines().next().unwrap(),
249            "Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0"
250        );
251    }
252
253    #[test]
254    fn test_bit_get_null_inputs() {
255        // Test with NULL value
256        let result = spark_bit_get(&[
257            Arc::new(Int64Array::from(vec![None])),
258            Arc::new(Int32Array::from(vec![0])),
259        ])
260        .unwrap();
261
262        assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
263
264        // Test with NULL position
265        let result = spark_bit_get(&[
266            Arc::new(Int64Array::from(vec![11])),
267            Arc::new(Int32Array::from(vec![None])),
268        ])
269        .unwrap();
270
271        assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
272    }
273
274    #[test]
275    fn test_bit_get_large_numbers() {
276        // Test with larger number
277        let result = spark_bit_get(&[
278            Arc::new(Int64Array::from(vec![255])), // 11111111 in binary
279            Arc::new(Int32Array::from(vec![7])),   // bit 7 = 1
280        ])
281        .unwrap();
282
283        assert_eq!(result.as_primitive::<Int8Type>().value(0), 1);
284
285        let result = spark_bit_get(&[
286            Arc::new(Int64Array::from(vec![255])), // 11111111 in binary
287            Arc::new(Int32Array::from(vec![8])),   // bit 8 = 0
288        ])
289        .unwrap();
290
291        assert_eq!(result.as_primitive::<Int8Type>().value(0), 0);
292    }
293}