Skip to main content

datafusion_spark/function/hash/
sha1.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, StringArray};
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::cast::{
23    as_binary_array, as_binary_view_array, as_fixed_size_binary_array,
24    as_large_binary_array,
25};
26use datafusion_common::types::{NativeType, logical_string};
27use datafusion_common::utils::take_function_args;
28use datafusion_common::{Result, internal_err};
29use datafusion_expr::{
30    Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
31    Signature, TypeSignatureClass, Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34use sha1::{Digest, Sha1};
35
36/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha1>
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkSha1 {
39    signature: Signature,
40    aliases: Vec<String>,
41}
42
43impl Default for SparkSha1 {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl SparkSha1 {
50    pub fn new() -> Self {
51        Self {
52            signature: Signature::coercible(
53                vec![Coercion::new_implicit(
54                    TypeSignatureClass::Binary,
55                    vec![TypeSignatureClass::Native(logical_string())],
56                    NativeType::Binary,
57                )],
58                Volatility::Immutable,
59            ),
60            aliases: vec!["sha".to_string()],
61        }
62    }
63}
64
65impl ScalarUDFImpl for SparkSha1 {
66    fn name(&self) -> &str {
67        "sha1"
68    }
69
70    fn aliases(&self) -> &[String] {
71        &self.aliases
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        internal_err!("return_field_from_args should be used instead")
80    }
81
82    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
83        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
84        Ok(Arc::new(Field::new(self.name(), DataType::Utf8, nullable)))
85    }
86
87    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
88        make_scalar_function(spark_sha1, vec![])(&args.args)
89    }
90}
91
92/// Hex encoding lookup table for fast byte-to-hex conversion
93const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
94
95#[inline]
96fn spark_sha1_digest(value: &[u8]) -> String {
97    let result = Sha1::digest(value);
98    let mut s = String::with_capacity(result.len() * 2);
99    for &b in result.as_slice() {
100        s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char);
101        s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char);
102    }
103    s
104}
105
106fn spark_sha1_impl<'a>(input: impl Iterator<Item = Option<&'a [u8]>>) -> ArrayRef {
107    let result = input
108        .map(|value| value.map(spark_sha1_digest))
109        .collect::<StringArray>();
110    Arc::new(result)
111}
112
113fn spark_sha1(args: &[ArrayRef]) -> Result<ArrayRef> {
114    let [input] = take_function_args("sha1", args)?;
115
116    match input.data_type() {
117        DataType::Null => Ok(Arc::new(StringArray::new_null(input.len()))),
118        DataType::Binary => {
119            let input = as_binary_array(input)?;
120            Ok(spark_sha1_impl(input.iter()))
121        }
122        DataType::LargeBinary => {
123            let input = as_large_binary_array(input)?;
124            Ok(spark_sha1_impl(input.iter()))
125        }
126        DataType::BinaryView => {
127            let input = as_binary_view_array(input)?;
128            Ok(spark_sha1_impl(input.iter()))
129        }
130        DataType::FixedSizeBinary(_) => {
131            let input = as_fixed_size_binary_array(input)?;
132            Ok(spark_sha1_impl(input.iter()))
133        }
134        dt => {
135            internal_err!("Unsupported data type for sha1: {dt}")
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_sha1_nullability() -> Result<()> {
146        let func = SparkSha1::new();
147
148        // Non-nullable input keeps output non-nullable
149        let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Binary, false));
150        let out = func.return_field_from_args(ReturnFieldArgs {
151            arg_fields: &[Arc::clone(&non_nullable)],
152            scalar_arguments: &[None],
153        })?;
154        assert!(!out.is_nullable());
155        assert_eq!(out.data_type(), &DataType::Utf8);
156
157        // Nullable input makes output nullable
158        let nullable: FieldRef = Arc::new(Field::new("col", DataType::Binary, true));
159        let out = func.return_field_from_args(ReturnFieldArgs {
160            arg_fields: &[Arc::clone(&nullable)],
161            scalar_arguments: &[None],
162        })?;
163        assert!(out.is_nullable());
164        assert_eq!(out.data_type(), &DataType::Utf8);
165
166        Ok(())
167    }
168}