1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23 StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27 as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_datafusion_err, exec_err};
30use datafusion_expr::{
31 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32 Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39 signature: Signature,
40}
41
42impl Default for ParseUrl {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ParseUrl {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::one_of(
52 vec![TypeSignature::String(2), TypeSignature::String(3)],
53 Volatility::Immutable,
54 ),
55 }
56 }
57 fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84 let url: std::result::Result<Url, ParseError> = Url::parse(value);
85 if let Err(ParseError::RelativeUrlWithoutBase) = url {
86 return if !value.contains("://") {
87 Ok(None)
88 } else {
89 Err(exec_datafusion_err!(
90 "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"
91 ))
92 };
93 };
94 url.map_err(|e| exec_datafusion_err!("{e:?}"))
95 .map(|url| match part {
96 "HOST" => url.host_str().map(String::from),
97 "PATH" => {
98 let path: String = url.path().to_string();
99 let path: String = if path == "/" { "".to_string() } else { path };
100 Some(path)
101 }
102 "QUERY" => match key {
103 None => url.query().map(String::from),
104 Some(key) => url
105 .query_pairs()
106 .find(|(k, _)| k == key)
107 .map(|(_, v)| v.into_owned()),
108 },
109 "REF" => url.fragment().map(String::from),
110 "PROTOCOL" => Some(url.scheme().to_string()),
111 "FILE" => {
112 let path = url.path();
113 match url.query() {
114 Some(query) => Some(format!("{path}?{query}")),
115 None => Some(path.to_string()),
116 }
117 }
118 "AUTHORITY" => Some(url.authority().to_string()),
119 "USERINFO" => {
120 let username = url.username();
121 if username.is_empty() {
122 return None;
123 }
124 match url.password() {
125 Some(password) => Some(format!("{username}:{password}")),
126 None => Some(username.to_string()),
127 }
128 }
129 _ => None,
130 })
131 }
132}
133
134impl ScalarUDFImpl for ParseUrl {
135 fn as_any(&self) -> &dyn Any {
136 self
137 }
138
139 fn name(&self) -> &str {
140 "parse_url"
141 }
142
143 fn signature(&self) -> &Signature {
144 &self.signature
145 }
146
147 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148 Ok(arg_types[0].clone())
149 }
150
151 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
152 let ScalarFunctionArgs { args, .. } = args;
153 make_scalar_function(spark_parse_url, vec![])(&args)
154 }
155}
156
157fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
173 spark_handled_parse_url(args, |x| x)
174}
175
176pub fn spark_handled_parse_url(
177 args: &[ArrayRef],
178 handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
179) -> Result<ArrayRef> {
180 if args.len() < 2 || args.len() > 3 {
181 return exec_err!(
182 "{} expects 2 or 3 arguments, but got {}",
183 "`parse_url`",
184 args.len()
185 );
186 }
187 let url = &args[0];
189 let part = &args[1];
190
191 if args.len() == 3 {
192 let key = &args[2];
194
195 match (url.data_type(), part.data_type(), key.data_type()) {
196 (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
197 process_parse_url::<_, _, _, StringArray>(
198 as_string_array(url)?,
199 as_string_array(part)?,
200 as_string_array(key)?,
201 handler_err,
202 )
203 }
204 (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
205 process_parse_url::<_, _, _, StringViewArray>(
206 as_string_view_array(url)?,
207 as_string_view_array(part)?,
208 as_string_view_array(key)?,
209 handler_err,
210 )
211 }
212 (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
213 process_parse_url::<_, _, _, LargeStringArray>(
214 as_large_string_array(url)?,
215 as_large_string_array(part)?,
216 as_large_string_array(key)?,
217 handler_err,
218 )
219 }
220 _ => exec_err!(
221 "`parse_url` expects STRING arguments, got ({}, {}, {})",
222 url.data_type(),
223 part.data_type(),
224 key.data_type()
225 ),
226 }
227 } else {
228 let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
231 for _ in 0..args[0].len() {
232 builder.append_null();
233 }
234 let key = builder.finish();
235
236 match (url.data_type(), part.data_type()) {
237 (DataType::Utf8, DataType::Utf8) => {
238 process_parse_url::<_, _, _, StringArray>(
239 as_string_array(url)?,
240 as_string_array(part)?,
241 &key,
242 handler_err,
243 )
244 }
245 (DataType::Utf8View, DataType::Utf8View) => {
246 process_parse_url::<_, _, _, StringViewArray>(
247 as_string_view_array(url)?,
248 as_string_view_array(part)?,
249 &key,
250 handler_err,
251 )
252 }
253 (DataType::LargeUtf8, DataType::LargeUtf8) => {
254 process_parse_url::<_, _, _, LargeStringArray>(
255 as_large_string_array(url)?,
256 as_large_string_array(part)?,
257 &key,
258 handler_err,
259 )
260 }
261 _ => exec_err!(
262 "`parse_url` expects STRING arguments, got ({}, {})",
263 url.data_type(),
264 part.data_type()
265 ),
266 }
267 }
268}
269
270fn process_parse_url<'a, A, B, C, T>(
271 url_array: &'a A,
272 part_array: &'a B,
273 key_array: &'a C,
274 handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
275) -> Result<ArrayRef>
276where
277 &'a A: StringArrayType<'a>,
278 &'a B: StringArrayType<'a>,
279 &'a C: StringArrayType<'a>,
280 T: Array + FromIterator<Option<String>> + 'static,
281{
282 url_array
283 .iter()
284 .zip(part_array.iter())
285 .zip(key_array.iter())
286 .map(|((url, part), key)| {
287 if let (Some(url), Some(part), key) = (url, part, key) {
288 handle(ParseUrl::parse(url, part, key))
289 } else {
290 Ok(None)
291 }
292 })
293 .collect::<Result<T>>()
294 .map(|array| Arc::new(array) as ArrayRef)
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use arrow::array::{ArrayRef, Int32Array, StringArray};
301 use datafusion_common::Result;
302 use std::array::from_ref;
303 use std::sync::Arc;
304
305 fn sa(vals: &[Option<&str>]) -> ArrayRef {
306 Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
307 }
308
309 #[test]
310 fn test_parse_host() -> Result<()> {
311 let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
312 assert_eq!(got, Some("example.com".to_string()));
313 Ok(())
314 }
315
316 #[test]
317 fn test_parse_query_no_key_vs_with_key() -> Result<()> {
318 let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
319 assert_eq!(got_all, Some("a=1&b=2".to_string()));
320
321 let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
322 assert_eq!(got_a, Some("1".to_string()));
323
324 let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
325 assert_eq!(got_c, None);
326 Ok(())
327 }
328
329 #[test]
330 fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
331 let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
332 assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
333 assert_eq!(
334 ParseUrl::parse(url, "PROTOCOL", None)?,
335 Some("ftp".to_string())
336 );
337 assert_eq!(
338 ParseUrl::parse(url, "USERINFO", None)?,
339 Some("user:pwd".to_string())
340 );
341 assert_eq!(
342 ParseUrl::parse(url, "FILE", None)?,
343 Some("/files?x=1".to_string())
344 );
345 assert_eq!(
346 ParseUrl::parse(url, "AUTHORITY", None)?,
347 Some("user:pwd@ftp.example.com".to_string())
348 );
349 Ok(())
350 }
351
352 #[test]
353 fn test_parse_path_root_is_empty_string() -> Result<()> {
354 let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
355 assert_eq!(got, Some("".to_string()));
356 Ok(())
357 }
358
359 #[test]
360 fn test_parse_malformed_url_returns_error() -> Result<()> {
361 let got = ParseUrl::parse("notaurl", "HOST", None)?;
362 assert_eq!(got, None);
363 Ok(())
364 }
365
366 #[test]
367 fn test_spark_utf8_two_args() -> Result<()> {
368 let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
369 let parts = sa(&[Some("HOST"), Some("PATH")]);
370
371 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
372 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
373
374 assert_eq!(out_sa.len(), 2);
375 assert_eq!(out_sa.value(0), "example.com");
376 assert_eq!(out_sa.value(1), "");
377 Ok(())
378 }
379
380 #[test]
381 fn test_spark_utf8_three_args_query_key() -> Result<()> {
382 let urls = sa(&[
383 Some("https://example.com/a?x=1&y=2"),
384 Some("https://ex.com/?a=1"),
385 ]);
386 let parts = sa(&[Some("QUERY"), Some("QUERY")]);
387 let keys = sa(&[Some("y"), Some("b")]);
388
389 let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
390 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
391
392 assert_eq!(out_sa.len(), 2);
393 assert_eq!(out_sa.value(0), "2");
394 assert!(out_sa.is_null(1));
395 Ok(())
396 }
397
398 #[test]
399 fn test_spark_userinfo_and_nulls() -> Result<()> {
400 let urls = sa(&[
401 Some("ftp://user:pwd@ftp.example.com:21/files"),
402 Some("https://example.com"),
403 None,
404 ]);
405 let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
406
407 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
408 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
409
410 assert_eq!(out_sa.len(), 3);
411 assert_eq!(out_sa.value(0), "user:pwd");
412 assert!(out_sa.is_null(1));
413 assert!(out_sa.is_null(2));
414 Ok(())
415 }
416
417 #[test]
418 fn test_invalid_arg_count() {
419 let urls = sa(&[Some("https://example.com")]);
420 let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
421 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
422
423 let parts = sa(&[Some("HOST")]);
424 let keys = sa(&[Some("x")]);
425 let err =
426 spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
427 .unwrap_err();
428 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
429 }
430
431 #[test]
432 fn test_non_string_types_error() {
433 let urls = sa(&[Some("https://example.com")]);
434 let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
435
436 let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
437 let msg = format!("{err}");
438 assert!(msg.contains("expects STRING arguments"));
439 }
440}