datafusion_spark/function/string/
luhn_check.rs1use std::{any::Any, sync::Arc};
19
20use arrow::array::{Array, AsArray, BooleanArray};
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::Boolean;
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{exec_err, Result, ScalarValue};
25use datafusion_expr::{
26 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
27 Volatility,
28};
29
30#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLuhnCheck {
34 signature: Signature,
35}
36
37impl Default for SparkLuhnCheck {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl SparkLuhnCheck {
44 pub fn new() -> Self {
45 Self {
46 signature: Signature::one_of(
47 vec![
48 TypeSignature::Exact(vec![DataType::Utf8]),
49 TypeSignature::Exact(vec![DataType::Utf8View]),
50 TypeSignature::Exact(vec![DataType::LargeUtf8]),
51 ],
52 Volatility::Immutable,
53 ),
54 }
55 }
56}
57
58impl ScalarUDFImpl for SparkLuhnCheck {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn name(&self) -> &str {
64 "luhn_check"
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72 Ok(Boolean)
73 }
74
75 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76 let [array] = take_function_args(self.name(), &args.args)?;
77
78 match array {
79 ColumnarValue::Array(array) => match array.data_type() {
80 DataType::Utf8View => {
81 let str_array = array.as_string_view();
82 let values = str_array
83 .iter()
84 .map(|s| s.map(luhn_check_impl))
85 .collect::<BooleanArray>();
86 Ok(ColumnarValue::Array(Arc::new(values)))
87 }
88 DataType::Utf8 => {
89 let str_array = array.as_string::<i32>();
90 let values = str_array
91 .iter()
92 .map(|s| s.map(luhn_check_impl))
93 .collect::<BooleanArray>();
94 Ok(ColumnarValue::Array(Arc::new(values)))
95 }
96 DataType::LargeUtf8 => {
97 let str_array = array.as_string::<i64>();
98 let values = str_array
99 .iter()
100 .map(|s| s.map(luhn_check_impl))
101 .collect::<BooleanArray>();
102 Ok(ColumnarValue::Array(Arc::new(values)))
103 }
104 other => {
105 exec_err!("Unsupported data type {other:?} for function `luhn_check`")
106 }
107 },
108 ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
109 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
110 | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
111 ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
112 ),
113 ColumnarValue::Scalar(ScalarValue::Utf8(None))
114 | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
115 | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
116 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
117 }
118 other => {
119 exec_err!("Unsupported data type {other:?} for function `luhn_check`")
120 }
121 }
122 }
123}
124
125fn luhn_check_impl(input: &str) -> bool {
128 let mut sum = 0u32;
129 let mut alt = false;
130 let mut digits_processed = 0;
131
132 for b in input.as_bytes().iter().rev() {
133 let digit = match b {
134 b'0'..=b'9' => {
135 digits_processed += 1;
136 b - b'0'
137 }
138 _ => return false,
139 };
140
141 let mut val = digit as u32;
142 if alt {
143 val *= 2;
144 if val > 9 {
145 val -= 9;
146 }
147 }
148 sum += val;
149 alt = !alt;
150 }
151
152 digits_processed > 0 && sum % 10 == 0
153}