Skip to main content

datafusion_spark/function/url/
try_url_decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19
20use arrow::array::ArrayRef;
21use arrow::datatypes::DataType;
22
23use datafusion_common::Result;
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_functions::utils::make_scalar_function;
28
29use crate::function::url::url_decode::{UrlDecode, spark_handled_url_decode};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
32pub struct TryUrlDecode {
33    signature: Signature,
34    url_decoder: UrlDecode,
35}
36
37impl Default for TryUrlDecode {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl TryUrlDecode {
44    pub fn new() -> Self {
45        Self {
46            signature: Signature::string(1, Volatility::Immutable),
47            url_decoder: UrlDecode::new(),
48        }
49    }
50}
51
52impl ScalarUDFImpl for TryUrlDecode {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn name(&self) -> &str {
58        "try_url_decode"
59    }
60
61    fn signature(&self) -> &Signature {
62        &self.signature
63    }
64
65    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
66        self.url_decoder.return_type(arg_types)
67    }
68
69    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70        let ScalarFunctionArgs { args, .. } = args;
71        make_scalar_function(spark_try_url_decode, vec![])(&args)
72    }
73}
74
75fn spark_try_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
76    spark_handled_url_decode(args, |x| match x {
77        Err(_) => Ok(None),
78        result => result,
79    })
80}
81
82#[cfg(test)]
83mod tests {
84    use std::sync::Arc;
85
86    use arrow::array::StringArray;
87    use datafusion_common::{Result, cast::as_string_array};
88
89    use super::*;
90
91    #[test]
92    fn test_try_decode_error_handled() -> Result<()> {
93        let input = Arc::new(StringArray::from(vec![
94            Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character
95            // Valid cases
96            Some("https%3A%2F%2Fspark.apache.org"),
97            None,
98        ]));
99
100        let expected =
101            StringArray::from(vec![None, Some("https://spark.apache.org"), None]);
102
103        let result = spark_try_url_decode(&[input as ArrayRef])?;
104        let result = as_string_array(&result)?;
105
106        assert_eq!(&expected, result);
107        Ok(())
108    }
109}