datafusion_comet_spark_expr/bitwise_funcs/
bitwise_not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{array::*, datatypes::DataType};
19use datafusion::common::{
20    exec_err, internal_datafusion_err, internal_err, DataFusionError, Result,
21};
22use datafusion::logical_expr::{ColumnarValue, Volatility};
23use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
24use std::{any::Any, sync::Arc};
25
26#[derive(Debug)]
27pub struct SparkBitwiseNot {
28    signature: Signature,
29    aliases: Vec<String>,
30}
31
32impl Default for SparkBitwiseNot {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl SparkBitwiseNot {
39    pub fn new() -> Self {
40        Self {
41            signature: Signature::user_defined(Volatility::Immutable),
42            aliases: vec![],
43        }
44    }
45}
46
47impl ScalarUDFImpl for SparkBitwiseNot {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn name(&self) -> &str {
53        "bit_not"
54    }
55
56    fn signature(&self) -> &Signature {
57        &self.signature
58    }
59
60    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
61        Ok(match arg_types[0] {
62            DataType::Int8 => DataType::Int8,
63            DataType::Int16 => DataType::Int16,
64            DataType::Int32 => DataType::Int32,
65            DataType::Int64 => DataType::Int64,
66            DataType::Null => DataType::Null,
67            _ => return exec_err!("{} function can only accept integral arrays", self.name()),
68        })
69    }
70
71    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72        let args: [ColumnarValue; 1] = args
73            .args
74            .try_into()
75            .map_err(|_| internal_datafusion_err!("bit_not expects exactly one argument"))?;
76        bitwise_not(args)
77    }
78
79    fn aliases(&self) -> &[String] {
80        &self.aliases
81    }
82}
83
84macro_rules! compute_op {
85    ($OPERAND:expr, $DT:ident) => {{
86        let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
87            DataFusionError::Execution(format!(
88                "compute_op failed to downcast array to: {:?}",
89                stringify!($DT)
90            ))
91        })?;
92        let result: $DT = operand.iter().map(|x| x.map(|y| !y)).collect();
93        Ok(Arc::new(result))
94    }};
95}
96
97pub fn bitwise_not(args: [ColumnarValue; 1]) -> Result<ColumnarValue> {
98    match args {
99        [ColumnarValue::Array(array)] => {
100            let result: Result<ArrayRef> = match array.data_type() {
101                DataType::Int8 => compute_op!(array, Int8Array),
102                DataType::Int16 => compute_op!(array, Int16Array),
103                DataType::Int32 => compute_op!(array, Int32Array),
104                DataType::Int64 => compute_op!(array, Int64Array),
105                _ => exec_err!("bit_not can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()),
106            };
107            result.map(ColumnarValue::Array)
108        }
109        [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise not scalar path"),
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use datafusion::common::{cast::as_int32_array, Result};
116
117    use super::*;
118
119    #[test]
120    fn bitwise_not_op() -> Result<()> {
121        let int_array = Int32Array::from(vec![
122            Some(1),
123            Some(2),
124            None,
125            Some(12345),
126            Some(89),
127            Some(-3456),
128        ]);
129        let expected = &Int32Array::from(vec![
130            Some(-2),
131            Some(-3),
132            None,
133            Some(-12346),
134            Some(-90),
135            Some(3455),
136        ]);
137
138        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
139
140        let result = bitwise_not([columnar_value])?;
141        let result = match result {
142            ColumnarValue::Array(array) => array,
143            _ => panic!("Expected array"),
144        };
145        let result = as_int32_array(&result).expect("failed to downcast to In32Array");
146        assert_eq!(result, expected);
147
148        Ok(())
149    }
150}