Skip to main content

datafusion_spark/function/url/
try_url_decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrayRef;
19use arrow::datatypes::DataType;
20
21use datafusion_common::Result;
22use datafusion_expr::{
23    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24};
25use datafusion_functions::utils::make_scalar_function;
26
27use crate::function::url::url_decode::{UrlDecode, spark_handled_url_decode};
28
29#[derive(Debug, PartialEq, Eq, Hash)]
30pub struct TryUrlDecode {
31    signature: Signature,
32    url_decoder: UrlDecode,
33}
34
35impl Default for TryUrlDecode {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TryUrlDecode {
42    pub fn new() -> Self {
43        Self {
44            signature: Signature::string(1, Volatility::Immutable),
45            url_decoder: UrlDecode::new(),
46        }
47    }
48}
49
50impl ScalarUDFImpl for TryUrlDecode {
51    fn name(&self) -> &str {
52        "try_url_decode"
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
60        self.url_decoder.return_type(arg_types)
61    }
62
63    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
64        let ScalarFunctionArgs { args, .. } = args;
65        make_scalar_function(spark_try_url_decode, vec![])(&args)
66    }
67}
68
69fn spark_try_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
70    spark_handled_url_decode(args, |x| match x {
71        Err(_) => Ok(None),
72        result => result,
73    })
74}
75
76#[cfg(test)]
77mod tests {
78    use std::sync::Arc;
79
80    use arrow::array::StringArray;
81    use datafusion_common::cast::as_string_array;
82
83    use super::*;
84
85    #[test]
86    fn test_try_decode_error_handled() -> Result<()> {
87        let input = Arc::new(StringArray::from(vec![
88            Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character
89            // Valid cases
90            Some("https%3A%2F%2Fspark.apache.org"),
91            None,
92        ]));
93
94        let expected =
95            StringArray::from(vec![None, Some("https://spark.apache.org"), None]);
96
97        let result = spark_try_url_decode(&[input as ArrayRef])?;
98        let result = as_string_array(&result)?;
99
100        assert_eq!(&expected, result);
101        Ok(())
102    }
103}