datafusion_comet_spark_expr/bitwise_funcs/
bitwise_not.rs1use arrow::{array::*, datatypes::DataType};
19use datafusion::common::{
20 exec_err, internal_datafusion_err, internal_err, DataFusionError, Result,
21};
22use datafusion::logical_expr::{ColumnarValue, Volatility};
23use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
24use std::{any::Any, sync::Arc};
25
26#[derive(Debug)]
27pub struct SparkBitwiseNot {
28 signature: Signature,
29 aliases: Vec<String>,
30}
31
32impl Default for SparkBitwiseNot {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl SparkBitwiseNot {
39 pub fn new() -> Self {
40 Self {
41 signature: Signature::user_defined(Volatility::Immutable),
42 aliases: vec![],
43 }
44 }
45}
46
47impl ScalarUDFImpl for SparkBitwiseNot {
48 fn as_any(&self) -> &dyn Any {
49 self
50 }
51
52 fn name(&self) -> &str {
53 "bit_not"
54 }
55
56 fn signature(&self) -> &Signature {
57 &self.signature
58 }
59
60 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
61 Ok(match arg_types[0] {
62 DataType::Int8 => DataType::Int8,
63 DataType::Int16 => DataType::Int16,
64 DataType::Int32 => DataType::Int32,
65 DataType::Int64 => DataType::Int64,
66 DataType::Null => DataType::Null,
67 _ => return exec_err!("{} function can only accept integral arrays", self.name()),
68 })
69 }
70
71 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72 let args: [ColumnarValue; 1] = args
73 .args
74 .try_into()
75 .map_err(|_| internal_datafusion_err!("bit_not expects exactly one argument"))?;
76 bitwise_not(args)
77 }
78
79 fn aliases(&self) -> &[String] {
80 &self.aliases
81 }
82}
83
84macro_rules! compute_op {
85 ($OPERAND:expr, $DT:ident) => {{
86 let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
87 DataFusionError::Execution(format!(
88 "compute_op failed to downcast array to: {:?}",
89 stringify!($DT)
90 ))
91 })?;
92 let result: $DT = operand.iter().map(|x| x.map(|y| !y)).collect();
93 Ok(Arc::new(result))
94 }};
95}
96
97pub fn bitwise_not(args: [ColumnarValue; 1]) -> Result<ColumnarValue> {
98 match args {
99 [ColumnarValue::Array(array)] => {
100 let result: Result<ArrayRef> = match array.data_type() {
101 DataType::Int8 => compute_op!(array, Int8Array),
102 DataType::Int16 => compute_op!(array, Int16Array),
103 DataType::Int32 => compute_op!(array, Int32Array),
104 DataType::Int64 => compute_op!(array, Int64Array),
105 _ => exec_err!("bit_not can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()),
106 };
107 result.map(ColumnarValue::Array)
108 }
109 [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise not scalar path"),
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use datafusion::common::{cast::as_int32_array, Result};
116
117 use super::*;
118
119 #[test]
120 fn bitwise_not_op() -> Result<()> {
121 let int_array = Int32Array::from(vec![
122 Some(1),
123 Some(2),
124 None,
125 Some(12345),
126 Some(89),
127 Some(-3456),
128 ]);
129 let expected = &Int32Array::from(vec![
130 Some(-2),
131 Some(-3),
132 None,
133 Some(-12346),
134 Some(-90),
135 Some(3455),
136 ]);
137
138 let columnar_value = ColumnarValue::Array(Arc::new(int_array));
139
140 let result = bitwise_not([columnar_value])?;
141 let result = match result {
142 ColumnarValue::Array(array) => array,
143 _ => panic!("Expected array"),
144 };
145 let result = as_int32_array(&result).expect("failed to downcast to In32Array");
146 assert_eq!(result, expected);
147
148 Ok(())
149 }
150}