1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23 StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27 as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_datafusion_err, exec_err};
30use datafusion_expr::{
31 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32 Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39 signature: Signature,
40}
41
42impl Default for ParseUrl {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ParseUrl {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::one_of(
52 vec![TypeSignature::String(2), TypeSignature::String(3)],
53 Volatility::Immutable,
54 ),
55 }
56 }
57 fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84 let url: std::result::Result<Url, ParseError> = Url::parse(value);
85 if let Err(ParseError::RelativeUrlWithoutBase) = url {
86 return if !value.contains("://") {
87 Ok(None)
88 } else {
89 Err(exec_datafusion_err!(
90 "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"
91 ))
92 };
93 };
94 url.map_err(|e| exec_datafusion_err!("{e:?}"))
95 .map(|url| match part {
96 "HOST" => url.host_str().map(String::from),
97 "PATH" => {
98 let path: String = url.path().to_string();
99 let path: String = if path == "/" { "".to_string() } else { path };
100 Some(path)
101 }
102 "QUERY" => match key {
103 None => url.query().map(String::from),
104 Some(key) => url
105 .query_pairs()
106 .find(|(k, _)| k == key)
107 .map(|(_, v)| v.into_owned()),
108 },
109 "REF" => url.fragment().map(String::from),
110 "PROTOCOL" => Some(url.scheme().to_string()),
111 "FILE" => {
112 let path = url.path();
113 match url.query() {
114 Some(query) => Some(format!("{path}?{query}")),
115 None => Some(path.to_string()),
116 }
117 }
118 "AUTHORITY" => Some(url.authority().to_string()),
119 "USERINFO" => {
120 let username = url.username();
121 if username.is_empty() {
122 return None;
123 }
124 match url.password() {
125 Some(password) => Some(format!("{username}:{password}")),
126 None => Some(username.to_string()),
127 }
128 }
129 _ => None,
130 })
131 }
132}
133
134impl ScalarUDFImpl for ParseUrl {
135 fn as_any(&self) -> &dyn Any {
136 self
137 }
138
139 fn name(&self) -> &str {
140 "parse_url"
141 }
142
143 fn signature(&self) -> &Signature {
144 &self.signature
145 }
146
147 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148 Ok(arg_types[0].clone())
149 }
150
151 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
152 let ScalarFunctionArgs { args, .. } = args;
153 make_scalar_function(spark_parse_url, vec![])(&args)
154 }
155}
156
157fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
173 spark_handled_parse_url(args, |x| x)
174}
175
176pub fn spark_handled_parse_url(
177 args: &[ArrayRef],
178 handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
179) -> Result<ArrayRef> {
180 if args.len() < 2 || args.len() > 3 {
181 return exec_err!(
182 "{} expects 2 or 3 arguments, but got {}",
183 "`parse_url`",
184 args.len()
185 );
186 }
187 let url = &args[0];
189 let part = &args[1];
190
191 if args.len() == 3 {
192 let key = &args[2];
194
195 match (url.data_type(), part.data_type(), key.data_type()) {
196 (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
197 process_parse_url::<_, _, _, StringArray>(
198 as_string_array(url)?,
199 as_string_array(part)?,
200 as_string_array(key)?,
201 handler_err,
202 )
203 }
204 (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
205 process_parse_url::<_, _, _, StringViewArray>(
206 as_string_view_array(url)?,
207 as_string_view_array(part)?,
208 as_string_view_array(key)?,
209 handler_err,
210 )
211 }
212 (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
213 process_parse_url::<_, _, _, LargeStringArray>(
214 as_large_string_array(url)?,
215 as_large_string_array(part)?,
216 as_large_string_array(key)?,
217 handler_err,
218 )
219 }
220 _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
221 }
222 } else {
223 let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
226 for _ in 0..args[0].len() {
227 builder.append_null();
228 }
229 let key = builder.finish();
230
231 match (url.data_type(), part.data_type()) {
232 (DataType::Utf8, DataType::Utf8) => {
233 process_parse_url::<_, _, _, StringArray>(
234 as_string_array(url)?,
235 as_string_array(part)?,
236 &key,
237 handler_err,
238 )
239 }
240 (DataType::Utf8View, DataType::Utf8View) => {
241 process_parse_url::<_, _, _, StringViewArray>(
242 as_string_view_array(url)?,
243 as_string_view_array(part)?,
244 &key,
245 handler_err,
246 )
247 }
248 (DataType::LargeUtf8, DataType::LargeUtf8) => {
249 process_parse_url::<_, _, _, LargeStringArray>(
250 as_large_string_array(url)?,
251 as_large_string_array(part)?,
252 &key,
253 handler_err,
254 )
255 }
256 _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
257 }
258 }
259}
260
261fn process_parse_url<'a, A, B, C, T>(
262 url_array: &'a A,
263 part_array: &'a B,
264 key_array: &'a C,
265 handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
266) -> Result<ArrayRef>
267where
268 &'a A: StringArrayType<'a>,
269 &'a B: StringArrayType<'a>,
270 &'a C: StringArrayType<'a>,
271 T: Array + FromIterator<Option<String>> + 'static,
272{
273 url_array
274 .iter()
275 .zip(part_array.iter())
276 .zip(key_array.iter())
277 .map(|((url, part), key)| {
278 if let (Some(url), Some(part), key) = (url, part, key) {
279 handle(ParseUrl::parse(url, part, key))
280 } else {
281 Ok(None)
282 }
283 })
284 .collect::<Result<T>>()
285 .map(|array| Arc::new(array) as ArrayRef)
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use arrow::array::{ArrayRef, Int32Array, StringArray};
292 use datafusion_common::Result;
293 use std::array::from_ref;
294 use std::sync::Arc;
295
296 fn sa(vals: &[Option<&str>]) -> ArrayRef {
297 Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
298 }
299
300 #[test]
301 fn test_parse_host() -> Result<()> {
302 let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
303 assert_eq!(got, Some("example.com".to_string()));
304 Ok(())
305 }
306
307 #[test]
308 fn test_parse_query_no_key_vs_with_key() -> Result<()> {
309 let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
310 assert_eq!(got_all, Some("a=1&b=2".to_string()));
311
312 let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
313 assert_eq!(got_a, Some("1".to_string()));
314
315 let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
316 assert_eq!(got_c, None);
317 Ok(())
318 }
319
320 #[test]
321 fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
322 let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
323 assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
324 assert_eq!(
325 ParseUrl::parse(url, "PROTOCOL", None)?,
326 Some("ftp".to_string())
327 );
328 assert_eq!(
329 ParseUrl::parse(url, "USERINFO", None)?,
330 Some("user:pwd".to_string())
331 );
332 assert_eq!(
333 ParseUrl::parse(url, "FILE", None)?,
334 Some("/files?x=1".to_string())
335 );
336 assert_eq!(
337 ParseUrl::parse(url, "AUTHORITY", None)?,
338 Some("user:pwd@ftp.example.com".to_string())
339 );
340 Ok(())
341 }
342
343 #[test]
344 fn test_parse_path_root_is_empty_string() -> Result<()> {
345 let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
346 assert_eq!(got, Some("".to_string()));
347 Ok(())
348 }
349
350 #[test]
351 fn test_parse_malformed_url_returns_error() -> Result<()> {
352 let got = ParseUrl::parse("notaurl", "HOST", None)?;
353 assert_eq!(got, None);
354 Ok(())
355 }
356
357 #[test]
358 fn test_spark_utf8_two_args() -> Result<()> {
359 let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
360 let parts = sa(&[Some("HOST"), Some("PATH")]);
361
362 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
363 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
364
365 assert_eq!(out_sa.len(), 2);
366 assert_eq!(out_sa.value(0), "example.com");
367 assert_eq!(out_sa.value(1), "");
368 Ok(())
369 }
370
371 #[test]
372 fn test_spark_utf8_three_args_query_key() -> Result<()> {
373 let urls = sa(&[
374 Some("https://example.com/a?x=1&y=2"),
375 Some("https://ex.com/?a=1"),
376 ]);
377 let parts = sa(&[Some("QUERY"), Some("QUERY")]);
378 let keys = sa(&[Some("y"), Some("b")]);
379
380 let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
381 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
382
383 assert_eq!(out_sa.len(), 2);
384 assert_eq!(out_sa.value(0), "2");
385 assert!(out_sa.is_null(1));
386 Ok(())
387 }
388
389 #[test]
390 fn test_spark_userinfo_and_nulls() -> Result<()> {
391 let urls = sa(&[
392 Some("ftp://user:pwd@ftp.example.com:21/files"),
393 Some("https://example.com"),
394 None,
395 ]);
396 let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
397
398 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
399 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
400
401 assert_eq!(out_sa.len(), 3);
402 assert_eq!(out_sa.value(0), "user:pwd");
403 assert!(out_sa.is_null(1));
404 assert!(out_sa.is_null(2));
405 Ok(())
406 }
407
408 #[test]
409 fn test_invalid_arg_count() {
410 let urls = sa(&[Some("https://example.com")]);
411 let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
412 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
413
414 let parts = sa(&[Some("HOST")]);
415 let keys = sa(&[Some("x")]);
416 let err =
417 spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
418 .unwrap_err();
419 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
420 }
421
422 #[test]
423 fn test_non_string_types_error() {
424 let urls = sa(&[Some("https://example.com")]);
425 let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
426
427 let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
428 let msg = format!("{err}");
429 assert!(msg.contains("expects STRING arguments"));
430 }
431}