datafusion_spark/function/string/
luhn_check.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, sync::Arc};
19
20use arrow::array::{Array, AsArray, BooleanArray};
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::Boolean;
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{exec_err, Result, ScalarValue};
25use datafusion_expr::{
26    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
27    Volatility,
28};
29
30/// Spark-compatible `luhn_check` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#luhn_check>
32#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLuhnCheck {
34    signature: Signature,
35}
36
37impl Default for SparkLuhnCheck {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl SparkLuhnCheck {
44    pub fn new() -> Self {
45        Self {
46            signature: Signature::one_of(
47                vec![
48                    TypeSignature::Exact(vec![DataType::Utf8]),
49                    TypeSignature::Exact(vec![DataType::Utf8View]),
50                    TypeSignature::Exact(vec![DataType::LargeUtf8]),
51                ],
52                Volatility::Immutable,
53            ),
54        }
55    }
56}
57
58impl ScalarUDFImpl for SparkLuhnCheck {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn name(&self) -> &str {
64        "luhn_check"
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72        Ok(Boolean)
73    }
74
75    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76        let [array] = take_function_args(self.name(), &args.args)?;
77
78        match array {
79            ColumnarValue::Array(array) => match array.data_type() {
80                DataType::Utf8View => {
81                    let str_array = array.as_string_view();
82                    let values = str_array
83                        .iter()
84                        .map(|s| s.map(luhn_check_impl))
85                        .collect::<BooleanArray>();
86                    Ok(ColumnarValue::Array(Arc::new(values)))
87                }
88                DataType::Utf8 => {
89                    let str_array = array.as_string::<i32>();
90                    let values = str_array
91                        .iter()
92                        .map(|s| s.map(luhn_check_impl))
93                        .collect::<BooleanArray>();
94                    Ok(ColumnarValue::Array(Arc::new(values)))
95                }
96                DataType::LargeUtf8 => {
97                    let str_array = array.as_string::<i64>();
98                    let values = str_array
99                        .iter()
100                        .map(|s| s.map(luhn_check_impl))
101                        .collect::<BooleanArray>();
102                    Ok(ColumnarValue::Array(Arc::new(values)))
103                }
104                other => {
105                    exec_err!("Unsupported data type {other:?} for function `luhn_check`")
106                }
107            },
108            ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
109            | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
110            | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
111                ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
112            ),
113            ColumnarValue::Scalar(ScalarValue::Utf8(None))
114            | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
115            | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
116                Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
117            }
118            other => {
119                exec_err!("Unsupported data type {other:?} for function `luhn_check`")
120            }
121        }
122    }
123}
124
125/// Validates a string using the Luhn algorithm.
126/// Returns `true` if the input is a valid Luhn number.
127fn luhn_check_impl(input: &str) -> bool {
128    let mut sum = 0u32;
129    let mut alt = false;
130    let mut digits_processed = 0;
131
132    for b in input.as_bytes().iter().rev() {
133        let digit = match b {
134            b'0'..=b'9' => {
135                digits_processed += 1;
136                b - b'0'
137            }
138            _ => return false,
139        };
140
141        let mut val = digit as u32;
142        if alt {
143            val *= 2;
144            if val > 9 {
145                val -= 9;
146            }
147        }
148        sum += val;
149        alt = !alt;
150    }
151
152    digits_processed > 0 && sum % 10 == 0
153}