Skip to main content

datafusion_spark/function/bitwise/
bit_shift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray};
21use arrow::compute;
22use arrow::datatypes::{
23    ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, UInt32Type,
24    UInt64Type,
25};
26use datafusion_common::types::{
27    NativeType, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8,
28    logical_uint16, logical_uint32, logical_uint64,
29};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{Result, internal_err};
32use datafusion_expr::{
33    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
34    Signature, TypeSignature, TypeSignatureClass, Volatility,
35};
36use datafusion_functions::utils::make_scalar_function;
37
38/// Bitwise left shift on elements in `value` by corresponding `shift` amount.
39/// The shift amount is normalized to the bit width of the type, matching Spark/Java
40/// semantics for negative and large shifts.
41fn shift_left<T>(
42    value: &PrimitiveArray<T>,
43    shift: &Int32Array,
44) -> Result<PrimitiveArray<T>>
45where
46    T: ArrowPrimitiveType,
47    T::Native: std::ops::Shl<i32, Output = T::Native>,
48{
49    let bit_num = (T::Native::get_byte_width() * 8) as i32;
50    let result = compute::binary::<_, Int32Type, _, _>(
51        value,
52        shift,
53        |value: T::Native, shift: i32| {
54            let shift = ((shift % bit_num) + bit_num) % bit_num;
55            value << shift
56        },
57    )?;
58    Ok(result)
59}
60
61/// Bitwise right shift on elements in `value` by corresponding `shift` amount.
62/// The shift amount is normalized to the bit width of the type, matching Spark/Java
63/// semantics for negative and large shifts.
64fn shift_right<T>(
65    value: &PrimitiveArray<T>,
66    shift: &Int32Array,
67) -> Result<PrimitiveArray<T>>
68where
69    T: ArrowPrimitiveType,
70    T::Native: std::ops::Shr<i32, Output = T::Native>,
71{
72    let bit_num = (T::Native::get_byte_width() * 8) as i32;
73    let result = compute::binary::<_, Int32Type, _, _>(
74        value,
75        shift,
76        |value: T::Native, shift: i32| {
77            let shift = ((shift % bit_num) + bit_num) % bit_num;
78            value >> shift
79        },
80    )?;
81    Ok(result)
82}
83
84/// Trait for performing an unsigned right shift (logical shift right).
85/// This is used to mimic Java's `>>>` operator, which does not exist in Rust.
86/// For unsigned types, this is just the normal right shift.
87/// For signed types, this casts to the unsigned type, shifts, then casts back.
88trait UShr {
89    fn ushr(self, rhs: i32) -> Self;
90}
91
92impl UShr for u32 {
93    fn ushr(self, rhs: i32) -> Self {
94        self >> rhs
95    }
96}
97
98impl UShr for u64 {
99    fn ushr(self, rhs: i32) -> Self {
100        self >> rhs
101    }
102}
103
104impl UShr for i32 {
105    fn ushr(self, rhs: i32) -> Self {
106        ((self as u32) >> rhs) as i32
107    }
108}
109
110impl UShr for i64 {
111    fn ushr(self, rhs: i32) -> Self {
112        ((self as u64) >> rhs) as i64
113    }
114}
115
116/// Bitwise unsigned right shift on elements in `value` by corresponding `shift`
117/// amount. The shift amount is normalized to the bit width of the type, matching
118/// Spark/Java semantics for negative and large shifts.
119fn shift_right_unsigned<T>(
120    value: &PrimitiveArray<T>,
121    shift: &Int32Array,
122) -> Result<PrimitiveArray<T>>
123where
124    T: ArrowPrimitiveType,
125    T::Native: UShr,
126{
127    let bit_num = (T::Native::get_byte_width() * 8) as i32;
128    let result = compute::binary::<_, Int32Type, _, _>(
129        value,
130        shift,
131        |value: T::Native, shift: i32| {
132            let shift = ((shift % bit_num) + bit_num) % bit_num;
133            value.ushr(shift)
134        },
135    )?;
136    Ok(result)
137}
138
139fn shift_inner(
140    arrays: &[ArrayRef],
141    name: &str,
142    bit_shift_type: BitShiftType,
143) -> Result<ArrayRef> {
144    let [value_array, shift_array] = take_function_args(name, arrays)?;
145    let shift_array = shift_array.as_primitive::<Int32Type>();
146
147    fn shift<T>(
148        value: &PrimitiveArray<T>,
149        shift: &Int32Array,
150        bit_shift_type: BitShiftType,
151    ) -> Result<PrimitiveArray<T>>
152    where
153        T: ArrowPrimitiveType,
154        T::Native: std::ops::Shl<i32, Output = T::Native>
155            + std::ops::Shr<i32, Output = T::Native>
156            + UShr,
157    {
158        match bit_shift_type {
159            BitShiftType::Left => shift_left(value, shift),
160            BitShiftType::Right => shift_right(value, shift),
161            BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
162        }
163    }
164
165    match value_array.data_type() {
166        DataType::Int32 => {
167            let value_array = value_array.as_primitive::<Int32Type>();
168            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
169        }
170        DataType::Int64 => {
171            let value_array = value_array.as_primitive::<Int64Type>();
172            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
173        }
174        DataType::UInt32 => {
175            let value_array = value_array.as_primitive::<UInt32Type>();
176            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
177        }
178        DataType::UInt64 => {
179            let value_array = value_array.as_primitive::<UInt64Type>();
180            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
181        }
182        dt => {
183            internal_err!("{name} function does not support data type: {dt}")
184        }
185    }
186}
187
188#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
189enum BitShiftType {
190    Left,
191    Right,
192    RightUnsigned,
193}
194
195#[derive(Debug, Hash, Eq, PartialEq)]
196pub struct SparkBitShift {
197    signature: Signature,
198    name: &'static str,
199    bit_shift_type: BitShiftType,
200}
201
202impl SparkBitShift {
203    fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
204        let shift_amount = Coercion::new_implicit(
205            TypeSignatureClass::Native(logical_int32()),
206            vec![TypeSignatureClass::Integer],
207            NativeType::Int32,
208        );
209        Self {
210            signature: Signature::one_of(
211                vec![
212                    // Upcast small ints to 32bit
213                    TypeSignature::Coercible(vec![
214                        Coercion::new_implicit(
215                            TypeSignatureClass::Native(logical_int32()),
216                            vec![
217                                TypeSignatureClass::Native(logical_int8()),
218                                TypeSignatureClass::Native(logical_int16()),
219                            ],
220                            NativeType::Int32,
221                        ),
222                        shift_amount.clone(),
223                    ]),
224                    TypeSignature::Coercible(vec![
225                        Coercion::new_implicit(
226                            TypeSignatureClass::Native(logical_uint32()),
227                            vec![
228                                TypeSignatureClass::Native(logical_uint8()),
229                                TypeSignatureClass::Native(logical_uint16()),
230                            ],
231                            NativeType::UInt32,
232                        ),
233                        shift_amount.clone(),
234                    ]),
235                    // Otherwise accept direct 64 bit integers
236                    TypeSignature::Coercible(vec![
237                        Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
238                        shift_amount.clone(),
239                    ]),
240                    TypeSignature::Coercible(vec![
241                        Coercion::new_exact(TypeSignatureClass::Native(logical_uint64())),
242                        shift_amount.clone(),
243                    ]),
244                ],
245                Volatility::Immutable,
246            ),
247            name,
248            bit_shift_type,
249        }
250    }
251
252    pub fn left() -> Self {
253        Self::new("shiftleft", BitShiftType::Left)
254    }
255
256    pub fn right() -> Self {
257        Self::new("shiftright", BitShiftType::Right)
258    }
259
260    pub fn right_unsigned() -> Self {
261        Self::new("shiftrightunsigned", BitShiftType::RightUnsigned)
262    }
263}
264
265impl ScalarUDFImpl for SparkBitShift {
266    fn name(&self) -> &str {
267        self.name
268    }
269
270    fn signature(&self) -> &Signature {
271        &self.signature
272    }
273
274    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
275        internal_err!("return_field_from_args should be used instead")
276    }
277
278    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
279        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
280        let data_type = args.arg_fields[0].data_type().clone();
281        Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
282    }
283
284    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
285        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> {
286            shift_inner(arr, self.name(), self.bit_shift_type)
287        };
288        make_scalar_function(inner, vec![])(&args.args)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_bit_shift_nullability() -> Result<()> {
298        let func = SparkBitShift::left();
299
300        let non_nullable_value: FieldRef =
301            Arc::new(Field::new("value", DataType::Int64, false));
302        let non_nullable_shift: FieldRef =
303            Arc::new(Field::new("shift", DataType::Int32, false));
304
305        let out = func.return_field_from_args(ReturnFieldArgs {
306            arg_fields: &[
307                Arc::clone(&non_nullable_value),
308                Arc::clone(&non_nullable_shift),
309            ],
310            scalar_arguments: &[None, None],
311        })?;
312
313        assert_eq!(out.data_type(), non_nullable_value.data_type());
314        assert!(
315            !out.is_nullable(),
316            "shift result should be non-nullable when both inputs are non-nullable"
317        );
318
319        let nullable_value: FieldRef =
320            Arc::new(Field::new("value", DataType::Int64, true));
321        let out_nullable_value = func.return_field_from_args(ReturnFieldArgs {
322            arg_fields: &[Arc::clone(&nullable_value), Arc::clone(&non_nullable_shift)],
323            scalar_arguments: &[None, None],
324        })?;
325        assert!(
326            out_nullable_value.is_nullable(),
327            "shift result should be nullable when value is nullable"
328        );
329
330        let nullable_shift: FieldRef =
331            Arc::new(Field::new("shift", DataType::Int32, true));
332        let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs {
333            arg_fields: &[non_nullable_value, nullable_shift],
334            scalar_arguments: &[None, None],
335        })?;
336        assert!(
337            out_nullable_shift.is_nullable(),
338            "shift result should be nullable when shift is nullable"
339        );
340
341        Ok(())
342    }
343}