datafusion_spark/function/hash/
sha2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18extern crate datafusion_functions;
19
20use crate::function::error_utils::{
21    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
22};
23use crate::function::math::hex::spark_sha2_hex;
24use arrow::array::{ArrayRef, AsArray, StringArray};
25use arrow::datatypes::{DataType, Int32Type};
26use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
27use datafusion_expr::Signature;
28use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
29pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512};
30use std::any::Any;
31use std::sync::Arc;
32
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkSha2 {
36    signature: Signature,
37    aliases: Vec<String>,
38}
39
40impl Default for SparkSha2 {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SparkSha2 {
47    pub fn new() -> Self {
48        Self {
49            signature: Signature::user_defined(Volatility::Immutable),
50            aliases: vec![],
51        }
52    }
53}
54
55impl ScalarUDFImpl for SparkSha2 {
56    fn as_any(&self) -> &dyn Any {
57        self
58    }
59
60    fn name(&self) -> &str {
61        "sha2"
62    }
63
64    fn signature(&self) -> &Signature {
65        &self.signature
66    }
67
68    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
69        if arg_types[1].is_null() {
70            return Ok(DataType::Null);
71        }
72        Ok(match arg_types[0] {
73            DataType::Utf8View
74            | DataType::LargeUtf8
75            | DataType::Utf8
76            | DataType::Binary
77            | DataType::BinaryView
78            | DataType::LargeBinary => DataType::Utf8,
79            DataType::Null => DataType::Null,
80            _ => {
81                return exec_err!(
82                    "{} function can only accept strings or binary arrays.",
83                    self.name()
84                )
85            }
86        })
87    }
88
89    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90        let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
91            internal_datafusion_err!("Expected 2 arguments for function sha2")
92        })?;
93
94        sha2(args)
95    }
96
97    fn aliases(&self) -> &[String] {
98        &self.aliases
99    }
100
101    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
102        if arg_types.len() != 2 {
103            return Err(invalid_arg_count_exec_err(
104                self.name(),
105                (2, 2),
106                arg_types.len(),
107            ));
108        }
109        let expr_type = match &arg_types[0] {
110            DataType::Utf8View
111            | DataType::LargeUtf8
112            | DataType::Utf8
113            | DataType::Binary
114            | DataType::BinaryView
115            | DataType::LargeBinary
116            | DataType::Null => Ok(arg_types[0].clone()),
117            _ => Err(unsupported_data_type_exec_err(
118                self.name(),
119                "String, Binary",
120                &arg_types[0],
121            )),
122        }?;
123        let bit_length_type = if arg_types[1].is_numeric() {
124            Ok(DataType::Int32)
125        } else if arg_types[1].is_null() {
126            Ok(DataType::Null)
127        } else {
128            Err(unsupported_data_type_exec_err(
129                self.name(),
130                "Numeric Type",
131                &arg_types[1],
132            ))
133        }?;
134
135        Ok(vec![expr_type, bit_length_type])
136    }
137}
138
139pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
140    match args {
141        [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
142            compute_sha2(
143                bit_length_arg,
144                &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
145            )
146        }
147        [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
148            compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)])
149        }
150        [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] =>
151        {
152            let arr: StringArray = bit_length_arg
153                .as_primitive::<Int32Type>()
154                .iter()
155                .map(|bit_length| {
156                    match sha2([
157                        ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
158                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
159                    ])
160                    .unwrap()
161                    {
162                        ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
163                        ColumnarValue::Array(arr) => arr
164                            .as_string::<i32>()
165                            .iter()
166                            .map(|str| str.unwrap().to_string())
167                            .next(), // first element
168                        _ => unreachable!(),
169                    }
170                })
171                .collect();
172            Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
173        }
174        [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => {
175            let expr_iter = expr_arg.as_string::<i32>().iter();
176            let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
177            let arr: StringArray = expr_iter
178                .zip(bit_length_iter)
179                .map(|(expr, bit_length)| {
180                    match sha2([
181                        ColumnarValue::Scalar(ScalarValue::Utf8(Some(
182                            expr.unwrap().to_string(),
183                        ))),
184                        ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
185                    ])
186                    .unwrap()
187                    {
188                        ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
189                        ColumnarValue::Array(arr) => arr
190                            .as_string::<i32>()
191                            .iter()
192                            .map(|str| str.unwrap().to_string())
193                            .next(), // first element
194                        _ => unreachable!(),
195                    }
196                })
197                .collect();
198            Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
199        }
200        _ => exec_err!("Unsupported argument types for sha2 function"),
201    }
202}
203
204fn compute_sha2(
205    bit_length_arg: i32,
206    expr_arg: &[ColumnarValue],
207) -> Result<ColumnarValue> {
208    match bit_length_arg {
209        0 | 256 => sha256(expr_arg),
210        224 => sha224(expr_arg),
211        384 => sha384(expr_arg),
212        512 => sha512(expr_arg),
213        _ => {
214            // Return null for unsupported bit lengths instead of error, because spark sha2 does not
215            // error out for this.
216            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
217        }
218    }
219    .map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
220}