use arrow::{array::*, datatypes::DataType};
use datafusion::common::{
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result,
};
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use std::{any::Any, sync::Arc};
#[derive(Debug)]
pub struct SparkBitwiseNot {
signature: Signature,
aliases: Vec<String>,
}
impl Default for SparkBitwiseNot {
fn default() -> Self {
Self::new()
}
}
impl SparkBitwiseNot {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![],
}
}
}
impl ScalarUDFImpl for SparkBitwiseNot {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"bit_not"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(match arg_types[0] {
DataType::Int8 => DataType::Int8,
DataType::Int16 => DataType::Int16,
DataType::Int32 => DataType::Int32,
DataType::Int64 => DataType::Int64,
DataType::Null => DataType::Null,
_ => return exec_err!("{} function can only accept integral arrays", self.name()),
})
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args: [ColumnarValue; 1] = args
.args
.try_into()
.map_err(|_| internal_datafusion_err!("bit_not expects exactly one argument"))?;
bitwise_not(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
macro_rules! compute_op {
($OPERAND:expr, $DT:ident) => {{
let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
DataFusionError::Execution(format!(
"compute_op failed to downcast array to: {:?}",
stringify!($DT)
))
})?;
let result: $DT = operand.iter().map(|x| x.map(|y| !y)).collect();
Ok(Arc::new(result))
}};
}
pub fn bitwise_not(args: [ColumnarValue; 1]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Array(array)] => {
let result: Result<ArrayRef> = match array.data_type() {
DataType::Int8 => compute_op!(array, Int8Array),
DataType::Int16 => compute_op!(array, Int16Array),
DataType::Int32 => compute_op!(array, Int32Array),
DataType::Int64 => compute_op!(array, Int64Array),
_ => exec_err!("bit_not can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()),
};
result.map(ColumnarValue::Array)
}
[ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise not scalar path"),
}
}
#[cfg(test)]
mod tests {
use datafusion::common::{cast::as_int32_array, Result};
use super::*;
#[test]
fn bitwise_not_op() -> Result<()> {
let int_array = Int32Array::from(vec![
Some(1),
Some(2),
None,
Some(12345),
Some(89),
Some(-3456),
]);
let expected = &Int32Array::from(vec![
Some(-2),
Some(-3),
None,
Some(-12346),
Some(-90),
Some(3455),
]);
let columnar_value = ColumnarValue::Array(Arc::new(int_array));
let result = bitwise_not([columnar_value])?;
let result = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};
let result = as_int32_array(&result).expect("failed to downcast to In32Array");
assert_eq!(result, expected);
Ok(())
}
}