Skip to main content

datafusion_spark/function/string/
luhn_check.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{Array, AsArray, BooleanArray};
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::Boolean;
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, exec_err};
25use datafusion_expr::{
26    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
27    Volatility,
28};
29
30/// Spark-compatible `luhn_check` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#luhn_check>
32#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLuhnCheck {
34    signature: Signature,
35}
36
37impl Default for SparkLuhnCheck {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl SparkLuhnCheck {
44    pub fn new() -> Self {
45        Self {
46            signature: Signature::one_of(
47                vec![
48                    TypeSignature::Exact(vec![DataType::Utf8]),
49                    TypeSignature::Exact(vec![DataType::Utf8View]),
50                    TypeSignature::Exact(vec![DataType::LargeUtf8]),
51                ],
52                Volatility::Immutable,
53            ),
54        }
55    }
56}
57
58impl ScalarUDFImpl for SparkLuhnCheck {
59    fn name(&self) -> &str {
60        "luhn_check"
61    }
62
63    fn signature(&self) -> &Signature {
64        &self.signature
65    }
66
67    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
68        Ok(Boolean)
69    }
70
71    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72        let [array] = take_function_args(self.name(), &args.args)?;
73
74        match array {
75            ColumnarValue::Array(array) => match array.data_type() {
76                DataType::Utf8View => {
77                    let str_array = array.as_string_view();
78                    let values = str_array
79                        .iter()
80                        .map(|s| s.map(luhn_check_impl))
81                        .collect::<BooleanArray>();
82                    Ok(ColumnarValue::Array(Arc::new(values)))
83                }
84                DataType::Utf8 => {
85                    let str_array = array.as_string::<i32>();
86                    let values = str_array
87                        .iter()
88                        .map(|s| s.map(luhn_check_impl))
89                        .collect::<BooleanArray>();
90                    Ok(ColumnarValue::Array(Arc::new(values)))
91                }
92                DataType::LargeUtf8 => {
93                    let str_array = array.as_string::<i64>();
94                    let values = str_array
95                        .iter()
96                        .map(|s| s.map(luhn_check_impl))
97                        .collect::<BooleanArray>();
98                    Ok(ColumnarValue::Array(Arc::new(values)))
99                }
100                other => {
101                    exec_err!("Unsupported data type {other:?} for function `luhn_check`")
102                }
103            },
104            ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
105            | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s)))
106            | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok(
107                ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))),
108            ),
109            ColumnarValue::Scalar(ScalarValue::Utf8(None))
110            | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))
111            | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {
112                Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
113            }
114            other => {
115                exec_err!("Unsupported data type {other:?} for function `luhn_check`")
116            }
117        }
118    }
119}
120
121/// Validates a string using the Luhn algorithm.
122/// Returns `true` if the input is a valid Luhn number.
123fn luhn_check_impl(input: &str) -> bool {
124    let mut sum = 0u32;
125    let mut alt = false;
126    let mut digits_processed = 0;
127
128    for b in input.as_bytes().iter().rev() {
129        let digit = match b {
130            b'0'..=b'9' => {
131                digits_processed += 1;
132                b - b'0'
133            }
134            _ => return false,
135        };
136
137        let mut val = digit as u32;
138        if alt {
139            val *= 2;
140            if val > 9 {
141                val -= 9;
142            }
143        }
144        sum += val;
145        alt = !alt;
146    }
147
148    digits_processed > 0 && sum.is_multiple_of(10)
149}