datafusion_spark/function/string/
luhn_check.rs1use std::sync::Arc;
19
20use arrow::array::{Array, AsArray, BooleanArray};
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::Boolean;
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, exec_err};
25use datafusion_expr::{
26 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
27 Volatility,
28};
29
30#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLuhnCheck {
34 signature: Signature,
35}
36
37impl Default for SparkLuhnCheck {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl SparkLuhnCheck {
44 pub fn new() -> Self {
45 Self {
46 signature: Signature::one_of(
47 vec![
48 TypeSignature::Exact(vec![DataType::Utf8]),
49 TypeSignature::Exact(vec![DataType::Utf8View]),
50 TypeSignature::Exact(vec![DataType::LargeUtf8]),
51 ],
52 Volatility::Immutable,
53 ),
54 }
55 }
56}
57
58impl ScalarUDFImpl for SparkLuhnCheck {
59 fn name(&self) -> &str {
60 "luhn_check"
61 }
62
63 fn signature(&self) -> &Signature {
64 &self.signature
65 }
66
67 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
68 Ok(Boolean)
69 }
70
71 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72 let [array] = take_function_args(self.name(), &args.args)?;
73
74 match array {
75 ColumnarValue::Array(array) => match array.data_type() {
76 DataType::Utf8View => {
77 let str_array = array.as_string_view();
78 let values = str_array
79 .iter()
80 .map(|s| s.map(luhn_check_impl))
81 .collect::<BooleanArray>();
82 Ok(ColumnarValue::Array(Arc::new(values)))
83 }
84 DataType::Utf8 => {
85 let str_array = array.as_string::<i32>();
86 let values = str_array
87 .iter()
88 .map(|s| s.map(luhn_check_impl))
89 .collect::<BooleanArray>();
90 Ok(ColumnarValue::Array(Arc::new(values)))
91 }
92 DataType::LargeUtf8 => {
93 let str_array = array.as_string::<i64>();
94 let values = str_array
95 .iter()
96 .map(|s| s.map(luhn_check_impl))
97 .collect::<BooleanArray>();
98 Ok(ColumnarValue::Array(Arc::new(values)))
99 }
100 other => {
101 exec_err!("Unsupported data type {other:?} for function `luhn_check`")
102 }
103 },
104 ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
105 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
106 | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
107 ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
108 ),
109 ColumnarValue::Scalar(ScalarValue::Utf8(None))
110 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
111 | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
112 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
113 }
114 other => {
115 exec_err!("Unsupported data type {other:?} for function `luhn_check`")
116 }
117 }
118 }
119}
120
121fn luhn_check_impl(input: &str) -> bool {
124 let mut sum = 0u32;
125 let mut alt = false;
126 let mut digits_processed = 0;
127
128 for b in input.as_bytes().iter().rev() {
129 let digit = match b {
130 b'0'..=b'9' => {
131 digits_processed += 1;
132 b - b'0'
133 }
134 _ => return false,
135 };
136
137 let mut val = digit as u32;
138 if alt {
139 val *= 2;
140 if val > 9 {
141 val -= 9;
142 }
143 }
144 sum += val;
145 alt = !alt;
146 }
147
148 digits_processed > 0 && sum.is_multiple_of(10)
149}