Skip to main content

datafusion_spark/function/bitwise/
bitwise_not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::*;
19use arrow::compute::kernels::bitwise;
20use arrow::datatypes::{
21    DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type,
22};
23use datafusion_common::{Result, internal_err, plan_err};
24use datafusion_expr::{ColumnarValue, TypeSignature, Volatility};
25use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature};
26use datafusion_functions::utils::make_scalar_function;
27use std::sync::Arc;
28
29#[derive(Debug, PartialEq, Eq, Hash)]
30pub struct SparkBitwiseNot {
31    signature: Signature,
32}
33
34impl Default for SparkBitwiseNot {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl SparkBitwiseNot {
41    pub fn new() -> Self {
42        Self {
43            signature: Signature::one_of(
44                vec![
45                    TypeSignature::Exact(vec![DataType::Int8]),
46                    TypeSignature::Exact(vec![DataType::Int16]),
47                    TypeSignature::Exact(vec![DataType::Int32]),
48                    TypeSignature::Exact(vec![DataType::Int64]),
49                ],
50                Volatility::Immutable,
51            ),
52        }
53    }
54}
55
56impl ScalarUDFImpl for SparkBitwiseNot {
57    fn name(&self) -> &str {
58        "bitwise_not"
59    }
60
61    fn signature(&self) -> &Signature {
62        &self.signature
63    }
64
65    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66        internal_err!(
67            "SparkBitwiseNot: return_type() is not used; return_field_from_args() is implemented"
68        )
69    }
70
71    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
72        Ok(Arc::new(Field::new(
73            self.name(),
74            args.arg_fields[0].data_type().clone(),
75            args.arg_fields[0].is_nullable(),
76        )))
77    }
78
79    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80        if args.args.len() != 1 {
81            return plan_err!("bitwise_not expects exactly 1 argument");
82        }
83        make_scalar_function(spark_bitwise_not, vec![])(&args.args)
84    }
85}
86
87pub fn spark_bitwise_not(args: &[ArrayRef]) -> Result<ArrayRef> {
88    let array = args[0].as_ref();
89    match array.data_type() {
90        DataType::Int8 => {
91            let result: Int8Array =
92                bitwise::bitwise_not(array.as_primitive::<Int8Type>())?;
93            Ok(Arc::new(result))
94        }
95        DataType::Int16 => {
96            let result: Int16Array =
97                bitwise::bitwise_not(array.as_primitive::<Int16Type>())?;
98            Ok(Arc::new(result))
99        }
100        DataType::Int32 => {
101            let result: Int32Array =
102                bitwise::bitwise_not(array.as_primitive::<Int32Type>())?;
103            Ok(Arc::new(result))
104        }
105        DataType::Int64 => {
106            let result: Int64Array =
107                bitwise::bitwise_not(array.as_primitive::<Int64Type>())?;
108            Ok(Arc::new(result))
109        }
110        _ => {
111            plan_err!(
112                "bitwise_not function does not support data type: {}",
113                array.data_type()
114            )
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use std::sync::Arc;
123
124    #[test]
125    fn test_bitwise_not_nullability() {
126        let bitwise_not = SparkBitwiseNot::new();
127
128        // --- non-nullable Int32 input ---
129        let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
130        let out_non_null = bitwise_not
131            .return_field_from_args(ReturnFieldArgs {
132                arg_fields: &[Arc::clone(&non_nullable_i32)],
133                // single-argument function -> one scalar_argument slot (None)
134                scalar_arguments: &[None],
135            })
136            .unwrap();
137
138        // result should be non-nullable and the same DataType as input
139        assert!(!out_non_null.is_nullable());
140        assert_eq!(out_non_null.data_type(), &DataType::Int32);
141
142        // --- nullable Int32 input ---
143        let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
144        let out_nullable = bitwise_not
145            .return_field_from_args(ReturnFieldArgs {
146                arg_fields: &[Arc::clone(&nullable_i32)],
147                scalar_arguments: &[None],
148            })
149            .unwrap();
150
151        // result should be nullable and the same DataType as input
152        assert!(out_nullable.is_nullable());
153        assert_eq!(out_nullable.data_type(), &DataType::Int32);
154
155        // --- also test another integer type (Int64) for completeness ---
156        let non_nullable_i64 = Arc::new(Field::new("c", DataType::Int64, false));
157        let out_i64 = bitwise_not
158            .return_field_from_args(ReturnFieldArgs {
159                arg_fields: &[Arc::clone(&non_nullable_i64)],
160                scalar_arguments: &[None],
161            })
162            .unwrap();
163
164        assert!(!out_i64.is_nullable());
165        assert_eq!(out_i64.data_type(), &DataType::Int64);
166
167        let nullable_i64 = Arc::new(Field::new("c", DataType::Int64, true));
168        let out_i64_null = bitwise_not
169            .return_field_from_args(ReturnFieldArgs {
170                arg_fields: &[Arc::clone(&nullable_i64)],
171                scalar_arguments: &[None],
172            })
173            .unwrap();
174
175        assert!(out_i64_null.is_nullable());
176        assert_eq!(out_i64_null.data_type(), &DataType::Int64);
177    }
178}