datafusion_comet_spark_expr/bitwise_funcs/
bitwise_get.rs1use arrow::{array::*, datatypes::DataType};
19use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue};
20use datafusion::logical_expr::ColumnarValue;
21use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
22use std::any::Any;
23use std::sync::Arc;
24
25#[derive(Debug)]
26pub struct SparkBitwiseGet {
27 signature: Signature,
28 aliases: Vec<String>,
29}
30
31impl Default for SparkBitwiseGet {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl SparkBitwiseGet {
38 pub fn new() -> Self {
39 Self {
40 signature: Signature::user_defined(Volatility::Immutable),
41 aliases: vec![],
42 }
43 }
44}
45
46impl ScalarUDFImpl for SparkBitwiseGet {
47 fn as_any(&self) -> &dyn Any {
48 self
49 }
50
51 fn name(&self) -> &str {
52 "bit_get"
53 }
54
55 fn signature(&self) -> &Signature {
56 &self.signature
57 }
58
59 fn aliases(&self) -> &[String] {
60 &self.aliases
61 }
62
63 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
64 Ok(DataType::Int8)
65 }
66
67 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
68 let args: [ColumnarValue; 2] = args
69 .args
70 .try_into()
71 .map_err(|_| internal_datafusion_err!("bit_get expects exactly two arguments"))?;
72 spark_bit_get(&args)
73 }
74}
75
76macro_rules! bit_get_scalar_position {
77 ($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{
78 if let Some(pos) = $pos {
79 check_position(*pos, $bit_size as i32)?;
80 }
81 let args = $args
82 .as_any()
83 .downcast_ref::<$array_type>()
84 .expect("bit_get_scalar_position failed to downcast array");
85
86 let result: Int8Array = args
87 .iter()
88 .map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos))))
89 .collect();
90
91 Ok(Arc::new(result))
92 }};
93}
94
95macro_rules! bit_get_array_positions {
96 ($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{
97 let args = $args
98 .as_any()
99 .downcast_ref::<$array_type>()
100 .expect("bit_get_array_positions failed to downcast args array");
101
102 let positions = $positions
103 .as_any()
104 .downcast_ref::<Int32Array>()
105 .expect("bit_get_array_positions failed to downcast positions array");
106
107 for pos in positions.iter().flatten() {
108 check_position(pos, $bit_size as i32)?
109 }
110
111 let result: Int8Array = args
112 .iter()
113 .zip(positions.iter())
114 .map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p))))
115 .collect();
116
117 Ok(Arc::new(result))
118 }};
119}
120
121pub fn spark_bit_get(args: &[ColumnarValue; 2]) -> Result<ColumnarValue> {
122 match args {
123 [ColumnarValue::Array(args), ColumnarValue::Scalar(ScalarValue::Int32(pos))] => {
124 let result: Result<ArrayRef> = match args.data_type() {
125 DataType::Int8 => bit_get_scalar_position!(args, Int8Array, pos, i8::BITS),
126 DataType::Int16 => bit_get_scalar_position!(args, Int16Array, pos, i16::BITS),
127 DataType::Int32 => bit_get_scalar_position!(args, Int32Array, pos, i32::BITS),
128 DataType::Int64 => bit_get_scalar_position!(args, Int64Array, pos, i64::BITS),
129 _ => exec_err!(
130 "Can't be evaluated because the expression's type is {:?}, not signed int",
131 args.data_type()
132 ),
133 };
134 result.map(ColumnarValue::Array)
135 },
136 [ColumnarValue::Array(args), ColumnarValue::Array(positions)] => {
137 if args.len() != positions.len() {
138 return exec_err!(
139 "Input arrays must have equal length. Positions array has {} elements, but arguments array has {} elements",
140 positions.len(), args.len()
141 );
142 }
143 if !matches!(positions.data_type(), DataType::Int32) {
144 return exec_err!(
145 "Invalid data type for positions array: expected `Int32`, found `{}`",
146 positions.data_type()
147 );
148 }
149 let result: Result<ArrayRef> = match args.data_type() {
150 DataType::Int8 => bit_get_array_positions!(args, Int8Array, positions, i8::BITS),
151 DataType::Int16 => bit_get_array_positions!(args, Int16Array, positions, i16::BITS),
152 DataType::Int32 => bit_get_array_positions!(args, Int32Array, positions, i32::BITS),
153 DataType::Int64 => bit_get_array_positions!(args, Int64Array, positions, i64::BITS),
154 _ => exec_err!(
155 "Can't be evaluated because the expression's type is {:?}, not signed int",
156 args.data_type()
157 ),
158 };
159 result.map(ColumnarValue::Array)
160 }
161 _ => exec_err!(
162 "Invalid input to function bit_get. Expected (IntegralType array, Int32Scalar) or (IntegralType array, Int32Array)"
163 ),
164 }
165}
166
167fn bit_get(arg: i64, pos: i32) -> i8 {
168 ((arg >> pos) & 1) as i8
169}
170
171fn check_position(pos: i32, bit_size: i32) -> Result<()> {
172 if pos < 0 {
173 return exec_err!("Invalid bit position: {:?} is less than zero", pos);
174 }
175 if bit_size <= pos {
176 return exec_err!(
177 "Invalid bit position: {:?} exceeds the bit upper limit: {:?}",
178 pos,
179 bit_size
180 );
181 }
182 Ok(())
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use datafusion::common::cast::as_int8_array;
189
190 #[test]
191 fn bitwise_get_scalar_position() -> Result<()> {
192 let args = [
193 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
194 Some(1),
195 None,
196 Some(1234553454),
197 ]))),
198 ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
199 ];
200
201 let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
202
203 let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
204 unreachable!()
205 };
206
207 let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
208
209 assert_eq!(result, expected);
210
211 Ok(())
212 }
213
214 #[test]
215 fn bitwise_get_scalar_negative_position() -> Result<()> {
216 let args = [
217 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
218 Some(1),
219 None,
220 Some(1234553454),
221 ]))),
222 ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
223 ];
224
225 let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
226 let result = spark_bit_get(&args).err().unwrap().to_string();
227
228 assert_eq!(result, expected);
229
230 Ok(())
231 }
232
233 #[test]
234 fn bitwise_get_scalar_overflow_position() -> Result<()> {
235 let args = [
236 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
237 Some(1),
238 None,
239 Some(1234553454),
240 ]))),
241 ColumnarValue::Scalar(ScalarValue::Int32(Some(33))),
242 ];
243
244 let expected = String::from(
245 "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
246 );
247 let result = spark_bit_get(&args).err().unwrap().to_string();
248
249 assert_eq!(result, expected);
250
251 Ok(())
252 }
253
254 #[test]
255 fn bitwise_get_array_positions() -> Result<()> {
256 let args = [
257 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
258 Some(1),
259 None,
260 Some(1234553454),
261 ]))),
262 ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, Some(1)]))),
263 ];
264
265 let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
266
267 let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
268 unreachable!()
269 };
270
271 let result = as_int8_array(&result).expect("failed to downcast to Int8Array");
272
273 assert_eq!(result, expected);
274
275 Ok(())
276 }
277
278 #[test]
279 fn bitwise_get_array_positions_contains_negative() -> Result<()> {
280 let args = [
281 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
282 Some(1),
283 None,
284 Some(1234553454),
285 ]))),
286 ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), None, Some(1)]))),
287 ];
288
289 let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
290 let result = spark_bit_get(&args).err().unwrap().to_string();
291
292 assert_eq!(result, expected);
293
294 Ok(())
295 }
296
297 #[test]
298 fn bitwise_get_array_positions_contains_overflow() -> Result<()> {
299 let args = [
300 ColumnarValue::Array(Arc::new(Int32Array::from(vec![
301 Some(1),
302 None,
303 Some(1234553454),
304 ]))),
305 ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), None, Some(1)]))),
306 ];
307
308 let expected = String::from(
309 "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
310 );
311 let result = spark_bit_get(&args).err().unwrap().to_string();
312
313 assert_eq!(result, expected);
314
315 Ok(())
316 }
317}