datafusion_spark/function/url/
url_decode.rs1use std::borrow::Cow;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
22use arrow::datatypes::DataType;
23use datafusion_common::cast::{
24 as_large_string_array, as_string_array, as_string_view_array,
25};
26use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err};
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_functions::utils::make_scalar_function;
31use percent_encoding::percent_decode;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct UrlDecode {
35 signature: Signature,
36}
37
38impl Default for UrlDecode {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl UrlDecode {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::string(1, Volatility::Immutable),
48 }
49 }
50
51 fn decode(value: &str) -> Result<String> {
68 Self::validate_percent_encoding(value)?;
70
71 let replaced = Self::replace_plus(value.as_bytes());
72 percent_decode(&replaced)
73 .decode_utf8()
74 .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}"))
75 .map(|parsed| parsed.into_owned())
76 }
77
78 fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> {
81 match input.iter().position(|&b| b == b'+') {
82 None => Cow::Borrowed(input),
83 Some(first_position) => {
84 let mut replaced = input.to_owned();
85 replaced[first_position] = b' ';
86 for byte in &mut replaced[first_position + 1..] {
87 if *byte == b'+' {
88 *byte = b' ';
89 }
90 }
91 Cow::Owned(replaced)
92 }
93 }
94 }
95
96 fn validate_percent_encoding(value: &str) -> Result<()> {
98 let bytes = value.as_bytes();
99 let mut i = 0;
100
101 while i < bytes.len() {
102 if bytes[i] == b'%' {
103 if i + 2 >= bytes.len() {
105 return exec_err!(
106 "Invalid percent-encoding: incomplete sequence at position {}",
107 i
108 );
109 }
110
111 let hex1 = bytes[i + 1];
112 let hex2 = bytes[i + 2];
113
114 if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() {
115 return exec_err!(
116 "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}",
117 hex1 as char,
118 hex2 as char,
119 i
120 );
121 }
122 i += 3;
123 } else {
124 i += 1;
125 }
126 }
127 Ok(())
128 }
129}
130
131impl ScalarUDFImpl for UrlDecode {
132 fn name(&self) -> &str {
133 "url_decode"
134 }
135
136 fn signature(&self) -> &Signature {
137 &self.signature
138 }
139
140 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
141 if arg_types.len() != 1 {
142 return plan_err!(
143 "{} expects 1 argument, but got {}",
144 self.name(),
145 arg_types.len()
146 );
147 }
148 Ok(arg_types[0].clone())
150 }
151
152 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
153 let ScalarFunctionArgs { args, .. } = args;
154 make_scalar_function(spark_url_decode, vec![])(&args)
155 }
156}
157
158fn spark_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
170 spark_handled_url_decode(args, |x| x)
171}
172
173pub fn spark_handled_url_decode(
174 args: &[ArrayRef],
175 err_handle_fn: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
176) -> Result<ArrayRef> {
177 if args.len() != 1 {
178 return exec_err!("`url_decode` expects 1 argument");
179 }
180
181 match &args[0].data_type() {
182 DataType::Utf8 => as_string_array(&args[0])?
183 .iter()
184 .map(|x| x.map(UrlDecode::decode).transpose())
185 .map(&err_handle_fn)
186 .collect::<Result<StringArray>>()
187 .map(|array| Arc::new(array) as ArrayRef),
188 DataType::LargeUtf8 => as_large_string_array(&args[0])?
189 .iter()
190 .map(|x| x.map(UrlDecode::decode).transpose())
191 .map(&err_handle_fn)
192 .collect::<Result<LargeStringArray>>()
193 .map(|array| Arc::new(array) as ArrayRef),
194 DataType::Utf8View => as_string_view_array(&args[0])?
195 .iter()
196 .map(|x| x.map(UrlDecode::decode).transpose())
197 .map(&err_handle_fn)
198 .collect::<Result<StringViewArray>>()
199 .map(|array| Arc::new(array) as ArrayRef),
200 other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"),
201 }
202}
203
204#[cfg(test)]
205mod tests {
206
207 use super::*;
208
209 #[test]
210 fn test_decode() -> Result<()> {
211 let input = Arc::new(StringArray::from(vec![
212 Some("https%3A%2F%2Fspark.apache.org"),
213 Some("inva+lid://user:pass@host/file\\;param?query\\;p2"),
214 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
215 Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"),
216 Some("%E4%BD%A0%E5%A5%BD"),
217 Some(""),
218 None,
219 ]));
220 let expected = StringArray::from(vec![
221 Some("https://spark.apache.org"),
222 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
223 Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
224 Some("~!@#$%^&*()_+"),
225 Some("你好"),
226 Some(""),
227 None,
228 ]);
229
230 let result = spark_url_decode(&[input as ArrayRef])?;
231 let result = as_string_array(&result)?;
232
233 assert_eq!(&expected, result);
234
235 Ok(())
236 }
237
238 #[test]
239 fn test_decode_error() -> Result<()> {
240 let input = Arc::new(StringArray::from(vec![
241 Some("http%3A%2F%2spark.apache.org"), Some("https%3A%2F%2Fspark.apache.org"),
244 None,
245 ]));
246
247 let result = spark_url_decode(&[input]);
248 assert!(
249 result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))
250 );
251
252 Ok(())
253 }
254}