Skip to main content

datafusion_spark/function/bitwise/
bitwise_not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::*;
19use arrow::compute::kernels::bitwise;
20use arrow::datatypes::{
21    DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type,
22};
23use datafusion_common::{Result, internal_err, plan_err};
24use datafusion_expr::{ColumnarValue, TypeSignature, Volatility};
25use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature};
26use datafusion_functions::utils::make_scalar_function;
27use std::{any::Any, sync::Arc};
28
29#[derive(Debug, PartialEq, Eq, Hash)]
30pub struct SparkBitwiseNot {
31    signature: Signature,
32}
33
34impl Default for SparkBitwiseNot {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl SparkBitwiseNot {
41    pub fn new() -> Self {
42        Self {
43            signature: Signature::one_of(
44                vec![
45                    TypeSignature::Exact(vec![DataType::Int8]),
46                    TypeSignature::Exact(vec![DataType::Int16]),
47                    TypeSignature::Exact(vec![DataType::Int32]),
48                    TypeSignature::Exact(vec![DataType::Int64]),
49                ],
50                Volatility::Immutable,
51            ),
52        }
53    }
54}
55
56impl ScalarUDFImpl for SparkBitwiseNot {
57    fn as_any(&self) -> &dyn Any {
58        self
59    }
60
61    fn name(&self) -> &str {
62        "bitwise_not"
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70        internal_err!(
71            "SparkBitwiseNot: return_type() is not used; return_field_from_args() is implemented"
72        )
73    }
74
75    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
76        if args.arg_fields.len() != 1 {
77            return plan_err!("bitwise_not expects exactly 1 argument");
78        }
79
80        let input_field = &args.arg_fields[0];
81
82        let out_dt = input_field.data_type().clone();
83        let mut out_nullable = input_field.is_nullable();
84
85        let scalar_null_present = args
86            .scalar_arguments
87            .iter()
88            .any(|opt_s| opt_s.is_some_and(|sv| sv.is_null()));
89
90        if scalar_null_present {
91            out_nullable = true;
92        }
93
94        Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
95    }
96
97    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98        if args.args.len() != 1 {
99            return plan_err!("bitwise_not expects exactly 1 argument");
100        }
101        make_scalar_function(spark_bitwise_not, vec![])(&args.args)
102    }
103}
104
105pub fn spark_bitwise_not(args: &[ArrayRef]) -> Result<ArrayRef> {
106    let array = args[0].as_ref();
107    match array.data_type() {
108        DataType::Int8 => {
109            let result: Int8Array =
110                bitwise::bitwise_not(array.as_primitive::<Int8Type>())?;
111            Ok(Arc::new(result))
112        }
113        DataType::Int16 => {
114            let result: Int16Array =
115                bitwise::bitwise_not(array.as_primitive::<Int16Type>())?;
116            Ok(Arc::new(result))
117        }
118        DataType::Int32 => {
119            let result: Int32Array =
120                bitwise::bitwise_not(array.as_primitive::<Int32Type>())?;
121            Ok(Arc::new(result))
122        }
123        DataType::Int64 => {
124            let result: Int64Array =
125                bitwise::bitwise_not(array.as_primitive::<Int64Type>())?;
126            Ok(Arc::new(result))
127        }
128        _ => {
129            plan_err!(
130                "bitwise_not function does not support data type: {}",
131                array.data_type()
132            )
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use arrow::datatypes::{DataType, Field};
141    use std::sync::Arc;
142
143    use datafusion_expr::ReturnFieldArgs;
144
145    #[test]
146    fn test_bitwise_not_nullability() {
147        let bitwise_not = SparkBitwiseNot::new();
148
149        // --- non-nullable Int32 input ---
150        let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
151        let out_non_null = bitwise_not
152            .return_field_from_args(ReturnFieldArgs {
153                arg_fields: &[Arc::clone(&non_nullable_i32)],
154                // single-argument function -> one scalar_argument slot (None)
155                scalar_arguments: &[None],
156            })
157            .unwrap();
158
159        // result should be non-nullable and the same DataType as input
160        assert!(!out_non_null.is_nullable());
161        assert_eq!(out_non_null.data_type(), &DataType::Int32);
162
163        // --- nullable Int32 input ---
164        let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
165        let out_nullable = bitwise_not
166            .return_field_from_args(ReturnFieldArgs {
167                arg_fields: &[Arc::clone(&nullable_i32)],
168                scalar_arguments: &[None],
169            })
170            .unwrap();
171
172        // result should be nullable and the same DataType as input
173        assert!(out_nullable.is_nullable());
174        assert_eq!(out_nullable.data_type(), &DataType::Int32);
175
176        // --- also test another integer type (Int64) for completeness ---
177        let non_nullable_i64 = Arc::new(Field::new("c", DataType::Int64, false));
178        let out_i64 = bitwise_not
179            .return_field_from_args(ReturnFieldArgs {
180                arg_fields: &[Arc::clone(&non_nullable_i64)],
181                scalar_arguments: &[None],
182            })
183            .unwrap();
184
185        assert!(!out_i64.is_nullable());
186        assert_eq!(out_i64.data_type(), &DataType::Int64);
187
188        let nullable_i64 = Arc::new(Field::new("c", DataType::Int64, true));
189        let out_i64_null = bitwise_not
190            .return_field_from_args(ReturnFieldArgs {
191                arg_fields: &[Arc::clone(&nullable_i64)],
192                scalar_arguments: &[None],
193            })
194            .unwrap();
195
196        assert!(out_i64_null.is_nullable());
197        assert_eq!(out_i64_null.data_type(), &DataType::Int64);
198    }
199
200    #[test]
201    fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> {
202        use arrow::datatypes::{DataType, Field};
203        use datafusion_common::ScalarValue;
204        use std::sync::Arc;
205
206        let func = SparkBitwiseNot::new();
207
208        let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false));
209
210        let out = func.return_field_from_args(ReturnFieldArgs {
211            arg_fields: &[Arc::clone(&non_nullable)],
212            scalar_arguments: &[None],
213        })?;
214        assert!(!out.is_nullable());
215        assert_eq!(out.data_type(), &DataType::Int32);
216
217        let null_scalar = ScalarValue::Int32(None);
218        let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs {
219            arg_fields: &[Arc::clone(&non_nullable)],
220            scalar_arguments: &[Some(&null_scalar)],
221        })?;
222        assert!(out_with_null_scalar.is_nullable());
223        assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32);
224
225        Ok(())
226    }
227}