Skip to main content

datafusion_spark/function/url/
url_decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::borrow::Cow;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
22use arrow::datatypes::DataType;
23use datafusion_common::cast::{
24    as_large_string_array, as_string_array, as_string_view_array,
25};
26use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err};
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_functions::utils::make_scalar_function;
31use percent_encoding::percent_decode;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct UrlDecode {
35    signature: Signature,
36}
37
38impl Default for UrlDecode {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl UrlDecode {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::string(1, Volatility::Immutable),
48        }
49    }
50
51    /// Decodes a URL-encoded string from application/x-www-form-urlencoded format.
52    /// Although the `url::form_urlencoded` support decoding, it does not return error when the string is malformed
53    ///     For example: "%2s" is not a valid percent-encoding, the `decode` function from `url::form_urlencoded`
54    ///                  will ignore this instead of return error
55    /// This function reproduce the same decoding process, plus an extra validation step
56    /// See <https://github.com/servo/rust-url/blob/b06048d70d4cc9cf4ffb277f06cfcebd53b2141e/form_urlencoded/src/lib.rs#L70-L76>
57    ///
58    /// # Arguments
59    ///
60    /// * `value` - The URL-encoded string to decode
61    ///
62    /// # Returns
63    ///
64    /// * `Ok(String)` - The decoded string
65    /// * `Err(DataFusionError)` - If the input is malformed or contains invalid UTF-8
66    ///
67    fn decode(value: &str) -> Result<String> {
68        // Check if the string has valid percent encoding
69        Self::validate_percent_encoding(value)?;
70
71        let replaced = Self::replace_plus(value.as_bytes());
72        percent_decode(&replaced)
73            .decode_utf8()
74            .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}"))
75            .map(|parsed| parsed.into_owned())
76    }
77
78    /// Replace b'+' with b' '
79    /// See: <https://github.com/servo/rust-url/blob/dbd526178ed9276176602dd039022eba89e8fc93/form_urlencoded/src/lib.rs#L79-L93>
80    fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> {
81        match input.iter().position(|&b| b == b'+') {
82            None => Cow::Borrowed(input),
83            Some(first_position) => {
84                let mut replaced = input.to_owned();
85                replaced[first_position] = b' ';
86                for byte in &mut replaced[first_position + 1..] {
87                    if *byte == b'+' {
88                        *byte = b' ';
89                    }
90                }
91                Cow::Owned(replaced)
92            }
93        }
94    }
95
96    /// Validate percent-encoding of the string
97    fn validate_percent_encoding(value: &str) -> Result<()> {
98        let bytes = value.as_bytes();
99        let mut i = 0;
100
101        while i < bytes.len() {
102            if bytes[i] == b'%' {
103                // Check if we have at least 2 more characters
104                if i + 2 >= bytes.len() {
105                    return exec_err!(
106                        "Invalid percent-encoding: incomplete sequence at position {}",
107                        i
108                    );
109                }
110
111                let hex1 = bytes[i + 1];
112                let hex2 = bytes[i + 2];
113
114                if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() {
115                    return exec_err!(
116                        "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}",
117                        hex1 as char,
118                        hex2 as char,
119                        i
120                    );
121                }
122                i += 3;
123            } else {
124                i += 1;
125            }
126        }
127        Ok(())
128    }
129}
130
131impl ScalarUDFImpl for UrlDecode {
132    fn name(&self) -> &str {
133        "url_decode"
134    }
135
136    fn signature(&self) -> &Signature {
137        &self.signature
138    }
139
140    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
141        if arg_types.len() != 1 {
142            return plan_err!(
143                "{} expects 1 argument, but got {}",
144                self.name(),
145                arg_types.len()
146            );
147        }
148        // As the type signature is already checked, we can safely return the type of the first argument
149        Ok(arg_types[0].clone())
150    }
151
152    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
153        let ScalarFunctionArgs { args, .. } = args;
154        make_scalar_function(spark_url_decode, vec![])(&args)
155    }
156}
157
158/// Core implementation of URL decoding function.
159///
160/// # Arguments
161///
162/// * `args` - A slice containing exactly one ArrayRef with the URL-encoded strings to decode
163///
164/// # Returns
165///
166/// * `Ok(ArrayRef)` - A new array of the same type containing decoded strings
167/// * `Err(DataFusionError)` - If validation fails or invalid arguments are provided
168///
169fn spark_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
170    spark_handled_url_decode(args, |x| x)
171}
172
173pub fn spark_handled_url_decode(
174    args: &[ArrayRef],
175    err_handle_fn: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
176) -> Result<ArrayRef> {
177    if args.len() != 1 {
178        return exec_err!("`url_decode` expects 1 argument");
179    }
180
181    match &args[0].data_type() {
182        DataType::Utf8 => as_string_array(&args[0])?
183            .iter()
184            .map(|x| x.map(UrlDecode::decode).transpose())
185            .map(&err_handle_fn)
186            .collect::<Result<StringArray>>()
187            .map(|array| Arc::new(array) as ArrayRef),
188        DataType::LargeUtf8 => as_large_string_array(&args[0])?
189            .iter()
190            .map(|x| x.map(UrlDecode::decode).transpose())
191            .map(&err_handle_fn)
192            .collect::<Result<LargeStringArray>>()
193            .map(|array| Arc::new(array) as ArrayRef),
194        DataType::Utf8View => as_string_view_array(&args[0])?
195            .iter()
196            .map(|x| x.map(UrlDecode::decode).transpose())
197            .map(&err_handle_fn)
198            .collect::<Result<StringViewArray>>()
199            .map(|array| Arc::new(array) as ArrayRef),
200        other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"),
201    }
202}
203
204#[cfg(test)]
205mod tests {
206
207    use super::*;
208
209    #[test]
210    fn test_decode() -> Result<()> {
211        let input = Arc::new(StringArray::from(vec![
212            Some("https%3A%2F%2Fspark.apache.org"),
213            Some("inva+lid://user:pass@host/file\\;param?query\\;p2"),
214            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
215            Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"),
216            Some("%E4%BD%A0%E5%A5%BD"),
217            Some(""),
218            None,
219        ]));
220        let expected = StringArray::from(vec![
221            Some("https://spark.apache.org"),
222            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
223            Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
224            Some("~!@#$%^&*()_+"),
225            Some("你好"),
226            Some(""),
227            None,
228        ]);
229
230        let result = spark_url_decode(&[input as ArrayRef])?;
231        let result = as_string_array(&result)?;
232
233        assert_eq!(&expected, result);
234
235        Ok(())
236    }
237
238    #[test]
239    fn test_decode_error() -> Result<()> {
240        let input = Arc::new(StringArray::from(vec![
241            Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character
242            // Valid cases
243            Some("https%3A%2F%2Fspark.apache.org"),
244            None,
245        ]));
246
247        let result = spark_url_decode(&[input]);
248        assert!(
249            result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))
250        );
251
252        Ok(())
253    }
254}