1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray};
22use arrow::compute;
23use arrow::datatypes::{
24 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, UInt32Type,
25 UInt64Type,
26};
27use datafusion_common::types::{
28 NativeType, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8,
29 logical_uint16, logical_uint32, logical_uint64,
30};
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{Result, internal_err};
33use datafusion_expr::{
34 Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
35 Signature, TypeSignature, TypeSignatureClass, Volatility,
36};
37use datafusion_functions::utils::make_scalar_function;
38
39fn shift_left<T>(
43 value: &PrimitiveArray<T>,
44 shift: &Int32Array,
45) -> Result<PrimitiveArray<T>>
46where
47 T: ArrowPrimitiveType,
48 T::Native: std::ops::Shl<i32, Output = T::Native>,
49{
50 let bit_num = (T::Native::get_byte_width() * 8) as i32;
51 let result = compute::binary::<_, Int32Type, _, _>(
52 value,
53 shift,
54 |value: T::Native, shift: i32| {
55 let shift = ((shift % bit_num) + bit_num) % bit_num;
56 value << shift
57 },
58 )?;
59 Ok(result)
60}
61
62fn shift_right<T>(
66 value: &PrimitiveArray<T>,
67 shift: &Int32Array,
68) -> Result<PrimitiveArray<T>>
69where
70 T: ArrowPrimitiveType,
71 T::Native: std::ops::Shr<i32, Output = T::Native>,
72{
73 let bit_num = (T::Native::get_byte_width() * 8) as i32;
74 let result = compute::binary::<_, Int32Type, _, _>(
75 value,
76 shift,
77 |value: T::Native, shift: i32| {
78 let shift = ((shift % bit_num) + bit_num) % bit_num;
79 value >> shift
80 },
81 )?;
82 Ok(result)
83}
84
85trait UShr {
90 fn ushr(self, rhs: i32) -> Self;
91}
92
93impl UShr for u32 {
94 fn ushr(self, rhs: i32) -> Self {
95 self >> rhs
96 }
97}
98
99impl UShr for u64 {
100 fn ushr(self, rhs: i32) -> Self {
101 self >> rhs
102 }
103}
104
105impl UShr for i32 {
106 fn ushr(self, rhs: i32) -> Self {
107 ((self as u32) >> rhs) as i32
108 }
109}
110
111impl UShr for i64 {
112 fn ushr(self, rhs: i32) -> Self {
113 ((self as u64) >> rhs) as i64
114 }
115}
116
117fn shift_right_unsigned<T>(
121 value: &PrimitiveArray<T>,
122 shift: &Int32Array,
123) -> Result<PrimitiveArray<T>>
124where
125 T: ArrowPrimitiveType,
126 T::Native: UShr,
127{
128 let bit_num = (T::Native::get_byte_width() * 8) as i32;
129 let result = compute::binary::<_, Int32Type, _, _>(
130 value,
131 shift,
132 |value: T::Native, shift: i32| {
133 let shift = ((shift % bit_num) + bit_num) % bit_num;
134 value.ushr(shift)
135 },
136 )?;
137 Ok(result)
138}
139
140fn shift_inner(
141 arrays: &[ArrayRef],
142 name: &str,
143 bit_shift_type: BitShiftType,
144) -> Result<ArrayRef> {
145 let [value_array, shift_array] = take_function_args(name, arrays)?;
146 let shift_array = shift_array.as_primitive::<Int32Type>();
147
148 fn shift<T>(
149 value: &PrimitiveArray<T>,
150 shift: &Int32Array,
151 bit_shift_type: BitShiftType,
152 ) -> Result<PrimitiveArray<T>>
153 where
154 T: ArrowPrimitiveType,
155 T::Native: std::ops::Shl<i32, Output = T::Native>
156 + std::ops::Shr<i32, Output = T::Native>
157 + UShr,
158 {
159 match bit_shift_type {
160 BitShiftType::Left => shift_left(value, shift),
161 BitShiftType::Right => shift_right(value, shift),
162 BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
163 }
164 }
165
166 match value_array.data_type() {
167 DataType::Int32 => {
168 let value_array = value_array.as_primitive::<Int32Type>();
169 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
170 }
171 DataType::Int64 => {
172 let value_array = value_array.as_primitive::<Int64Type>();
173 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
174 }
175 DataType::UInt32 => {
176 let value_array = value_array.as_primitive::<UInt32Type>();
177 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
178 }
179 DataType::UInt64 => {
180 let value_array = value_array.as_primitive::<UInt64Type>();
181 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
182 }
183 dt => {
184 internal_err!("{name} function does not support data type: {dt}")
185 }
186 }
187}
188
189#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
190enum BitShiftType {
191 Left,
192 Right,
193 RightUnsigned,
194}
195
196#[derive(Debug, Hash, Eq, PartialEq)]
197pub struct SparkBitShift {
198 signature: Signature,
199 name: &'static str,
200 bit_shift_type: BitShiftType,
201}
202
203impl SparkBitShift {
204 fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
205 let shift_amount = Coercion::new_implicit(
206 TypeSignatureClass::Native(logical_int32()),
207 vec![TypeSignatureClass::Integer],
208 NativeType::Int32,
209 );
210 Self {
211 signature: Signature::one_of(
212 vec![
213 TypeSignature::Coercible(vec![
215 Coercion::new_implicit(
216 TypeSignatureClass::Native(logical_int32()),
217 vec![
218 TypeSignatureClass::Native(logical_int8()),
219 TypeSignatureClass::Native(logical_int16()),
220 ],
221 NativeType::Int32,
222 ),
223 shift_amount.clone(),
224 ]),
225 TypeSignature::Coercible(vec![
226 Coercion::new_implicit(
227 TypeSignatureClass::Native(logical_uint32()),
228 vec![
229 TypeSignatureClass::Native(logical_uint8()),
230 TypeSignatureClass::Native(logical_uint16()),
231 ],
232 NativeType::UInt32,
233 ),
234 shift_amount.clone(),
235 ]),
236 TypeSignature::Coercible(vec![
238 Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
239 shift_amount.clone(),
240 ]),
241 TypeSignature::Coercible(vec![
242 Coercion::new_exact(TypeSignatureClass::Native(logical_uint64())),
243 shift_amount.clone(),
244 ]),
245 ],
246 Volatility::Immutable,
247 ),
248 name,
249 bit_shift_type,
250 }
251 }
252
253 pub fn left() -> Self {
254 Self::new("shiftleft", BitShiftType::Left)
255 }
256
257 pub fn right() -> Self {
258 Self::new("shiftright", BitShiftType::Right)
259 }
260
261 pub fn right_unsigned() -> Self {
262 Self::new("shiftrightunsigned", BitShiftType::RightUnsigned)
263 }
264}
265
266impl ScalarUDFImpl for SparkBitShift {
267 fn as_any(&self) -> &dyn Any {
268 self
269 }
270
271 fn name(&self) -> &str {
272 self.name
273 }
274
275 fn signature(&self) -> &Signature {
276 &self.signature
277 }
278
279 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
280 internal_err!("return_field_from_args should be used instead")
281 }
282
283 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
284 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
285 let data_type = args.arg_fields[0].data_type().clone();
286 Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
287 }
288
289 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
290 let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> {
291 shift_inner(arr, self.name(), self.bit_shift_type)
292 };
293 make_scalar_function(inner, vec![])(&args.args)
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use arrow::datatypes::Field;
301 use datafusion_expr::ReturnFieldArgs;
302
303 #[test]
304 fn test_bit_shift_nullability() -> Result<()> {
305 let func = SparkBitShift::left();
306
307 let non_nullable_value: FieldRef =
308 Arc::new(Field::new("value", DataType::Int64, false));
309 let non_nullable_shift: FieldRef =
310 Arc::new(Field::new("shift", DataType::Int32, false));
311
312 let out = func.return_field_from_args(ReturnFieldArgs {
313 arg_fields: &[
314 Arc::clone(&non_nullable_value),
315 Arc::clone(&non_nullable_shift),
316 ],
317 scalar_arguments: &[None, None],
318 })?;
319
320 assert_eq!(out.data_type(), non_nullable_value.data_type());
321 assert!(
322 !out.is_nullable(),
323 "shift result should be non-nullable when both inputs are non-nullable"
324 );
325
326 let nullable_value: FieldRef =
327 Arc::new(Field::new("value", DataType::Int64, true));
328 let out_nullable_value = func.return_field_from_args(ReturnFieldArgs {
329 arg_fields: &[Arc::clone(&nullable_value), Arc::clone(&non_nullable_shift)],
330 scalar_arguments: &[None, None],
331 })?;
332 assert!(
333 out_nullable_value.is_nullable(),
334 "shift result should be nullable when value is nullable"
335 );
336
337 let nullable_shift: FieldRef =
338 Arc::new(Field::new("shift", DataType::Int32, true));
339 let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs {
340 arg_fields: &[non_nullable_value, nullable_shift],
341 scalar_arguments: &[None, None],
342 })?;
343 assert!(
344 out_nullable_shift.is_nullable(),
345 "shift result should be nullable when shift is nullable"
346 );
347
348 Ok(())
349 }
350}