datafusion_comet_spark_expr/hash_funcs/
sha2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::math_funcs::hex::hex_strings;
19use arrow_array::{Array, StringArray};
20use arrow_schema::DataType;
21use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
22use datafusion_common::cast::as_binary_array;
23use datafusion_common::{exec_err, DataFusionError, ScalarValue};
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF};
25use std::sync::Arc;
26
27/// `sha224` function that simulates Spark's `sha2` expression with bit width 224
28pub fn spark_sha224(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
29    wrap_digest_result_as_hex_string(args, sha224())
30}
31
32/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256
33pub fn spark_sha256(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
34    wrap_digest_result_as_hex_string(args, sha256())
35}
36
37/// `sha384` function that simulates Spark's `sha2` expression with bit width 384
38pub fn spark_sha384(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
39    wrap_digest_result_as_hex_string(args, sha384())
40}
41
42/// `sha512` function that simulates Spark's `sha2` expression with bit width 512
43pub fn spark_sha512(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
44    wrap_digest_result_as_hex_string(args, sha512())
45}
46
47// Spark requires hex string as the result of sha2 functions, we have to wrap the
48// result of digest functions as hex string
49fn wrap_digest_result_as_hex_string(
50    args: &[ColumnarValue],
51    digest: Arc<ScalarUDF>,
52) -> Result<ColumnarValue, DataFusionError> {
53    let row_count = match &args[0] {
54        ColumnarValue::Array(array) => array.len(),
55        ColumnarValue::Scalar(_) => 1,
56    };
57    let value = digest.invoke_with_args(ScalarFunctionArgs {
58        args: args.into(),
59        number_rows: row_count,
60        return_type: &DataType::Utf8,
61    })?;
62    match value {
63        ColumnarValue::Array(array) => {
64            let binary_array = as_binary_array(&array)?;
65            let string_array: StringArray = binary_array
66                .iter()
67                .map(|opt| opt.map(hex_strings::<_>))
68                .collect();
69            Ok(ColumnarValue::Array(Arc::new(string_array)))
70        }
71        ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar(
72            ScalarValue::Utf8(opt.map(hex_strings::<_>)),
73        )),
74        _ => {
75            exec_err!(
76                "digest function should return binary value, but got: {:?}",
77                value.data_type()
78            )
79        }
80    }
81}