1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray};
21use arrow::compute;
22use arrow::datatypes::{
23 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, UInt32Type,
24 UInt64Type,
25};
26use datafusion_common::types::{
27 NativeType, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8,
28 logical_uint16, logical_uint32, logical_uint64,
29};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{Result, internal_err};
32use datafusion_expr::{
33 Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
34 Signature, TypeSignature, TypeSignatureClass, Volatility,
35};
36use datafusion_functions::utils::make_scalar_function;
37
38fn shift_left<T>(
42 value: &PrimitiveArray<T>,
43 shift: &Int32Array,
44) -> Result<PrimitiveArray<T>>
45where
46 T: ArrowPrimitiveType,
47 T::Native: std::ops::Shl<i32, Output = T::Native>,
48{
49 let bit_num = (T::Native::get_byte_width() * 8) as i32;
50 let result = compute::binary::<_, Int32Type, _, _>(
51 value,
52 shift,
53 |value: T::Native, shift: i32| {
54 let shift = ((shift % bit_num) + bit_num) % bit_num;
55 value << shift
56 },
57 )?;
58 Ok(result)
59}
60
61fn shift_right<T>(
65 value: &PrimitiveArray<T>,
66 shift: &Int32Array,
67) -> Result<PrimitiveArray<T>>
68where
69 T: ArrowPrimitiveType,
70 T::Native: std::ops::Shr<i32, Output = T::Native>,
71{
72 let bit_num = (T::Native::get_byte_width() * 8) as i32;
73 let result = compute::binary::<_, Int32Type, _, _>(
74 value,
75 shift,
76 |value: T::Native, shift: i32| {
77 let shift = ((shift % bit_num) + bit_num) % bit_num;
78 value >> shift
79 },
80 )?;
81 Ok(result)
82}
83
84trait UShr {
89 fn ushr(self, rhs: i32) -> Self;
90}
91
92impl UShr for u32 {
93 fn ushr(self, rhs: i32) -> Self {
94 self >> rhs
95 }
96}
97
98impl UShr for u64 {
99 fn ushr(self, rhs: i32) -> Self {
100 self >> rhs
101 }
102}
103
104impl UShr for i32 {
105 fn ushr(self, rhs: i32) -> Self {
106 ((self as u32) >> rhs) as i32
107 }
108}
109
110impl UShr for i64 {
111 fn ushr(self, rhs: i32) -> Self {
112 ((self as u64) >> rhs) as i64
113 }
114}
115
116fn shift_right_unsigned<T>(
120 value: &PrimitiveArray<T>,
121 shift: &Int32Array,
122) -> Result<PrimitiveArray<T>>
123where
124 T: ArrowPrimitiveType,
125 T::Native: UShr,
126{
127 let bit_num = (T::Native::get_byte_width() * 8) as i32;
128 let result = compute::binary::<_, Int32Type, _, _>(
129 value,
130 shift,
131 |value: T::Native, shift: i32| {
132 let shift = ((shift % bit_num) + bit_num) % bit_num;
133 value.ushr(shift)
134 },
135 )?;
136 Ok(result)
137}
138
139fn shift_inner(
140 arrays: &[ArrayRef],
141 name: &str,
142 bit_shift_type: BitShiftType,
143) -> Result<ArrayRef> {
144 let [value_array, shift_array] = take_function_args(name, arrays)?;
145 let shift_array = shift_array.as_primitive::<Int32Type>();
146
147 fn shift<T>(
148 value: &PrimitiveArray<T>,
149 shift: &Int32Array,
150 bit_shift_type: BitShiftType,
151 ) -> Result<PrimitiveArray<T>>
152 where
153 T: ArrowPrimitiveType,
154 T::Native: std::ops::Shl<i32, Output = T::Native>
155 + std::ops::Shr<i32, Output = T::Native>
156 + UShr,
157 {
158 match bit_shift_type {
159 BitShiftType::Left => shift_left(value, shift),
160 BitShiftType::Right => shift_right(value, shift),
161 BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
162 }
163 }
164
165 match value_array.data_type() {
166 DataType::Int32 => {
167 let value_array = value_array.as_primitive::<Int32Type>();
168 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
169 }
170 DataType::Int64 => {
171 let value_array = value_array.as_primitive::<Int64Type>();
172 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
173 }
174 DataType::UInt32 => {
175 let value_array = value_array.as_primitive::<UInt32Type>();
176 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
177 }
178 DataType::UInt64 => {
179 let value_array = value_array.as_primitive::<UInt64Type>();
180 Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
181 }
182 dt => {
183 internal_err!("{name} function does not support data type: {dt}")
184 }
185 }
186}
187
188#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
189enum BitShiftType {
190 Left,
191 Right,
192 RightUnsigned,
193}
194
195#[derive(Debug, Hash, Eq, PartialEq)]
196pub struct SparkBitShift {
197 signature: Signature,
198 name: &'static str,
199 bit_shift_type: BitShiftType,
200}
201
202impl SparkBitShift {
203 fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
204 let shift_amount = Coercion::new_implicit(
205 TypeSignatureClass::Native(logical_int32()),
206 vec![TypeSignatureClass::Integer],
207 NativeType::Int32,
208 );
209 Self {
210 signature: Signature::one_of(
211 vec![
212 TypeSignature::Coercible(vec![
214 Coercion::new_implicit(
215 TypeSignatureClass::Native(logical_int32()),
216 vec![
217 TypeSignatureClass::Native(logical_int8()),
218 TypeSignatureClass::Native(logical_int16()),
219 ],
220 NativeType::Int32,
221 ),
222 shift_amount.clone(),
223 ]),
224 TypeSignature::Coercible(vec![
225 Coercion::new_implicit(
226 TypeSignatureClass::Native(logical_uint32()),
227 vec![
228 TypeSignatureClass::Native(logical_uint8()),
229 TypeSignatureClass::Native(logical_uint16()),
230 ],
231 NativeType::UInt32,
232 ),
233 shift_amount.clone(),
234 ]),
235 TypeSignature::Coercible(vec![
237 Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
238 shift_amount.clone(),
239 ]),
240 TypeSignature::Coercible(vec![
241 Coercion::new_exact(TypeSignatureClass::Native(logical_uint64())),
242 shift_amount.clone(),
243 ]),
244 ],
245 Volatility::Immutable,
246 ),
247 name,
248 bit_shift_type,
249 }
250 }
251
252 pub fn left() -> Self {
253 Self::new("shiftleft", BitShiftType::Left)
254 }
255
256 pub fn right() -> Self {
257 Self::new("shiftright", BitShiftType::Right)
258 }
259
260 pub fn right_unsigned() -> Self {
261 Self::new("shiftrightunsigned", BitShiftType::RightUnsigned)
262 }
263}
264
265impl ScalarUDFImpl for SparkBitShift {
266 fn name(&self) -> &str {
267 self.name
268 }
269
270 fn signature(&self) -> &Signature {
271 &self.signature
272 }
273
274 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
275 internal_err!("return_field_from_args should be used instead")
276 }
277
278 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
279 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
280 let data_type = args.arg_fields[0].data_type().clone();
281 Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
282 }
283
284 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
285 let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> {
286 shift_inner(arr, self.name(), self.bit_shift_type)
287 };
288 make_scalar_function(inner, vec![])(&args.args)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_bit_shift_nullability() -> Result<()> {
298 let func = SparkBitShift::left();
299
300 let non_nullable_value: FieldRef =
301 Arc::new(Field::new("value", DataType::Int64, false));
302 let non_nullable_shift: FieldRef =
303 Arc::new(Field::new("shift", DataType::Int32, false));
304
305 let out = func.return_field_from_args(ReturnFieldArgs {
306 arg_fields: &[
307 Arc::clone(&non_nullable_value),
308 Arc::clone(&non_nullable_shift),
309 ],
310 scalar_arguments: &[None, None],
311 })?;
312
313 assert_eq!(out.data_type(), non_nullable_value.data_type());
314 assert!(
315 !out.is_nullable(),
316 "shift result should be non-nullable when both inputs are non-nullable"
317 );
318
319 let nullable_value: FieldRef =
320 Arc::new(Field::new("value", DataType::Int64, true));
321 let out_nullable_value = func.return_field_from_args(ReturnFieldArgs {
322 arg_fields: &[Arc::clone(&nullable_value), Arc::clone(&non_nullable_shift)],
323 scalar_arguments: &[None, None],
324 })?;
325 assert!(
326 out_nullable_value.is_nullable(),
327 "shift result should be nullable when value is nullable"
328 );
329
330 let nullable_shift: FieldRef =
331 Arc::new(Field::new("shift", DataType::Int32, true));
332 let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs {
333 arg_fields: &[non_nullable_value, nullable_shift],
334 scalar_arguments: &[None, None],
335 })?;
336 assert!(
337 out_nullable_shift.is_nullable(),
338 "shift result should be nullable when shift is nullable"
339 );
340
341 Ok(())
342 }
343}