use std::{any::Any, sync::Arc};
use arrow::array::{Array, AsArray, BooleanArray};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Boolean;
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkLuhnCheck {
signature: Signature,
}
impl Default for SparkLuhnCheck {
fn default() -> Self {
Self::new()
}
}
impl SparkLuhnCheck {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8View]),
TypeSignature::Exact(vec![DataType::LargeUtf8]),
],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for SparkLuhnCheck {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"luhn_check"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [array] = take_function_args(self.name(), &args.args)?;
match array {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8View => {
let str_array = array.as_string_view();
let values = str_array
.iter()
.map(|s| s.map(luhn_check_impl))
.collect::<BooleanArray>();
Ok(ColumnarValue::Array(Arc::new(values)))
}
DataType::Utf8 => {
let str_array = array.as_string::<i32>();
let values = str_array
.iter()
.map(|s| s.map(luhn_check_impl))
.collect::<BooleanArray>();
Ok(ColumnarValue::Array(Arc::new(values)))
}
DataType::LargeUtf8 => {
let str_array = array.as_string::<i64>();
let values = str_array
.iter()
.map(|s| s.map(luhn_check_impl))
.collect::<BooleanArray>();
Ok(ColumnarValue::Array(Arc::new(values)))
}
other => {
exec_err!("Unsupported data type {other:?} for function `luhn_check`")
}
},
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
),
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
}
other => {
exec_err!("Unsupported data type {other:?} for function `luhn_check`")
}
}
}
}
fn luhn_check_impl(input: &str) -> bool {
let mut sum = 0u32;
let mut alt = false;
let mut digits_processed = 0;
for b in input.as_bytes().iter().rev() {
let digit = match b {
b'0'..=b'9' => {
digits_processed += 1;
b - b'0'
}
_ => return false,
};
let mut val = digit as u32;
if alt {
val *= 2;
if val > 9 {
val -= 9;
}
}
sum += val;
alt = !alt;
}
digits_processed > 0 && sum.is_multiple_of(10)
}