Skip to main content

datafusion_spark/function/url/
url_decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::borrow::Cow;
20use std::sync::Arc;
21
22use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
23use arrow::datatypes::DataType;
24use datafusion_common::cast::{
25    as_large_string_array, as_string_array, as_string_view_array,
26};
27use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err};
28use datafusion_expr::{
29    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32use percent_encoding::percent_decode;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct UrlDecode {
36    signature: Signature,
37}
38
39impl Default for UrlDecode {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl UrlDecode {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::string(1, Volatility::Immutable),
49        }
50    }
51
52    /// Decodes a URL-encoded string from application/x-www-form-urlencoded format.
53    /// Although the `url::form_urlencoded` support decoding, it does not return error when the string is malformed
54    ///     For example: "%2s" is not a valid percent-encoding, the `decode` function from `url::form_urlencoded`
55    ///                  will ignore this instead of return error
56    /// This function reproduce the same decoding process, plus an extra validation step
57    /// See <https://github.com/servo/rust-url/blob/b06048d70d4cc9cf4ffb277f06cfcebd53b2141e/form_urlencoded/src/lib.rs#L70-L76>
58    ///
59    /// # Arguments
60    ///
61    /// * `value` - The URL-encoded string to decode
62    ///
63    /// # Returns
64    ///
65    /// * `Ok(String)` - The decoded string
66    /// * `Err(DataFusionError)` - If the input is malformed or contains invalid UTF-8
67    ///
68    fn decode(value: &str) -> Result<String> {
69        // Check if the string has valid percent encoding
70        Self::validate_percent_encoding(value)?;
71
72        let replaced = Self::replace_plus(value.as_bytes());
73        percent_decode(&replaced)
74            .decode_utf8()
75            .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}"))
76            .map(|parsed| parsed.into_owned())
77    }
78
79    /// Replace b'+' with b' '
80    /// See: <https://github.com/servo/rust-url/blob/dbd526178ed9276176602dd039022eba89e8fc93/form_urlencoded/src/lib.rs#L79-L93>
81    fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> {
82        match input.iter().position(|&b| b == b'+') {
83            None => Cow::Borrowed(input),
84            Some(first_position) => {
85                let mut replaced = input.to_owned();
86                replaced[first_position] = b' ';
87                for byte in &mut replaced[first_position + 1..] {
88                    if *byte == b'+' {
89                        *byte = b' ';
90                    }
91                }
92                Cow::Owned(replaced)
93            }
94        }
95    }
96
97    /// Validate percent-encoding of the string
98    fn validate_percent_encoding(value: &str) -> Result<()> {
99        let bytes = value.as_bytes();
100        let mut i = 0;
101
102        while i < bytes.len() {
103            if bytes[i] == b'%' {
104                // Check if we have at least 2 more characters
105                if i + 2 >= bytes.len() {
106                    return exec_err!(
107                        "Invalid percent-encoding: incomplete sequence at position {}",
108                        i
109                    );
110                }
111
112                let hex1 = bytes[i + 1];
113                let hex2 = bytes[i + 2];
114
115                if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() {
116                    return exec_err!(
117                        "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}",
118                        hex1 as char,
119                        hex2 as char,
120                        i
121                    );
122                }
123                i += 3;
124            } else {
125                i += 1;
126            }
127        }
128        Ok(())
129    }
130}
131
132impl ScalarUDFImpl for UrlDecode {
133    fn as_any(&self) -> &dyn Any {
134        self
135    }
136
137    fn name(&self) -> &str {
138        "url_decode"
139    }
140
141    fn signature(&self) -> &Signature {
142        &self.signature
143    }
144
145    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
146        if arg_types.len() != 1 {
147            return plan_err!(
148                "{} expects 1 argument, but got {}",
149                self.name(),
150                arg_types.len()
151            );
152        }
153        // As the type signature is already checked, we can safely return the type of the first argument
154        Ok(arg_types[0].clone())
155    }
156
157    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
158        let ScalarFunctionArgs { args, .. } = args;
159        make_scalar_function(spark_url_decode, vec![])(&args)
160    }
161}
162
163/// Core implementation of URL decoding function.
164///
165/// # Arguments
166///
167/// * `args` - A slice containing exactly one ArrayRef with the URL-encoded strings to decode
168///
169/// # Returns
170///
171/// * `Ok(ArrayRef)` - A new array of the same type containing decoded strings
172/// * `Err(DataFusionError)` - If validation fails or invalid arguments are provided
173///
174fn spark_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
175    spark_handled_url_decode(args, |x| x)
176}
177
178pub fn spark_handled_url_decode(
179    args: &[ArrayRef],
180    err_handle_fn: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
181) -> Result<ArrayRef> {
182    if args.len() != 1 {
183        return exec_err!("`url_decode` expects 1 argument");
184    }
185
186    match &args[0].data_type() {
187        DataType::Utf8 => as_string_array(&args[0])?
188            .iter()
189            .map(|x| x.map(UrlDecode::decode).transpose())
190            .map(&err_handle_fn)
191            .collect::<Result<StringArray>>()
192            .map(|array| Arc::new(array) as ArrayRef),
193        DataType::LargeUtf8 => as_large_string_array(&args[0])?
194            .iter()
195            .map(|x| x.map(UrlDecode::decode).transpose())
196            .map(&err_handle_fn)
197            .collect::<Result<LargeStringArray>>()
198            .map(|array| Arc::new(array) as ArrayRef),
199        DataType::Utf8View => as_string_view_array(&args[0])?
200            .iter()
201            .map(|x| x.map(UrlDecode::decode).transpose())
202            .map(&err_handle_fn)
203            .collect::<Result<StringViewArray>>()
204            .map(|array| Arc::new(array) as ArrayRef),
205        other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"),
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use arrow::array::StringArray;
212    use datafusion_common::Result;
213
214    use super::*;
215
216    #[test]
217    fn test_decode() -> Result<()> {
218        let input = Arc::new(StringArray::from(vec![
219            Some("https%3A%2F%2Fspark.apache.org"),
220            Some("inva+lid://user:pass@host/file\\;param?query\\;p2"),
221            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
222            Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"),
223            Some("%E4%BD%A0%E5%A5%BD"),
224            Some(""),
225            None,
226        ]));
227        let expected = StringArray::from(vec![
228            Some("https://spark.apache.org"),
229            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
230            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
231            Some("~!@#$%^&*()_+"),
232            Some("你好"),
233            Some(""),
234            None,
235        ]);
236
237        let result = spark_url_decode(&[input as ArrayRef])?;
238        let result = as_string_array(&result)?;
239
240        assert_eq!(&expected, result);
241
242        Ok(())
243    }
244
245    #[test]
246    fn test_decode_error() -> Result<()> {
247        let input = Arc::new(StringArray::from(vec![
248            Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character
249            // Valid cases
250            Some("https%3A%2F%2Fspark.apache.org"),
251            None,
252        ]));
253
254        let result = spark_url_decode(&[input]);
255        assert!(
256            result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))
257        );
258
259        Ok(())
260    }
261}