datafusion_spark/function/url/
url_decode.rs1use std::any::Any;
19use std::borrow::Cow;
20use std::sync::Arc;
21
22use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
23use arrow::datatypes::DataType;
24use datafusion_common::cast::{
25 as_large_string_array, as_string_array, as_string_view_array,
26};
27use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err};
28use datafusion_expr::{
29 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32use percent_encoding::percent_decode;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct UrlDecode {
36 signature: Signature,
37}
38
39impl Default for UrlDecode {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl UrlDecode {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::string(1, Volatility::Immutable),
49 }
50 }
51
52 fn decode(value: &str) -> Result<String> {
69 Self::validate_percent_encoding(value)?;
71
72 let replaced = Self::replace_plus(value.as_bytes());
73 percent_decode(&replaced)
74 .decode_utf8()
75 .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}"))
76 .map(|parsed| parsed.into_owned())
77 }
78
79 fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> {
82 match input.iter().position(|&b| b == b'+') {
83 None => Cow::Borrowed(input),
84 Some(first_position) => {
85 let mut replaced = input.to_owned();
86 replaced[first_position] = b' ';
87 for byte in &mut replaced[first_position + 1..] {
88 if *byte == b'+' {
89 *byte = b' ';
90 }
91 }
92 Cow::Owned(replaced)
93 }
94 }
95 }
96
97 fn validate_percent_encoding(value: &str) -> Result<()> {
99 let bytes = value.as_bytes();
100 let mut i = 0;
101
102 while i < bytes.len() {
103 if bytes[i] == b'%' {
104 if i + 2 >= bytes.len() {
106 return exec_err!(
107 "Invalid percent-encoding: incomplete sequence at position {}",
108 i
109 );
110 }
111
112 let hex1 = bytes[i + 1];
113 let hex2 = bytes[i + 2];
114
115 if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() {
116 return exec_err!(
117 "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}",
118 hex1 as char,
119 hex2 as char,
120 i
121 );
122 }
123 i += 3;
124 } else {
125 i += 1;
126 }
127 }
128 Ok(())
129 }
130}
131
132impl ScalarUDFImpl for UrlDecode {
133 fn as_any(&self) -> &dyn Any {
134 self
135 }
136
137 fn name(&self) -> &str {
138 "url_decode"
139 }
140
141 fn signature(&self) -> &Signature {
142 &self.signature
143 }
144
145 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
146 if arg_types.len() != 1 {
147 return plan_err!(
148 "{} expects 1 argument, but got {}",
149 self.name(),
150 arg_types.len()
151 );
152 }
153 Ok(arg_types[0].clone())
155 }
156
157 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
158 let ScalarFunctionArgs { args, .. } = args;
159 make_scalar_function(spark_url_decode, vec![])(&args)
160 }
161}
162
163fn spark_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
175 spark_handled_url_decode(args, |x| x)
176}
177
178pub fn spark_handled_url_decode(
179 args: &[ArrayRef],
180 err_handle_fn: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
181) -> Result<ArrayRef> {
182 if args.len() != 1 {
183 return exec_err!("`url_decode` expects 1 argument");
184 }
185
186 match &args[0].data_type() {
187 DataType::Utf8 => as_string_array(&args[0])?
188 .iter()
189 .map(|x| x.map(UrlDecode::decode).transpose())
190 .map(&err_handle_fn)
191 .collect::<Result<StringArray>>()
192 .map(|array| Arc::new(array) as ArrayRef),
193 DataType::LargeUtf8 => as_large_string_array(&args[0])?
194 .iter()
195 .map(|x| x.map(UrlDecode::decode).transpose())
196 .map(&err_handle_fn)
197 .collect::<Result<LargeStringArray>>()
198 .map(|array| Arc::new(array) as ArrayRef),
199 DataType::Utf8View => as_string_view_array(&args[0])?
200 .iter()
201 .map(|x| x.map(UrlDecode::decode).transpose())
202 .map(&err_handle_fn)
203 .collect::<Result<StringViewArray>>()
204 .map(|array| Arc::new(array) as ArrayRef),
205 other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"),
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use arrow::array::StringArray;
212 use datafusion_common::Result;
213
214 use super::*;
215
216 #[test]
217 fn test_decode() -> Result<()> {
218 let input = Arc::new(StringArray::from(vec![
219 Some("https%3A%2F%2Fspark.apache.org"),
220 Some("inva+lid://user:pass@host/file\\;param?query\\;p2"),
221 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
222 Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"),
223 Some("%E4%BD%A0%E5%A5%BD"),
224 Some(""),
225 None,
226 ]));
227 let expected = StringArray::from(vec![
228 Some("https://spark.apache.org"),
229 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
230 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
231 Some("~!@#$%^&*()_+"),
232 Some("你好"),
233 Some(""),
234 None,
235 ]);
236
237 let result = spark_url_decode(&[input as ArrayRef])?;
238 let result = as_string_array(&result)?;
239
240 assert_eq!(&expected, result);
241
242 Ok(())
243 }
244
245 #[test]
246 fn test_decode_error() -> Result<()> {
247 let input = Arc::new(StringArray::from(vec![
248 Some("http%3A%2F%2spark.apache.org"), Some("https%3A%2F%2Fspark.apache.org"),
251 None,
252 ]));
253
254 let result = spark_url_decode(&[input]);
255 assert!(
256 result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))
257 );
258
259 Ok(())
260 }
261}