datafusion_comet_spark_expr/hash_funcs/sha2.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::scalar_funcs::hex_strings;
use arrow_array::{Array, StringArray};
use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
use datafusion_common::cast::as_binary_array;
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDF};
use std::sync::Arc;
/// `sha224` function that simulates Spark's `sha2` expression with bit width 224
pub fn spark_sha224(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
wrap_digest_result_as_hex_string(args, sha224())
}
/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256
pub fn spark_sha256(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
wrap_digest_result_as_hex_string(args, sha256())
}
/// `sha384` function that simulates Spark's `sha2` expression with bit width 384
pub fn spark_sha384(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
wrap_digest_result_as_hex_string(args, sha384())
}
/// `sha512` function that simulates Spark's `sha2` expression with bit width 512
pub fn spark_sha512(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
wrap_digest_result_as_hex_string(args, sha512())
}
// Spark requires hex string as the result of sha2 functions, we have to wrap the
// result of digest functions as hex string
fn wrap_digest_result_as_hex_string(
args: &[ColumnarValue],
digest: Arc<ScalarUDF>,
) -> Result<ColumnarValue, DataFusionError> {
let row_count = match &args[0] {
ColumnarValue::Array(array) => array.len(),
ColumnarValue::Scalar(_) => 1,
};
let value = digest.invoke_batch(args, row_count)?;
match value {
ColumnarValue::Array(array) => {
let binary_array = as_binary_array(&array)?;
let string_array: StringArray = binary_array
.iter()
.map(|opt| opt.map(hex_strings::<_>))
.collect();
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar(
ScalarValue::Utf8(opt.map(hex_strings::<_>)),
)),
_ => {
exec_err!(
"digest function should return binary value, but got: {:?}",
value.data_type()
)
}
}
}