1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23 StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27 as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{exec_datafusion_err, exec_err, Result};
30use datafusion_expr::{
31 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32 Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39 signature: Signature,
40}
41
42impl Default for ParseUrl {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ParseUrl {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::one_of(
52 vec![TypeSignature::String(2), TypeSignature::String(3)],
53 Volatility::Immutable,
54 ),
55 }
56 }
57 fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84 let url: std::result::Result<Url, ParseError> = Url::parse(value);
85 if let Err(ParseError::RelativeUrlWithoutBase) = url {
86 return if !value.contains("://") {
87 Ok(None)
88 } else {
89 Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"))
90 };
91 };
92 url.map_err(|e| exec_datafusion_err!("{e:?}"))
93 .map(|url| match part {
94 "HOST" => url.host_str().map(String::from),
95 "PATH" => {
96 let path: String = url.path().to_string();
97 let path: String = if path == "/" { "".to_string() } else { path };
98 Some(path)
99 }
100 "QUERY" => match key {
101 None => url.query().map(String::from),
102 Some(key) => url
103 .query_pairs()
104 .find(|(k, _)| k == key)
105 .map(|(_, v)| v.into_owned()),
106 },
107 "REF" => url.fragment().map(String::from),
108 "PROTOCOL" => Some(url.scheme().to_string()),
109 "FILE" => {
110 let path = url.path();
111 match url.query() {
112 Some(query) => Some(format!("{path}?{query}")),
113 None => Some(path.to_string()),
114 }
115 }
116 "AUTHORITY" => Some(url.authority().to_string()),
117 "USERINFO" => {
118 let username = url.username();
119 if username.is_empty() {
120 return None;
121 }
122 match url.password() {
123 Some(password) => Some(format!("{username}:{password}")),
124 None => Some(username.to_string()),
125 }
126 }
127 _ => None,
128 })
129 }
130}
131
132impl ScalarUDFImpl for ParseUrl {
133 fn as_any(&self) -> &dyn Any {
134 self
135 }
136
137 fn name(&self) -> &str {
138 "parse_url"
139 }
140
141 fn signature(&self) -> &Signature {
142 &self.signature
143 }
144
145 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
146 Ok(arg_types[0].clone())
147 }
148
149 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
150 let ScalarFunctionArgs { args, .. } = args;
151 make_scalar_function(spark_parse_url, vec![])(&args)
152 }
153}
154
155fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
171 spark_handled_parse_url(args, |x| x)
172}
173
174pub fn spark_handled_parse_url(
175 args: &[ArrayRef],
176 handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
177) -> Result<ArrayRef> {
178 if args.len() < 2 || args.len() > 3 {
179 return exec_err!(
180 "{} expects 2 or 3 arguments, but got {}",
181 "`parse_url`",
182 args.len()
183 );
184 }
185 let url = &args[0];
187 let part = &args[1];
188
189 let result = if args.len() == 3 {
190 let key = &args[2];
192
193 match (url.data_type(), part.data_type(), key.data_type()) {
194 (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
195 process_parse_url::<_, _, _, StringArray>(
196 as_string_array(url)?,
197 as_string_array(part)?,
198 as_string_array(key)?,
199 handler_err,
200 )
201 }
202 (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
203 process_parse_url::<_, _, _, StringViewArray>(
204 as_string_view_array(url)?,
205 as_string_view_array(part)?,
206 as_string_view_array(key)?,
207 handler_err,
208 )
209 }
210 (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
211 process_parse_url::<_, _, _, LargeStringArray>(
212 as_large_string_array(url)?,
213 as_large_string_array(part)?,
214 as_large_string_array(key)?,
215 handler_err,
216 )
217 }
218 _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
219 }
220 } else {
221 let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
224 for _ in 0..args[0].len() {
225 builder.append_null();
226 }
227 let key = builder.finish();
228
229 match (url.data_type(), part.data_type()) {
230 (DataType::Utf8, DataType::Utf8) => {
231 process_parse_url::<_, _, _, StringArray>(
232 as_string_array(url)?,
233 as_string_array(part)?,
234 &key,
235 handler_err,
236 )
237 }
238 (DataType::Utf8View, DataType::Utf8View) => {
239 process_parse_url::<_, _, _, StringViewArray>(
240 as_string_view_array(url)?,
241 as_string_view_array(part)?,
242 &key,
243 handler_err,
244 )
245 }
246 (DataType::LargeUtf8, DataType::LargeUtf8) => {
247 process_parse_url::<_, _, _, LargeStringArray>(
248 as_large_string_array(url)?,
249 as_large_string_array(part)?,
250 &key,
251 handler_err,
252 )
253 }
254 _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
255 }
256 };
257 result
258}
259
260fn process_parse_url<'a, A, B, C, T>(
261 url_array: &'a A,
262 part_array: &'a B,
263 key_array: &'a C,
264 handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
265) -> Result<ArrayRef>
266where
267 &'a A: StringArrayType<'a>,
268 &'a B: StringArrayType<'a>,
269 &'a C: StringArrayType<'a>,
270 T: Array + FromIterator<Option<String>> + 'static,
271{
272 url_array
273 .iter()
274 .zip(part_array.iter())
275 .zip(key_array.iter())
276 .map(|((url, part), key)| {
277 if let (Some(url), Some(part), key) = (url, part, key) {
278 handle(ParseUrl::parse(url, part, key))
279 } else {
280 Ok(None)
281 }
282 })
283 .collect::<Result<T>>()
284 .map(|array| Arc::new(array) as ArrayRef)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use arrow::array::{ArrayRef, Int32Array, StringArray};
291 use datafusion_common::Result;
292 use std::array::from_ref;
293 use std::sync::Arc;
294
295 fn sa(vals: &[Option<&str>]) -> ArrayRef {
296 Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
297 }
298
299 #[test]
300 fn test_parse_host() -> Result<()> {
301 let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
302 assert_eq!(got, Some("example.com".to_string()));
303 Ok(())
304 }
305
306 #[test]
307 fn test_parse_query_no_key_vs_with_key() -> Result<()> {
308 let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
309 assert_eq!(got_all, Some("a=1&b=2".to_string()));
310
311 let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
312 assert_eq!(got_a, Some("1".to_string()));
313
314 let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
315 assert_eq!(got_c, None);
316 Ok(())
317 }
318
319 #[test]
320 fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
321 let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
322 assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
323 assert_eq!(
324 ParseUrl::parse(url, "PROTOCOL", None)?,
325 Some("ftp".to_string())
326 );
327 assert_eq!(
328 ParseUrl::parse(url, "USERINFO", None)?,
329 Some("user:pwd".to_string())
330 );
331 assert_eq!(
332 ParseUrl::parse(url, "FILE", None)?,
333 Some("/files?x=1".to_string())
334 );
335 assert_eq!(
336 ParseUrl::parse(url, "AUTHORITY", None)?,
337 Some("user:pwd@ftp.example.com".to_string())
338 );
339 Ok(())
340 }
341
342 #[test]
343 fn test_parse_path_root_is_empty_string() -> Result<()> {
344 let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
345 assert_eq!(got, Some("".to_string()));
346 Ok(())
347 }
348
349 #[test]
350 fn test_parse_malformed_url_returns_error() -> Result<()> {
351 let got = ParseUrl::parse("notaurl", "HOST", None)?;
352 assert_eq!(got, None);
353 Ok(())
354 }
355
356 #[test]
357 fn test_spark_utf8_two_args() -> Result<()> {
358 let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
359 let parts = sa(&[Some("HOST"), Some("PATH")]);
360
361 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
362 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
363
364 assert_eq!(out_sa.len(), 2);
365 assert_eq!(out_sa.value(0), "example.com");
366 assert_eq!(out_sa.value(1), "");
367 Ok(())
368 }
369
370 #[test]
371 fn test_spark_utf8_three_args_query_key() -> Result<()> {
372 let urls = sa(&[
373 Some("https://example.com/a?x=1&y=2"),
374 Some("https://ex.com/?a=1"),
375 ]);
376 let parts = sa(&[Some("QUERY"), Some("QUERY")]);
377 let keys = sa(&[Some("y"), Some("b")]);
378
379 let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
380 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
381
382 assert_eq!(out_sa.len(), 2);
383 assert_eq!(out_sa.value(0), "2");
384 assert!(out_sa.is_null(1));
385 Ok(())
386 }
387
388 #[test]
389 fn test_spark_userinfo_and_nulls() -> Result<()> {
390 let urls = sa(&[
391 Some("ftp://user:pwd@ftp.example.com:21/files"),
392 Some("https://example.com"),
393 None,
394 ]);
395 let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
396
397 let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
398 let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
399
400 assert_eq!(out_sa.len(), 3);
401 assert_eq!(out_sa.value(0), "user:pwd");
402 assert!(out_sa.is_null(1));
403 assert!(out_sa.is_null(2));
404 Ok(())
405 }
406
407 #[test]
408 fn test_invalid_arg_count() {
409 let urls = sa(&[Some("https://example.com")]);
410 let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
411 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
412
413 let parts = sa(&[Some("HOST")]);
414 let keys = sa(&[Some("x")]);
415 let err =
416 spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
417 .unwrap_err();
418 assert!(format!("{err}").contains("expects 2 or 3 arguments"));
419 }
420
421 #[test]
422 fn test_non_string_types_error() {
423 let urls = sa(&[Some("https://example.com")]);
424 let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
425
426 let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
427 let msg = format!("{err}");
428 assert!(msg.contains("expects STRING arguments"));
429 }
430}