datafusion_comet_spark_expr/bitwise_funcs/
bitwise_not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::*,
20    datatypes::{DataType, Schema},
21    record_batch::RecordBatch,
22};
23use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
24use datafusion_common::Result;
25use datafusion_physical_expr::PhysicalExpr;
26use std::hash::Hash;
27use std::{any::Any, sync::Arc};
28
29macro_rules! compute_op {
30    ($OPERAND:expr, $DT:ident) => {{
31        let operand = $OPERAND
32            .as_any()
33            .downcast_ref::<$DT>()
34            .expect("compute_op failed to downcast array");
35        let result: $DT = operand.iter().map(|x| x.map(|y| !y)).collect();
36        Ok(Arc::new(result))
37    }};
38}
39
40/// BitwiseNot expression
41#[derive(Debug, Eq)]
42pub struct BitwiseNotExpr {
43    /// Input expression
44    arg: Arc<dyn PhysicalExpr>,
45}
46
47impl Hash for BitwiseNotExpr {
48    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
49        self.arg.hash(state);
50    }
51}
52
53impl PartialEq for BitwiseNotExpr {
54    fn eq(&self, other: &Self) -> bool {
55        self.arg.eq(&other.arg)
56    }
57}
58
59impl BitwiseNotExpr {
60    /// Create new bitwise not expression
61    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
62        Self { arg }
63    }
64
65    /// Get the input expression
66    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
67        &self.arg
68    }
69}
70
71impl std::fmt::Display for BitwiseNotExpr {
72    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
73        write!(f, "(~ {})", self.arg)
74    }
75}
76
77impl PhysicalExpr for BitwiseNotExpr {
78    /// Return a reference to Any that can be used for downcasting
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
84        self.arg.data_type(input_schema)
85    }
86
87    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
88        self.arg.nullable(input_schema)
89    }
90
91    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
92        let arg = self.arg.evaluate(batch)?;
93        match arg {
94            ColumnarValue::Array(array) => {
95                let result: Result<ArrayRef> = match array.data_type() {
96                    DataType::Int8 => compute_op!(array, Int8Array),
97                    DataType::Int16 => compute_op!(array, Int16Array),
98                    DataType::Int32 => compute_op!(array, Int32Array),
99                    DataType::Int64 => compute_op!(array, Int64Array),
100                    _ => Err(DataFusionError::Execution(format!(
101                        "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int",
102                        self,
103                        array.data_type(),
104                    ))),
105                };
106                result.map(ColumnarValue::Array)
107            }
108            ColumnarValue::Scalar(_) => Err(DataFusionError::Internal(
109                "shouldn't go to bitwise not scalar path".to_string(),
110            )),
111        }
112    }
113
114    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
115        vec![&self.arg]
116    }
117
118    fn with_new_children(
119        self: Arc<Self>,
120        children: Vec<Arc<dyn PhysicalExpr>>,
121    ) -> Result<Arc<dyn PhysicalExpr>> {
122        Ok(Arc::new(BitwiseNotExpr::new(Arc::clone(&children[0]))))
123    }
124}
125
126pub fn bitwise_not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
127    Ok(Arc::new(BitwiseNotExpr::new(arg)))
128}
129
130#[cfg(test)]
131mod tests {
132    use arrow::datatypes::*;
133    use datafusion_common::{cast::as_int32_array, Result};
134    use datafusion_physical_expr::expressions::col;
135
136    use super::*;
137
138    #[test]
139    fn bitwise_not_op() -> Result<()> {
140        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
141
142        let expr = bitwise_not(col("a", &schema)?)?;
143
144        let input = Int32Array::from(vec![
145            Some(1),
146            Some(2),
147            None,
148            Some(12345),
149            Some(89),
150            Some(-3456),
151        ]);
152        let expected = &Int32Array::from(vec![
153            Some(-2),
154            Some(-3),
155            None,
156            Some(-12346),
157            Some(-90),
158            Some(3455),
159        ]);
160
161        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
162
163        let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
164        let result = as_int32_array(&result).expect("failed to downcast to In32Array");
165        assert_eq!(result, expected);
166
167        Ok(())
168    }
169}