datafusion_spark/function/bitwise/
bit_shift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray};
22use arrow::compute;
23use arrow::datatypes::{
24    ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, UInt32Type,
25    UInt64Type,
26};
27use datafusion_common::types::{
28    NativeType, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8,
29    logical_uint16, logical_uint32, logical_uint64,
30};
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{Result, internal_err};
33use datafusion_expr::{
34    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
35    Signature, TypeSignature, TypeSignatureClass, Volatility,
36};
37use datafusion_functions::utils::make_scalar_function;
38
39/// Bitwise left shift on elements in `value` by corresponding `shift` amount.
40/// The shift amount is normalized to the bit width of the type, matching Spark/Java
41/// semantics for negative and large shifts.
42fn shift_left<T>(
43    value: &PrimitiveArray<T>,
44    shift: &Int32Array,
45) -> Result<PrimitiveArray<T>>
46where
47    T: ArrowPrimitiveType,
48    T::Native: std::ops::Shl<i32, Output = T::Native>,
49{
50    let bit_num = (T::Native::get_byte_width() * 8) as i32;
51    let result = compute::binary::<_, Int32Type, _, _>(
52        value,
53        shift,
54        |value: T::Native, shift: i32| {
55            let shift = ((shift % bit_num) + bit_num) % bit_num;
56            value << shift
57        },
58    )?;
59    Ok(result)
60}
61
62/// Bitwise right shift on elements in `value` by corresponding `shift` amount.
63/// The shift amount is normalized to the bit width of the type, matching Spark/Java
64/// semantics for negative and large shifts.
65fn shift_right<T>(
66    value: &PrimitiveArray<T>,
67    shift: &Int32Array,
68) -> Result<PrimitiveArray<T>>
69where
70    T: ArrowPrimitiveType,
71    T::Native: std::ops::Shr<i32, Output = T::Native>,
72{
73    let bit_num = (T::Native::get_byte_width() * 8) as i32;
74    let result = compute::binary::<_, Int32Type, _, _>(
75        value,
76        shift,
77        |value: T::Native, shift: i32| {
78            let shift = ((shift % bit_num) + bit_num) % bit_num;
79            value >> shift
80        },
81    )?;
82    Ok(result)
83}
84
85/// Trait for performing an unsigned right shift (logical shift right).
86/// This is used to mimic Java's `>>>` operator, which does not exist in Rust.
87/// For unsigned types, this is just the normal right shift.
88/// For signed types, this casts to the unsigned type, shifts, then casts back.
89trait UShr {
90    fn ushr(self, rhs: i32) -> Self;
91}
92
93impl UShr for u32 {
94    fn ushr(self, rhs: i32) -> Self {
95        self >> rhs
96    }
97}
98
99impl UShr for u64 {
100    fn ushr(self, rhs: i32) -> Self {
101        self >> rhs
102    }
103}
104
105impl UShr for i32 {
106    fn ushr(self, rhs: i32) -> Self {
107        ((self as u32) >> rhs) as i32
108    }
109}
110
111impl UShr for i64 {
112    fn ushr(self, rhs: i32) -> Self {
113        ((self as u64) >> rhs) as i64
114    }
115}
116
117/// Bitwise unsigned right shift on elements in `value` by corresponding `shift`
118/// amount. The shift amount is normalized to the bit width of the type, matching
119/// Spark/Java semantics for negative and large shifts.
120fn shift_right_unsigned<T>(
121    value: &PrimitiveArray<T>,
122    shift: &Int32Array,
123) -> Result<PrimitiveArray<T>>
124where
125    T: ArrowPrimitiveType,
126    T::Native: UShr,
127{
128    let bit_num = (T::Native::get_byte_width() * 8) as i32;
129    let result = compute::binary::<_, Int32Type, _, _>(
130        value,
131        shift,
132        |value: T::Native, shift: i32| {
133            let shift = ((shift % bit_num) + bit_num) % bit_num;
134            value.ushr(shift)
135        },
136    )?;
137    Ok(result)
138}
139
140fn shift_inner(
141    arrays: &[ArrayRef],
142    name: &str,
143    bit_shift_type: BitShiftType,
144) -> Result<ArrayRef> {
145    let [value_array, shift_array] = take_function_args(name, arrays)?;
146    let shift_array = shift_array.as_primitive::<Int32Type>();
147
148    fn shift<T>(
149        value: &PrimitiveArray<T>,
150        shift: &Int32Array,
151        bit_shift_type: BitShiftType,
152    ) -> Result<PrimitiveArray<T>>
153    where
154        T: ArrowPrimitiveType,
155        T::Native: std::ops::Shl<i32, Output = T::Native>
156            + std::ops::Shr<i32, Output = T::Native>
157            + UShr,
158    {
159        match bit_shift_type {
160            BitShiftType::Left => shift_left(value, shift),
161            BitShiftType::Right => shift_right(value, shift),
162            BitShiftType::RightUnsigned => shift_right_unsigned(value, shift),
163        }
164    }
165
166    match value_array.data_type() {
167        DataType::Int32 => {
168            let value_array = value_array.as_primitive::<Int32Type>();
169            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
170        }
171        DataType::Int64 => {
172            let value_array = value_array.as_primitive::<Int64Type>();
173            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
174        }
175        DataType::UInt32 => {
176            let value_array = value_array.as_primitive::<UInt32Type>();
177            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
178        }
179        DataType::UInt64 => {
180            let value_array = value_array.as_primitive::<UInt64Type>();
181            Ok(Arc::new(shift(value_array, shift_array, bit_shift_type)?))
182        }
183        dt => {
184            internal_err!("{name} function does not support data type: {dt}")
185        }
186    }
187}
188
189#[derive(Debug, Hash, Copy, Clone, Eq, PartialEq)]
190enum BitShiftType {
191    Left,
192    Right,
193    RightUnsigned,
194}
195
196#[derive(Debug, Hash, Eq, PartialEq)]
197pub struct SparkBitShift {
198    signature: Signature,
199    name: &'static str,
200    bit_shift_type: BitShiftType,
201}
202
203impl SparkBitShift {
204    fn new(name: &'static str, bit_shift_type: BitShiftType) -> Self {
205        let shift_amount = Coercion::new_implicit(
206            TypeSignatureClass::Native(logical_int32()),
207            vec![TypeSignatureClass::Integer],
208            NativeType::Int32,
209        );
210        Self {
211            signature: Signature::one_of(
212                vec![
213                    // Upcast small ints to 32bit
214                    TypeSignature::Coercible(vec![
215                        Coercion::new_implicit(
216                            TypeSignatureClass::Native(logical_int32()),
217                            vec![
218                                TypeSignatureClass::Native(logical_int8()),
219                                TypeSignatureClass::Native(logical_int16()),
220                            ],
221                            NativeType::Int32,
222                        ),
223                        shift_amount.clone(),
224                    ]),
225                    TypeSignature::Coercible(vec![
226                        Coercion::new_implicit(
227                            TypeSignatureClass::Native(logical_uint32()),
228                            vec![
229                                TypeSignatureClass::Native(logical_uint8()),
230                                TypeSignatureClass::Native(logical_uint16()),
231                            ],
232                            NativeType::UInt32,
233                        ),
234                        shift_amount.clone(),
235                    ]),
236                    // Otherwise accept direct 64 bit integers
237                    TypeSignature::Coercible(vec![
238                        Coercion::new_exact(TypeSignatureClass::Native(logical_int64())),
239                        shift_amount.clone(),
240                    ]),
241                    TypeSignature::Coercible(vec![
242                        Coercion::new_exact(TypeSignatureClass::Native(logical_uint64())),
243                        shift_amount.clone(),
244                    ]),
245                ],
246                Volatility::Immutable,
247            ),
248            name,
249            bit_shift_type,
250        }
251    }
252
253    pub fn left() -> Self {
254        Self::new("shiftleft", BitShiftType::Left)
255    }
256
257    pub fn right() -> Self {
258        Self::new("shiftright", BitShiftType::Right)
259    }
260
261    pub fn right_unsigned() -> Self {
262        Self::new("shiftrightunsigned", BitShiftType::RightUnsigned)
263    }
264}
265
266impl ScalarUDFImpl for SparkBitShift {
267    fn as_any(&self) -> &dyn Any {
268        self
269    }
270
271    fn name(&self) -> &str {
272        self.name
273    }
274
275    fn signature(&self) -> &Signature {
276        &self.signature
277    }
278
279    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
280        internal_err!("return_field_from_args should be used instead")
281    }
282
283    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
284        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
285        let data_type = args.arg_fields[0].data_type().clone();
286        Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
287    }
288
289    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
290        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> {
291            shift_inner(arr, self.name(), self.bit_shift_type)
292        };
293        make_scalar_function(inner, vec![])(&args.args)
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use arrow::datatypes::Field;
301    use datafusion_expr::ReturnFieldArgs;
302
303    #[test]
304    fn test_bit_shift_nullability() -> Result<()> {
305        let func = SparkBitShift::left();
306
307        let non_nullable_value: FieldRef =
308            Arc::new(Field::new("value", DataType::Int64, false));
309        let non_nullable_shift: FieldRef =
310            Arc::new(Field::new("shift", DataType::Int32, false));
311
312        let out = func.return_field_from_args(ReturnFieldArgs {
313            arg_fields: &[
314                Arc::clone(&non_nullable_value),
315                Arc::clone(&non_nullable_shift),
316            ],
317            scalar_arguments: &[None, None],
318        })?;
319
320        assert_eq!(out.data_type(), non_nullable_value.data_type());
321        assert!(
322            !out.is_nullable(),
323            "shift result should be non-nullable when both inputs are non-nullable"
324        );
325
326        let nullable_value: FieldRef =
327            Arc::new(Field::new("value", DataType::Int64, true));
328        let out_nullable_value = func.return_field_from_args(ReturnFieldArgs {
329            arg_fields: &[Arc::clone(&nullable_value), Arc::clone(&non_nullable_shift)],
330            scalar_arguments: &[None, None],
331        })?;
332        assert!(
333            out_nullable_value.is_nullable(),
334            "shift result should be nullable when value is nullable"
335        );
336
337        let nullable_shift: FieldRef =
338            Arc::new(Field::new("shift", DataType::Int32, true));
339        let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs {
340            arg_fields: &[non_nullable_value, nullable_shift],
341            scalar_arguments: &[None, None],
342        })?;
343        assert!(
344            out_nullable_shift.is_nullable(),
345            "shift result should be nullable when shift is nullable"
346        );
347
348        Ok(())
349    }
350}