datafusion_comet_spark_expr/bitwise_funcs/
bitwise_get.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{array::*, datatypes::DataType};
19use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue};
20use datafusion::logical_expr::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
22use std::any::Any;
23use std::sync::Arc;
24
25#[derive(Debug)]
26pub struct SparkBitwiseGet {
27    signature: Signature,
28    aliases: Vec<String>,
29}
30
31impl Default for SparkBitwiseGet {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl SparkBitwiseGet {
38    pub fn new() -> Self {
39        Self {
40            signature: Signature::user_defined(Volatility::Immutable),
41            aliases: vec![],
42        }
43    }
44}
45
46impl ScalarUDFImpl for SparkBitwiseGet {
47    fn as_any(&self) -> &dyn Any {
48        self
49    }
50
51    fn name(&self) -> &str {
52        "bit_get"
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn aliases(&self) -> &[String] {
60        &self.aliases
61    }
62
63    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
64        Ok(DataType::Int8)
65    }
66
67    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
68        let args: [ColumnarValue; 2] = args
69            .args
70            .try_into()
71            .map_err(|_| internal_datafusion_err!("bit_get expects exactly two arguments"))?;
72        spark_bit_get(&args)
73    }
74}
75
76macro_rules! bit_get_scalar_position {
77    ($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{
78        if let Some(pos) = $pos {
79            check_position(*pos, $bit_size as i32)?;
80        }
81        let args = $args
82            .as_any()
83            .downcast_ref::<$array_type>()
84            .expect("bit_get_scalar_position failed to downcast array");
85
86        let result: Int8Array = args
87            .iter()
88            .map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos))))
89            .collect();
90
91        Ok(Arc::new(result))
92    }};
93}
94
95macro_rules! bit_get_array_positions {
96    ($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{
97        let args = $args
98            .as_any()
99            .downcast_ref::<$array_type>()
100            .expect("bit_get_array_positions failed to downcast args array");
101
102        let positions = $positions
103            .as_any()
104            .downcast_ref::<Int32Array>()
105            .expect("bit_get_array_positions failed to downcast positions array");
106
107        for pos in positions.iter().flatten() {
108            check_position(pos, $bit_size as i32)?
109        }
110
111        let result: Int8Array = args
112            .iter()
113            .zip(positions.iter())
114            .map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p))))
115            .collect();
116
117        Ok(Arc::new(result))
118    }};
119}
120
121pub fn spark_bit_get(args: &[ColumnarValue; 2]) -> Result<ColumnarValue> {
122    match args {
123        [ColumnarValue::Array(args), ColumnarValue::Scalar(ScalarValue::Int32(pos))] => {
124            let result: Result<ArrayRef> = match args.data_type() {
125                DataType::Int8 => bit_get_scalar_position!(args, Int8Array, pos, i8::BITS),
126                DataType::Int16 => bit_get_scalar_position!(args, Int16Array, pos, i16::BITS),
127                DataType::Int32 => bit_get_scalar_position!(args, Int32Array, pos, i32::BITS),
128                DataType::Int64 => bit_get_scalar_position!(args, Int64Array, pos, i64::BITS),
129                _ => exec_err!(
130                    "Can't be evaluated because the expression's type is {:?}, not signed int",
131                    args.data_type()
132                ),
133            };
134            result.map(ColumnarValue::Array)
135        },
136        [ColumnarValue::Array(args), ColumnarValue::Array(positions)] => {
137            if args.len() != positions.len() {
138                return exec_err!(
139                    "Input arrays must have equal length. Positions array has {} elements, but arguments array has {} elements",
140                    positions.len(), args.len()
141                );
142            }
143            if !matches!(positions.data_type(), DataType::Int32) {
144                return exec_err!(
145                    "Invalid data type for positions array: expected `Int32`, found `{}`",
146                    positions.data_type()
147                );
148            }
149            let result: Result<ArrayRef> = match args.data_type() {
150                DataType::Int8 => bit_get_array_positions!(args, Int8Array, positions, i8::BITS),
151                DataType::Int16 => bit_get_array_positions!(args, Int16Array, positions, i16::BITS),
152                DataType::Int32 => bit_get_array_positions!(args, Int32Array, positions, i32::BITS),
153                DataType::Int64 => bit_get_array_positions!(args, Int64Array, positions, i64::BITS),
154                _ => exec_err!(
155                    "Can't be evaluated because the expression's type is {:?}, not signed int",
156                    args.data_type()
157                ),
158            };
159            result.map(ColumnarValue::Array)
160        }
161        _ => exec_err!(
162            "Invalid input to function bit_get. Expected (IntegralType array, Int32Scalar) or (IntegralType array, Int32Array)"
163        ),
164    }
165}
166
167fn bit_get(arg: i64, pos: i32) -> i8 {
168    ((arg >> pos) & 1) as i8
169}
170
171fn check_position(pos: i32, bit_size: i32) -> Result<()> {
172    if pos < 0 {
173        return exec_err!("Invalid bit position: {:?} is less than zero", pos);
174    }
175    if bit_size <= pos {
176        return exec_err!(
177            "Invalid bit position: {:?} exceeds the bit upper limit: {:?}",
178            pos,
179            bit_size
180        );
181    }
182    Ok(())
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use datafusion::common::cast::as_int8_array;
189
190    #[test]
191    fn bitwise_get_scalar_position() -> Result<()> {
192        let args = [
193            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
194                Some(1),
195                None,
196                Some(1234553454),
197            ]))),
198            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
199        ];
200
201        let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
202
203        let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
204            unreachable!()
205        };
206
207        let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
208
209        assert_eq!(result, expected);
210
211        Ok(())
212    }
213
214    #[test]
215    fn bitwise_get_scalar_negative_position() -> Result<()> {
216        let args = [
217            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
218                Some(1),
219                None,
220                Some(1234553454),
221            ]))),
222            ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
223        ];
224
225        let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
226        let result = spark_bit_get(&args).err().unwrap().to_string();
227
228        assert_eq!(result, expected);
229
230        Ok(())
231    }
232
233    #[test]
234    fn bitwise_get_scalar_overflow_position() -> Result<()> {
235        let args = [
236            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
237                Some(1),
238                None,
239                Some(1234553454),
240            ]))),
241            ColumnarValue::Scalar(ScalarValue::Int32(Some(33))),
242        ];
243
244        let expected = String::from(
245            "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
246        );
247        let result = spark_bit_get(&args).err().unwrap().to_string();
248
249        assert_eq!(result, expected);
250
251        Ok(())
252    }
253
254    #[test]
255    fn bitwise_get_array_positions() -> Result<()> {
256        let args = [
257            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
258                Some(1),
259                None,
260                Some(1234553454),
261            ]))),
262            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, Some(1)]))),
263        ];
264
265        let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
266
267        let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
268            unreachable!()
269        };
270
271        let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
272
273        assert_eq!(result, expected);
274
275        Ok(())
276    }
277
278    #[test]
279    fn bitwise_get_array_positions_contains_negative() -> Result<()> {
280        let args = [
281            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
282                Some(1),
283                None,
284                Some(1234553454),
285            ]))),
286            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), None, Some(1)]))),
287        ];
288
289        let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
290        let result = spark_bit_get(&args).err().unwrap().to_string();
291
292        assert_eq!(result, expected);
293
294        Ok(())
295    }
296
297    #[test]
298    fn bitwise_get_array_positions_contains_overflow() -> Result<()> {
299        let args = [
300            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
301                Some(1),
302                None,
303                Some(1234553454),
304            ]))),
305            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), None, Some(1)]))),
306        ];
307
308        let expected = String::from(
309            "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
310        );
311        let result = spark_bit_get(&args).err().unwrap().to_string();
312
313        assert_eq!(result, expected);
314
315        Ok(())
316    }
317}