Skip to main content

datafusion_spark/function/url/
parse_url.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23    StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27    as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_datafusion_err, exec_err};
30use datafusion_expr::{
31    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32    Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39    signature: Signature,
40}
41
42impl Default for ParseUrl {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ParseUrl {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::one_of(
52                vec![TypeSignature::String(2), TypeSignature::String(3)],
53                Volatility::Immutable,
54            ),
55        }
56    }
57    /// Parses a URL and extracts the specified component.
58    ///
59    /// This function takes a URL string and extracts different parts of it based on the
60    /// `part` parameter. For query parameters, an optional `key` can be specified to
61    /// extract a specific query parameter value.
62    ///
63    /// # Arguments
64    ///
65    /// * `value` - The URL string to parse
66    /// * `part` - The component of the URL to extract. Valid values are:
67    ///   - `"HOST"` - The hostname (e.g., "example.com")
68    ///   - `"PATH"` - The path portion (e.g., "/path/to/resource")
69    ///   - `"QUERY"` - The query string or a specific query parameter
70    ///   - `"REF"` - The fragment/anchor (the part after #)
71    ///   - `"PROTOCOL"` - The URL scheme (e.g., "https", "http")
72    ///   - `"FILE"` - The path with query string (e.g., "/path?query=value")
73    ///   - `"AUTHORITY"` - The authority component (host:port)
74    ///   - `"USERINFO"` - The user information (username:password)
75    /// * `key` - Optional parameter used only with `"QUERY"`. When provided, extracts
76    ///   the value of the specific query parameter with this key name.
77    ///
78    /// # Returns
79    ///
80    /// * `Ok(Some(String))` - The extracted URL component as a string
81    /// * `Ok(None)` - If the requested component doesn't exist or is empty
82    /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed
83    fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84        let url: std::result::Result<Url, ParseError> = Url::parse(value);
85        if let Err(ParseError::RelativeUrlWithoutBase) = url {
86            return if !value.contains("://") {
87                Ok(None)
88            } else {
89                Err(exec_datafusion_err!(
90                    "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"
91                ))
92            };
93        };
94        url.map_err(|e| exec_datafusion_err!("{e:?}"))
95            .map(|url| match part {
96                "HOST" => url.host_str().map(String::from),
97                "PATH" => {
98                    let path: String = url.path().to_string();
99                    let path: String = if path == "/" { "".to_string() } else { path };
100                    Some(path)
101                }
102                "QUERY" => match key {
103                    None => url.query().map(String::from),
104                    Some(key) => url
105                        .query_pairs()
106                        .find(|(k, _)| k == key)
107                        .map(|(_, v)| v.into_owned()),
108                },
109                "REF" => url.fragment().map(String::from),
110                "PROTOCOL" => Some(url.scheme().to_string()),
111                "FILE" => {
112                    let path = url.path();
113                    match url.query() {
114                        Some(query) => Some(format!("{path}?{query}")),
115                        None => Some(path.to_string()),
116                    }
117                }
118                "AUTHORITY" => Some(url.authority().to_string()),
119                "USERINFO" => {
120                    let username = url.username();
121                    if username.is_empty() {
122                        return None;
123                    }
124                    match url.password() {
125                        Some(password) => Some(format!("{username}:{password}")),
126                        None => Some(username.to_string()),
127                    }
128                }
129                _ => None,
130            })
131    }
132}
133
134impl ScalarUDFImpl for ParseUrl {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn name(&self) -> &str {
140        "parse_url"
141    }
142
143    fn signature(&self) -> &Signature {
144        &self.signature
145    }
146
147    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148        Ok(arg_types[0].clone())
149    }
150
151    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
152        let ScalarFunctionArgs { args, .. } = args;
153        make_scalar_function(spark_parse_url, vec![])(&args)
154    }
155}
156
157/// Core implementation of URL parsing function.
158///
159/// # Arguments
160///
161/// * `args` - A slice of ArrayRef containing the input arrays:
162///   - `args[0]` - URL array: The URLs to parse
163///   - `args[1]` - Part array: The URL components to extract (HOST, PATH, QUERY, etc.)
164///   - `args[2]` - Key array (optional): For QUERY part, the specific parameter names to extract
165///
166/// # Return Value
167///
168/// Returns `Result<ArrayRef>` containing:
169/// - A string array with extracted URL components
170/// - `None` values where extraction failed or component doesn't exist
171/// - The output array type (StringArray or LargeStringArray) is determined by input types
172fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
173    spark_handled_parse_url(args, |x| x)
174}
175
176pub fn spark_handled_parse_url(
177    args: &[ArrayRef],
178    handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
179) -> Result<ArrayRef> {
180    if args.len() < 2 || args.len() > 3 {
181        return exec_err!(
182            "{} expects 2 or 3 arguments, but got {}",
183            "`parse_url`",
184            args.len()
185        );
186    }
187    // Required arguments
188    let url = &args[0];
189    let part = &args[1];
190
191    if args.len() == 3 {
192        // In this case, the 'key' argument is passed
193        let key = &args[2];
194
195        match (url.data_type(), part.data_type(), key.data_type()) {
196            (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
197                process_parse_url::<_, _, _, StringArray>(
198                    as_string_array(url)?,
199                    as_string_array(part)?,
200                    as_string_array(key)?,
201                    handler_err,
202                )
203            }
204            (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
205                process_parse_url::<_, _, _, StringViewArray>(
206                    as_string_view_array(url)?,
207                    as_string_view_array(part)?,
208                    as_string_view_array(key)?,
209                    handler_err,
210                )
211            }
212            (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
213                process_parse_url::<_, _, _, LargeStringArray>(
214                    as_large_string_array(url)?,
215                    as_large_string_array(part)?,
216                    as_large_string_array(key)?,
217                    handler_err,
218                )
219            }
220            _ => exec_err!(
221                "`parse_url` expects STRING arguments, got ({}, {}, {})",
222                url.data_type(),
223                part.data_type(),
224                key.data_type()
225            ),
226        }
227    } else {
228        // The 'key' argument is omitted, assume all values are null
229        // Create 'null' string array for 'key' argument
230        let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
231        for _ in 0..args[0].len() {
232            builder.append_null();
233        }
234        let key = builder.finish();
235
236        match (url.data_type(), part.data_type()) {
237            (DataType::Utf8, DataType::Utf8) => {
238                process_parse_url::<_, _, _, StringArray>(
239                    as_string_array(url)?,
240                    as_string_array(part)?,
241                    &key,
242                    handler_err,
243                )
244            }
245            (DataType::Utf8View, DataType::Utf8View) => {
246                process_parse_url::<_, _, _, StringViewArray>(
247                    as_string_view_array(url)?,
248                    as_string_view_array(part)?,
249                    &key,
250                    handler_err,
251                )
252            }
253            (DataType::LargeUtf8, DataType::LargeUtf8) => {
254                process_parse_url::<_, _, _, LargeStringArray>(
255                    as_large_string_array(url)?,
256                    as_large_string_array(part)?,
257                    &key,
258                    handler_err,
259                )
260            }
261            _ => exec_err!(
262                "`parse_url` expects STRING arguments, got ({}, {})",
263                url.data_type(),
264                part.data_type()
265            ),
266        }
267    }
268}
269
270fn process_parse_url<'a, A, B, C, T>(
271    url_array: &'a A,
272    part_array: &'a B,
273    key_array: &'a C,
274    handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
275) -> Result<ArrayRef>
276where
277    &'a A: StringArrayType<'a>,
278    &'a B: StringArrayType<'a>,
279    &'a C: StringArrayType<'a>,
280    T: Array + FromIterator<Option<String>> + 'static,
281{
282    url_array
283        .iter()
284        .zip(part_array.iter())
285        .zip(key_array.iter())
286        .map(|((url, part), key)| {
287            if let (Some(url), Some(part), key) = (url, part, key) {
288                handle(ParseUrl::parse(url, part, key))
289            } else {
290                Ok(None)
291            }
292        })
293        .collect::<Result<T>>()
294        .map(|array| Arc::new(array) as ArrayRef)
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use arrow::array::{ArrayRef, Int32Array, StringArray};
301    use datafusion_common::Result;
302    use std::array::from_ref;
303    use std::sync::Arc;
304
305    fn sa(vals: &[Option<&str>]) -> ArrayRef {
306        Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
307    }
308
309    #[test]
310    fn test_parse_host() -> Result<()> {
311        let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
312        assert_eq!(got, Some("example.com".to_string()));
313        Ok(())
314    }
315
316    #[test]
317    fn test_parse_query_no_key_vs_with_key() -> Result<()> {
318        let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
319        assert_eq!(got_all, Some("a=1&b=2".to_string()));
320
321        let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
322        assert_eq!(got_a, Some("1".to_string()));
323
324        let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
325        assert_eq!(got_c, None);
326        Ok(())
327    }
328
329    #[test]
330    fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
331        let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
332        assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
333        assert_eq!(
334            ParseUrl::parse(url, "PROTOCOL", None)?,
335            Some("ftp".to_string())
336        );
337        assert_eq!(
338            ParseUrl::parse(url, "USERINFO", None)?,
339            Some("user:pwd".to_string())
340        );
341        assert_eq!(
342            ParseUrl::parse(url, "FILE", None)?,
343            Some("/files?x=1".to_string())
344        );
345        assert_eq!(
346            ParseUrl::parse(url, "AUTHORITY", None)?,
347            Some("user:pwd@ftp.example.com".to_string())
348        );
349        Ok(())
350    }
351
352    #[test]
353    fn test_parse_path_root_is_empty_string() -> Result<()> {
354        let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
355        assert_eq!(got, Some("".to_string()));
356        Ok(())
357    }
358
359    #[test]
360    fn test_parse_malformed_url_returns_error() -> Result<()> {
361        let got = ParseUrl::parse("notaurl", "HOST", None)?;
362        assert_eq!(got, None);
363        Ok(())
364    }
365
366    #[test]
367    fn test_spark_utf8_two_args() -> Result<()> {
368        let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
369        let parts = sa(&[Some("HOST"), Some("PATH")]);
370
371        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
372        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
373
374        assert_eq!(out_sa.len(), 2);
375        assert_eq!(out_sa.value(0), "example.com");
376        assert_eq!(out_sa.value(1), "");
377        Ok(())
378    }
379
380    #[test]
381    fn test_spark_utf8_three_args_query_key() -> Result<()> {
382        let urls = sa(&[
383            Some("https://example.com/a?x=1&y=2"),
384            Some("https://ex.com/?a=1"),
385        ]);
386        let parts = sa(&[Some("QUERY"), Some("QUERY")]);
387        let keys = sa(&[Some("y"), Some("b")]);
388
389        let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
390        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
391
392        assert_eq!(out_sa.len(), 2);
393        assert_eq!(out_sa.value(0), "2");
394        assert!(out_sa.is_null(1));
395        Ok(())
396    }
397
398    #[test]
399    fn test_spark_userinfo_and_nulls() -> Result<()> {
400        let urls = sa(&[
401            Some("ftp://user:pwd@ftp.example.com:21/files"),
402            Some("https://example.com"),
403            None,
404        ]);
405        let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
406
407        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
408        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
409
410        assert_eq!(out_sa.len(), 3);
411        assert_eq!(out_sa.value(0), "user:pwd");
412        assert!(out_sa.is_null(1));
413        assert!(out_sa.is_null(2));
414        Ok(())
415    }
416
417    #[test]
418    fn test_invalid_arg_count() {
419        let urls = sa(&[Some("https://example.com")]);
420        let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
421        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
422
423        let parts = sa(&[Some("HOST")]);
424        let keys = sa(&[Some("x")]);
425        let err =
426            spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
427                .unwrap_err();
428        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
429    }
430
431    #[test]
432    fn test_non_string_types_error() {
433        let urls = sa(&[Some("https://example.com")]);
434        let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
435
436        let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
437        let msg = format!("{err}");
438        assert!(msg.contains("expects STRING arguments"));
439    }
440}