datafusion_spark/function/bitwise/
bit_shift.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
22use arrow::compute;
23use arrow::datatypes::{
24    ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type,
25};
26use datafusion_common::{plan_err, Result};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_functions::utils::make_scalar_function;
31
32use crate::function::error_utils::{
33    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
34};
35
36/// Performs a bitwise left shift on each element of the `value` array by the corresponding amount in the `shift` array.
37/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts.
38///
39/// # Arguments
40/// * `value` - The array of values to shift.
41/// * `shift` - The array of shift amounts (must be Int32).
42///
43/// # Returns
44/// A new array with the shifted values.
45fn shift_left<T: ArrowPrimitiveType>(
46    value: &PrimitiveArray<T>,
47    shift: &PrimitiveArray<Int32Type>,
48) -> Result<PrimitiveArray<T>>
49where
50    T::Native: ArrowNativeType + std::ops::Shl<i32, Output = T::Native>,
51{
52    let bit_num = (T::Native::get_byte_width() * 8) as i32;
53    let result = compute::binary::<_, Int32Type, _, _>(
54        value,
55        shift,
56        |value: T::Native, shift: i32| {
57            let shift = ((shift % bit_num) + bit_num) % bit_num;
58            value << shift
59        },
60    )?;
61    Ok(result)
62}
63
64/// Performs a bitwise right shift on each element of the `value` array by the corresponding amount in the `shift` array.
65/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts.
66///
67/// # Arguments
68/// * `value` - The array of values to shift.
69/// * `shift` - The array of shift amounts (must be Int32).
70///
71/// # Returns
72/// A new array with the shifted values.
73fn shift_right<T: ArrowPrimitiveType>(
74    value: &PrimitiveArray<T>,
75    shift: &PrimitiveArray<Int32Type>,
76) -> Result<PrimitiveArray<T>>
77where
78    T::Native: ArrowNativeType + std::ops::Shr<i32, Output = T::Native>,
79{
80    let bit_num = (T::Native::get_byte_width() * 8) as i32;
81    let result = compute::binary::<_, Int32Type, _, _>(
82        value,
83        shift,
84        |value: T::Native, shift: i32| {
85            let shift = ((shift % bit_num) + bit_num) % bit_num;
86            value >> shift
87        },
88    )?;
89    Ok(result)
90}
91
92/// Trait for performing an unsigned right shift (logical shift right).
93/// This is used to mimic Java's `>>>` operator, which does not exist in Rust.
94/// For unsigned types, this is just the normal right shift.
95/// For signed types, this casts to the unsigned type, shifts, then casts back.
96trait UShr<Rhs> {
97    fn ushr(self, rhs: Rhs) -> Self;
98}
99
100impl UShr<i32> for u32 {
101    fn ushr(self, rhs: i32) -> Self {
102        self >> rhs
103    }
104}
105
106impl UShr<i32> for u64 {
107    fn ushr(self, rhs: i32) -> Self {
108        self >> rhs
109    }
110}
111
112impl UShr<i32> for i32 {
113    fn ushr(self, rhs: i32) -> Self {
114        ((self as u32) >> rhs) as i32
115    }
116}
117
118impl UShr<i32> for i64 {
119    fn ushr(self, rhs: i32) -> Self {
120        ((self as u64) >> rhs) as i64
121    }
122}
123
124/// Performs a bitwise unsigned right shift on each element of the `value` array by the corresponding amount in the `shift` array.
125/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts.
126///
127/// # Arguments
128/// * `value` - The array of values to shift.
129/// * `shift` - The array of shift amounts (must be Int32).
130///
131/// # Returns
132/// A new array with the shifted values.
133fn shift_right_unsigned<T: ArrowPrimitiveType>(
134    value: &PrimitiveArray<T>,
135    shift: &PrimitiveArray<Int32Type>,
136) -> Result<PrimitiveArray<T>>
137where
138    T::Native: ArrowNativeType + UShr<i32>,
139{
140    let bit_num = (T::Native::get_byte_width() * 8) as i32;
141    let result = compute::binary::<_, Int32Type, _, _>(
142        value,
143        shift,
144        |value: T::Native, shift: i32| {
145            let shift = ((shift % bit_num) + bit_num) % bit_num;
146            value.ushr(shift)
147        },
148    )?;
149    Ok(result)
150}
151
152trait BitShiftUDF: ScalarUDFImpl {
153    fn shift<T: ArrowPrimitiveType>(
154        &self,
155        value: &PrimitiveArray<T>,
156        shift: &PrimitiveArray<Int32Type>,
157    ) -> Result<PrimitiveArray<T>>
158    where
159        T::Native: ArrowNativeType
160            + std::ops::Shl<i32, Output = T::Native>
161            + std::ops::Shr<i32, Output = T::Native>
162            + UShr<i32>;
163
164    fn spark_shift(&self, arrays: &[ArrayRef]) -> Result<ArrayRef> {
165        let value_array = arrays[0].as_ref();
166        let shift_array = arrays[1].as_ref();
167
168        // Ensure shift array is Int32
169        let shift_array = if shift_array.data_type() != &DataType::Int32 {
170            return plan_err!("{} shift amount must be Int32", self.name());
171        } else {
172            shift_array.as_primitive::<Int32Type>()
173        };
174
175        match value_array.data_type() {
176            DataType::Int32 => {
177                let value_array = value_array.as_primitive::<Int32Type>();
178                Ok(Arc::new(self.shift(value_array, shift_array)?))
179            }
180            DataType::Int64 => {
181                let value_array = value_array.as_primitive::<Int64Type>();
182                Ok(Arc::new(self.shift(value_array, shift_array)?))
183            }
184            DataType::UInt32 => {
185                let value_array = value_array.as_primitive::<UInt32Type>();
186                Ok(Arc::new(self.shift(value_array, shift_array)?))
187            }
188            DataType::UInt64 => {
189                let value_array = value_array.as_primitive::<UInt64Type>();
190                Ok(Arc::new(self.shift(value_array, shift_array)?))
191            }
192            _ => {
193                plan_err!(
194                    "{} function does not support data type: {}",
195                    self.name(),
196                    value_array.data_type()
197                )
198            }
199        }
200    }
201}
202
203fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> Result<Vec<DataType>> {
204    if arg_types.len() != 2 {
205        return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len()));
206    }
207    if !arg_types[0].is_integer() && !arg_types[0].is_null() {
208        return Err(unsupported_data_type_exec_err(
209            func,
210            "Integer Type",
211            &arg_types[0],
212        ));
213    }
214    if !arg_types[1].is_integer() && !arg_types[1].is_null() {
215        return Err(unsupported_data_type_exec_err(
216            func,
217            "Integer Type",
218            &arg_types[1],
219        ));
220    }
221
222    // Coerce smaller integer types to Int32
223    let coerced_first = match &arg_types[0] {
224        DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32,
225        DataType::UInt8 | DataType::UInt16 => DataType::UInt32,
226        _ => arg_types[0].clone(),
227    };
228
229    Ok(vec![coerced_first, DataType::Int32])
230}
231
232#[derive(Debug, Hash, Eq, PartialEq)]
233pub struct SparkShiftLeft {
234    signature: Signature,
235}
236
237impl Default for SparkShiftLeft {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243impl SparkShiftLeft {
244    pub fn new() -> Self {
245        Self {
246            signature: Signature::user_defined(Volatility::Immutable),
247        }
248    }
249}
250
251impl BitShiftUDF for SparkShiftLeft {
252    fn shift<T: ArrowPrimitiveType>(
253        &self,
254        value: &PrimitiveArray<T>,
255        shift: &PrimitiveArray<Int32Type>,
256    ) -> Result<PrimitiveArray<T>>
257    where
258        T::Native: ArrowNativeType
259            + std::ops::Shl<i32, Output = T::Native>
260            + std::ops::Shr<i32, Output = T::Native>
261            + UShr<i32>,
262    {
263        shift_left(value, shift)
264    }
265}
266
267impl ScalarUDFImpl for SparkShiftLeft {
268    fn as_any(&self) -> &dyn Any {
269        self
270    }
271
272    fn name(&self) -> &str {
273        "shiftleft"
274    }
275
276    fn signature(&self) -> &Signature {
277        &self.signature
278    }
279
280    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
281        bit_shift_coerce_types(arg_types, "shiftleft")
282    }
283
284    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
285        if arg_types.len() != 2 {
286            return plan_err!("shiftleft expects exactly 2 arguments");
287        }
288        // Return type is the same as the first argument (the value to shift)
289        Ok(arg_types[0].clone())
290    }
291
292    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
293        if args.args.len() != 2 {
294            return plan_err!("shiftleft expects exactly 2 arguments");
295        }
296        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { self.spark_shift(arr) };
297        make_scalar_function(inner, vec![])(&args.args)
298    }
299}
300
301#[derive(Debug, Hash, Eq, PartialEq)]
302pub struct SparkShiftRightUnsigned {
303    signature: Signature,
304}
305
306impl Default for SparkShiftRightUnsigned {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312impl SparkShiftRightUnsigned {
313    pub fn new() -> Self {
314        Self {
315            signature: Signature::user_defined(Volatility::Immutable),
316        }
317    }
318}
319
320impl BitShiftUDF for SparkShiftRightUnsigned {
321    fn shift<T: ArrowPrimitiveType>(
322        &self,
323        value: &PrimitiveArray<T>,
324        shift: &PrimitiveArray<Int32Type>,
325    ) -> Result<PrimitiveArray<T>>
326    where
327        T::Native: ArrowNativeType
328            + std::ops::Shl<i32, Output = T::Native>
329            + std::ops::Shr<i32, Output = T::Native>
330            + UShr<i32>,
331    {
332        shift_right_unsigned(value, shift)
333    }
334}
335
336impl ScalarUDFImpl for SparkShiftRightUnsigned {
337    fn as_any(&self) -> &dyn Any {
338        self
339    }
340
341    fn name(&self) -> &str {
342        "shiftrightunsigned"
343    }
344
345    fn signature(&self) -> &Signature {
346        &self.signature
347    }
348
349    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
350        bit_shift_coerce_types(arg_types, "shiftrightunsigned")
351    }
352
353    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
354        if arg_types.len() != 2 {
355            return plan_err!("shiftrightunsigned expects exactly 2 arguments");
356        }
357        // Return type is the same as the first argument (the value to shift)
358        Ok(arg_types[0].clone())
359    }
360
361    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
362        if args.args.len() != 2 {
363            return plan_err!("shiftrightunsigned expects exactly 2 arguments");
364        }
365        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { self.spark_shift(arr) };
366        make_scalar_function(inner, vec![])(&args.args)
367    }
368}
369
370#[derive(Debug, Hash, Eq, PartialEq)]
371pub struct SparkShiftRight {
372    signature: Signature,
373}
374
375impl Default for SparkShiftRight {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381impl SparkShiftRight {
382    pub fn new() -> Self {
383        Self {
384            signature: Signature::user_defined(Volatility::Immutable),
385        }
386    }
387}
388
389impl BitShiftUDF for SparkShiftRight {
390    fn shift<T: ArrowPrimitiveType>(
391        &self,
392        value: &PrimitiveArray<T>,
393        shift: &PrimitiveArray<Int32Type>,
394    ) -> Result<PrimitiveArray<T>>
395    where
396        T::Native: ArrowNativeType
397            + std::ops::Shl<i32, Output = T::Native>
398            + std::ops::Shr<i32, Output = T::Native>
399            + UShr<i32>,
400    {
401        shift_right(value, shift)
402    }
403}
404
405impl ScalarUDFImpl for SparkShiftRight {
406    fn as_any(&self) -> &dyn Any {
407        self
408    }
409
410    fn name(&self) -> &str {
411        "shiftright"
412    }
413
414    fn signature(&self) -> &Signature {
415        &self.signature
416    }
417
418    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
419        bit_shift_coerce_types(arg_types, "shiftright")
420    }
421
422    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
423        if arg_types.len() != 2 {
424            return plan_err!("shiftright expects exactly 2 arguments");
425        }
426        // Return type is the same as the first argument (the value to shift)
427        Ok(arg_types[0].clone())
428    }
429
430    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
431        if args.args.len() != 2 {
432            return plan_err!("shiftright expects exactly 2 arguments");
433        }
434        let inner = |arr: &[ArrayRef]| -> Result<ArrayRef> { self.spark_shift(arr) };
435        make_scalar_function(inner, vec![])(&args.args)
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use arrow::array::{Array, Int32Array, Int64Array, UInt32Array, UInt64Array};
443
444    #[test]
445    fn test_shift_right_unsigned_int32() {
446        let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32]));
447        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
448        let result = SparkShiftRightUnsigned::new()
449            .spark_shift(&[value_array, shift_array])
450            .unwrap();
451        let arr = result.as_primitive::<Int32Type>();
452        assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2
453        assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2
454        assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2
455        assert_eq!(arr.value(3), 2); // 32 >>> 4 = 2
456    }
457
458    #[test]
459    fn test_shift_right_unsigned_int64() {
460        let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16]));
461        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
462        let result = SparkShiftRightUnsigned::new()
463            .spark_shift(&[value_array, shift_array])
464            .unwrap();
465        let arr = result.as_primitive::<Int64Type>();
466        assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2
467        assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2
468        assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2
469    }
470
471    #[test]
472    fn test_shift_right_unsigned_uint32() {
473        let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16]));
474        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
475        let result = SparkShiftRightUnsigned::new()
476            .spark_shift(&[value_array, shift_array])
477            .unwrap();
478        let arr = result.as_primitive::<UInt32Type>();
479        assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2
480        assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2
481        assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2
482    }
483
484    #[test]
485    fn test_shift_right_unsigned_uint64() {
486        let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16]));
487        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
488        let result = SparkShiftRightUnsigned::new()
489            .spark_shift(&[value_array, shift_array])
490            .unwrap();
491        let arr = result.as_primitive::<UInt64Type>();
492        assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2
493        assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2
494        assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2
495    }
496
497    #[test]
498    fn test_shift_right_unsigned_nulls() {
499        let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)]));
500        let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None]));
501        let result = SparkShiftRightUnsigned::new()
502            .spark_shift(&[value_array, shift_array])
503            .unwrap();
504        let arr = result.as_primitive::<Int32Type>();
505        assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2
506        assert!(arr.is_null(1)); // null >>> 2 = null
507        assert!(arr.is_null(2)); // 8 >>> null = null
508    }
509
510    #[test]
511    fn test_shift_right_unsigned_negative_shift() {
512        let value_array = Arc::new(Int32Array::from(vec![4, 8, 16]));
513        let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3]));
514        let result = SparkShiftRightUnsigned::new()
515            .spark_shift(&[value_array, shift_array])
516            .unwrap();
517        let arr = result.as_primitive::<Int32Type>();
518        assert_eq!(arr.value(0), 0); // 4 >>> -1 = 0
519        assert_eq!(arr.value(1), 0); // 8 >>> -2 = 0
520        assert_eq!(arr.value(2), 0); // 16 >>> -3 = 0
521    }
522
523    #[test]
524    fn test_shift_right_unsigned_negative_values() {
525        let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16]));
526        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
527        let result = SparkShiftRightUnsigned::new()
528            .spark_shift(&[value_array, shift_array])
529            .unwrap();
530        let arr = result.as_primitive::<Int32Type>();
531        // For unsigned right shift, negative values are treated as large positive values
532        // -4 as u32 = 4294967292, -4 >>> 1 = 2147483646
533        assert_eq!(arr.value(0), 2147483646);
534        // -8 as u32 = 4294967288, -8 >>> 2 = 1073741822
535        assert_eq!(arr.value(1), 1073741822);
536        // -16 as u32 = 4294967280, -16 >>> 3 = 536870910
537        assert_eq!(arr.value(2), 536870910);
538    }
539
540    #[test]
541    fn test_shift_right_int32() {
542        let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32]));
543        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
544        let result = SparkShiftRight::new()
545            .spark_shift(&[value_array, shift_array])
546            .unwrap();
547        let arr = result.as_primitive::<Int32Type>();
548        assert_eq!(arr.value(0), 2); // 4 >> 1 = 2
549        assert_eq!(arr.value(1), 2); // 8 >> 2 = 2
550        assert_eq!(arr.value(2), 2); // 16 >> 3 = 2
551        assert_eq!(arr.value(3), 2); // 32 >> 4 = 2
552    }
553
554    #[test]
555    fn test_shift_right_int64() {
556        let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16]));
557        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
558        let result = SparkShiftRight::new()
559            .spark_shift(&[value_array, shift_array])
560            .unwrap();
561        let arr = result.as_primitive::<Int64Type>();
562        assert_eq!(arr.value(0), 2); // 4 >> 1 = 2
563        assert_eq!(arr.value(1), 2); // 8 >> 2 = 2
564        assert_eq!(arr.value(2), 2); // 16 >> 3 = 2
565    }
566
567    #[test]
568    fn test_shift_right_uint32() {
569        let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16]));
570        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
571        let result = SparkShiftRight::new()
572            .spark_shift(&[value_array, shift_array])
573            .unwrap();
574        let arr = result.as_primitive::<UInt32Type>();
575        assert_eq!(arr.value(0), 2); // 4 >> 1 = 2
576        assert_eq!(arr.value(1), 2); // 8 >> 2 = 2
577        assert_eq!(arr.value(2), 2); // 16 >> 3 = 2
578    }
579
580    #[test]
581    fn test_shift_right_uint64() {
582        let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16]));
583        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
584        let result = SparkShiftRight::new()
585            .spark_shift(&[value_array, shift_array])
586            .unwrap();
587        let arr = result.as_primitive::<UInt64Type>();
588        assert_eq!(arr.value(0), 2); // 4 >> 1 = 2
589        assert_eq!(arr.value(1), 2); // 8 >> 2 = 2
590        assert_eq!(arr.value(2), 2); // 16 >> 3 = 2
591    }
592
593    #[test]
594    fn test_shift_right_nulls() {
595        let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)]));
596        let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None]));
597        let result = SparkShiftRight::new()
598            .spark_shift(&[value_array, shift_array])
599            .unwrap();
600        let arr = result.as_primitive::<Int32Type>();
601        assert_eq!(arr.value(0), 2); // 4 >> 1 = 2
602        assert!(arr.is_null(1)); // null >> 2 = null
603        assert!(arr.is_null(2)); // 8 >> null = null
604    }
605
606    #[test]
607    fn test_shift_right_large_shift() {
608        let value_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
609        let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64]));
610        let result = SparkShiftRight::new()
611            .spark_shift(&[value_array, shift_array])
612            .unwrap();
613        let arr = result.as_primitive::<Int32Type>();
614        assert_eq!(arr.value(0), 1); // 1 >> 32 = 1
615        assert_eq!(arr.value(1), 1); // 2 >> 33 = 1
616        assert_eq!(arr.value(2), 3); // 3 >> 64 = 3
617    }
618
619    #[test]
620    fn test_shift_right_negative_shift() {
621        let value_array = Arc::new(Int32Array::from(vec![4, 8, 16]));
622        let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3]));
623        let result = SparkShiftRight::new()
624            .spark_shift(&[value_array, shift_array])
625            .unwrap();
626        let arr = result.as_primitive::<Int32Type>();
627        assert_eq!(arr.value(0), 0); // 4 >> -1 = 0
628        assert_eq!(arr.value(1), 0); // 8 >> -2 = 0
629        assert_eq!(arr.value(2), 0); // 16 >> -3 = 0
630    }
631
632    #[test]
633    fn test_shift_right_negative_values() {
634        let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16]));
635        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
636        let result = SparkShiftRight::new()
637            .spark_shift(&[value_array, shift_array])
638            .unwrap();
639        let arr = result.as_primitive::<Int32Type>();
640        // For signed integers, right shift preserves the sign bit
641        assert_eq!(arr.value(0), -2); // -4 >> 1 = -2
642        assert_eq!(arr.value(1), -2); // -8 >> 2 = -2
643        assert_eq!(arr.value(2), -2); // -16 >> 3 = -2
644    }
645
646    #[test]
647    fn test_shift_left_int32() {
648        let value_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
649        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
650        let result = SparkShiftLeft::new()
651            .spark_shift(&[value_array, shift_array])
652            .unwrap();
653        let arr = result.as_primitive::<Int32Type>();
654        assert_eq!(arr.value(0), 2); // 1 << 1 = 2
655        assert_eq!(arr.value(1), 8); // 2 << 2 = 8
656        assert_eq!(arr.value(2), 24); // 3 << 3 = 24
657        assert_eq!(arr.value(3), 64); // 4 << 4 = 64
658    }
659
660    #[test]
661    fn test_shift_left_int64() {
662        let value_array = Arc::new(Int64Array::from(vec![1i64, 2, 3]));
663        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
664        let result = SparkShiftLeft::new()
665            .spark_shift(&[value_array, shift_array])
666            .unwrap();
667        let arr = result.as_primitive::<Int64Type>();
668        assert_eq!(arr.value(0), 2); // 1 << 1 = 2
669        assert_eq!(arr.value(1), 8); // 2 << 2 = 8
670        assert_eq!(arr.value(2), 24); // 3 << 3 = 24
671    }
672
673    #[test]
674    fn test_shift_left_uint32() {
675        let value_array = Arc::new(UInt32Array::from(vec![1u32, 2, 3]));
676        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
677        let result = SparkShiftLeft::new()
678            .spark_shift(&[value_array, shift_array])
679            .unwrap();
680        let arr = result.as_primitive::<UInt32Type>();
681        assert_eq!(arr.value(0), 2); // 1 << 1 = 2
682        assert_eq!(arr.value(1), 8); // 2 << 2 = 8
683        assert_eq!(arr.value(2), 24); // 3 << 3 = 24
684    }
685
686    #[test]
687    fn test_shift_left_uint64() {
688        let value_array = Arc::new(UInt64Array::from(vec![1u64, 2, 3]));
689        let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
690        let result = SparkShiftLeft::new()
691            .spark_shift(&[value_array, shift_array])
692            .unwrap();
693        let arr = result.as_primitive::<UInt64Type>();
694        assert_eq!(arr.value(0), 2); // 1 << 1 = 2
695        assert_eq!(arr.value(1), 8); // 2 << 2 = 8
696        assert_eq!(arr.value(2), 24); // 3 << 3 = 24
697    }
698
699    #[test]
700    fn test_shift_left_nulls() {
701        let value_array = Arc::new(Int32Array::from(vec![Some(2), None, Some(3)]));
702        let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None]));
703        let result = SparkShiftLeft::new()
704            .spark_shift(&[value_array, shift_array])
705            .unwrap();
706        let arr = result.as_primitive::<Int32Type>();
707        assert_eq!(arr.value(0), 4); // 2 << 1 = 4
708        assert!(arr.is_null(1)); // null << 2 = null
709        assert!(arr.is_null(2)); // 3 << null = null
710    }
711
712    #[test]
713    fn test_shift_left_large_shift() {
714        let value_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
715        let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64]));
716        let result = SparkShiftLeft::new()
717            .spark_shift(&[value_array, shift_array])
718            .unwrap();
719        let arr = result.as_primitive::<Int32Type>();
720        assert_eq!(arr.value(0), 1); // 1 << 32 = 0 (overflow)
721        assert_eq!(arr.value(1), 4); // 2 << 33 = 0 (overflow)
722        assert_eq!(arr.value(2), 3); // 3 << 64 = 0 (overflow)
723    }
724
725    #[test]
726    fn test_shift_left_negative_shift() {
727        let value_array = Arc::new(Int32Array::from(vec![4, 8, 16]));
728        let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3]));
729        let result = SparkShiftLeft::new()
730            .spark_shift(&[value_array, shift_array])
731            .unwrap();
732        let arr = result.as_primitive::<Int32Type>();
733        assert_eq!(arr.value(0), 0); // 4 << -1 = 0
734        assert_eq!(arr.value(1), 0); // 8 << -2 = 0
735        assert_eq!(arr.value(2), 0); // 16 << -3 = 0
736    }
737}