use arrow::array::{Array, BooleanArray};
use arrow::array::{Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
match is_nan.nulls() {
Some(nulls) => {
let is_not_null = nulls.inner();
ColumnarValue::Array(Arc::new(BooleanArray::new(
is_nan.values() & is_not_null,
None,
)))
}
None => ColumnarValue::Array(Arc::new(is_nan)),
}
}
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function isnan",
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
value.data_type(),
))),
},
}
}