Skip to main content

datafusion_spark/function/url/
parse_url.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23    StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27    as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_datafusion_err, exec_err};
30use datafusion_expr::{
31    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32    Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39    signature: Signature,
40}
41
42impl Default for ParseUrl {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ParseUrl {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::one_of(
52                vec![TypeSignature::String(2), TypeSignature::String(3)],
53                Volatility::Immutable,
54            ),
55        }
56    }
57    /// Parses a URL and extracts the specified component.
58    ///
59    /// This function takes a URL string and extracts different parts of it based on the
60    /// `part` parameter. For query parameters, an optional `key` can be specified to
61    /// extract a specific query parameter value.
62    ///
63    /// # Arguments
64    ///
65    /// * `value` - The URL string to parse
66    /// * `part` - The component of the URL to extract. Valid values are:
67    ///   - `"HOST"` - The hostname (e.g., "example.com")
68    ///   - `"PATH"` - The path portion (e.g., "/path/to/resource")
69    ///   - `"QUERY"` - The query string or a specific query parameter
70    ///   - `"REF"` - The fragment/anchor (the part after #)
71    ///   - `"PROTOCOL"` - The URL scheme (e.g., "https", "http")
72    ///   - `"FILE"` - The path with query string (e.g., "/path?query=value")
73    ///   - `"AUTHORITY"` - The authority component (host:port)
74    ///   - `"USERINFO"` - The user information (username:password)
75    /// * `key` - Optional parameter used only with `"QUERY"`. When provided, extracts
76    ///   the value of the specific query parameter with this key name.
77    ///
78    /// # Returns
79    ///
80    /// * `Ok(Some(String))` - The extracted URL component as a string
81    /// * `Ok(None)` - If the requested component doesn't exist or is empty
82    /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed
83    fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84        let url: std::result::Result<Url, ParseError> = Url::parse(value);
85        if let Err(ParseError::RelativeUrlWithoutBase) = url {
86            return if !value.contains("://") {
87                Ok(None)
88            } else {
89                Err(exec_datafusion_err!(
90                    "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"
91                ))
92            };
93        };
94        url.map_err(|e| exec_datafusion_err!("{e:?}"))
95            .map(|url| match part {
96                "HOST" => url.host_str().map(String::from),
97                "PATH" => {
98                    let path: String = url.path().to_string();
99                    let path: String = if path == "/" { "".to_string() } else { path };
100                    Some(path)
101                }
102                "QUERY" => match key {
103                    None => url.query().map(String::from),
104                    Some(key) => url
105                        .query_pairs()
106                        .find(|(k, _)| k == key)
107                        .map(|(_, v)| v.into_owned()),
108                },
109                "REF" => url.fragment().map(String::from),
110                "PROTOCOL" => Some(url.scheme().to_string()),
111                "FILE" => {
112                    let path = url.path();
113                    match url.query() {
114                        Some(query) => Some(format!("{path}?{query}")),
115                        None => Some(path.to_string()),
116                    }
117                }
118                "AUTHORITY" => Some(url.authority().to_string()),
119                "USERINFO" => {
120                    let username = url.username();
121                    if username.is_empty() {
122                        return None;
123                    }
124                    match url.password() {
125                        Some(password) => Some(format!("{username}:{password}")),
126                        None => Some(username.to_string()),
127                    }
128                }
129                _ => None,
130            })
131    }
132}
133
134impl ScalarUDFImpl for ParseUrl {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn name(&self) -> &str {
140        "parse_url"
141    }
142
143    fn signature(&self) -> &Signature {
144        &self.signature
145    }
146
147    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148        Ok(arg_types[0].clone())
149    }
150
151    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
152        let ScalarFunctionArgs { args, .. } = args;
153        make_scalar_function(spark_parse_url, vec![])(&args)
154    }
155}
156
157/// Core implementation of URL parsing function.
158///
159/// # Arguments
160///
161/// * `args` - A slice of ArrayRef containing the input arrays:
162///   - `args[0]` - URL array: The URLs to parse
163///   - `args[1]` - Part array: The URL components to extract (HOST, PATH, QUERY, etc.)
164///   - `args[2]` - Key array (optional): For QUERY part, the specific parameter names to extract
165///
166/// # Return Value
167///
168/// Returns `Result<ArrayRef>` containing:
169/// - A string array with extracted URL components
170/// - `None` values where extraction failed or component doesn't exist
171/// - The output array type (StringArray or LargeStringArray) is determined by input types
172fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
173    spark_handled_parse_url(args, |x| x)
174}
175
176pub fn spark_handled_parse_url(
177    args: &[ArrayRef],
178    handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
179) -> Result<ArrayRef> {
180    if args.len() < 2 || args.len() > 3 {
181        return exec_err!(
182            "{} expects 2 or 3 arguments, but got {}",
183            "`parse_url`",
184            args.len()
185        );
186    }
187    // Required arguments
188    let url = &args[0];
189    let part = &args[1];
190
191    if args.len() == 3 {
192        // In this case, the 'key' argument is passed
193        let key = &args[2];
194
195        match (url.data_type(), part.data_type(), key.data_type()) {
196            (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
197                process_parse_url::<_, _, _, StringArray>(
198                    as_string_array(url)?,
199                    as_string_array(part)?,
200                    as_string_array(key)?,
201                    handler_err,
202                )
203            }
204            (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
205                process_parse_url::<_, _, _, StringViewArray>(
206                    as_string_view_array(url)?,
207                    as_string_view_array(part)?,
208                    as_string_view_array(key)?,
209                    handler_err,
210                )
211            }
212            (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
213                process_parse_url::<_, _, _, LargeStringArray>(
214                    as_large_string_array(url)?,
215                    as_large_string_array(part)?,
216                    as_large_string_array(key)?,
217                    handler_err,
218                )
219            }
220            _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
221        }
222    } else {
223        // The 'key' argument is omitted, assume all values are null
224        // Create 'null' string array for 'key' argument
225        let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
226        for _ in 0..args[0].len() {
227            builder.append_null();
228        }
229        let key = builder.finish();
230
231        match (url.data_type(), part.data_type()) {
232            (DataType::Utf8, DataType::Utf8) => {
233                process_parse_url::<_, _, _, StringArray>(
234                    as_string_array(url)?,
235                    as_string_array(part)?,
236                    &key,
237                    handler_err,
238                )
239            }
240            (DataType::Utf8View, DataType::Utf8View) => {
241                process_parse_url::<_, _, _, StringViewArray>(
242                    as_string_view_array(url)?,
243                    as_string_view_array(part)?,
244                    &key,
245                    handler_err,
246                )
247            }
248            (DataType::LargeUtf8, DataType::LargeUtf8) => {
249                process_parse_url::<_, _, _, LargeStringArray>(
250                    as_large_string_array(url)?,
251                    as_large_string_array(part)?,
252                    &key,
253                    handler_err,
254                )
255            }
256            _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
257        }
258    }
259}
260
261fn process_parse_url<'a, A, B, C, T>(
262    url_array: &'a A,
263    part_array: &'a B,
264    key_array: &'a C,
265    handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
266) -> Result<ArrayRef>
267where
268    &'a A: StringArrayType<'a>,
269    &'a B: StringArrayType<'a>,
270    &'a C: StringArrayType<'a>,
271    T: Array + FromIterator<Option<String>> + 'static,
272{
273    url_array
274        .iter()
275        .zip(part_array.iter())
276        .zip(key_array.iter())
277        .map(|((url, part), key)| {
278            if let (Some(url), Some(part), key) = (url, part, key) {
279                handle(ParseUrl::parse(url, part, key))
280            } else {
281                Ok(None)
282            }
283        })
284        .collect::<Result<T>>()
285        .map(|array| Arc::new(array) as ArrayRef)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use arrow::array::{ArrayRef, Int32Array, StringArray};
292    use datafusion_common::Result;
293    use std::array::from_ref;
294    use std::sync::Arc;
295
296    fn sa(vals: &[Option<&str>]) -> ArrayRef {
297        Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
298    }
299
300    #[test]
301    fn test_parse_host() -> Result<()> {
302        let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
303        assert_eq!(got, Some("example.com".to_string()));
304        Ok(())
305    }
306
307    #[test]
308    fn test_parse_query_no_key_vs_with_key() -> Result<()> {
309        let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
310        assert_eq!(got_all, Some("a=1&b=2".to_string()));
311
312        let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
313        assert_eq!(got_a, Some("1".to_string()));
314
315        let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
316        assert_eq!(got_c, None);
317        Ok(())
318    }
319
320    #[test]
321    fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
322        let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
323        assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
324        assert_eq!(
325            ParseUrl::parse(url, "PROTOCOL", None)?,
326            Some("ftp".to_string())
327        );
328        assert_eq!(
329            ParseUrl::parse(url, "USERINFO", None)?,
330            Some("user:pwd".to_string())
331        );
332        assert_eq!(
333            ParseUrl::parse(url, "FILE", None)?,
334            Some("/files?x=1".to_string())
335        );
336        assert_eq!(
337            ParseUrl::parse(url, "AUTHORITY", None)?,
338            Some("user:pwd@ftp.example.com".to_string())
339        );
340        Ok(())
341    }
342
343    #[test]
344    fn test_parse_path_root_is_empty_string() -> Result<()> {
345        let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
346        assert_eq!(got, Some("".to_string()));
347        Ok(())
348    }
349
350    #[test]
351    fn test_parse_malformed_url_returns_error() -> Result<()> {
352        let got = ParseUrl::parse("notaurl", "HOST", None)?;
353        assert_eq!(got, None);
354        Ok(())
355    }
356
357    #[test]
358    fn test_spark_utf8_two_args() -> Result<()> {
359        let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
360        let parts = sa(&[Some("HOST"), Some("PATH")]);
361
362        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
363        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
364
365        assert_eq!(out_sa.len(), 2);
366        assert_eq!(out_sa.value(0), "example.com");
367        assert_eq!(out_sa.value(1), "");
368        Ok(())
369    }
370
371    #[test]
372    fn test_spark_utf8_three_args_query_key() -> Result<()> {
373        let urls = sa(&[
374            Some("https://example.com/a?x=1&y=2"),
375            Some("https://ex.com/?a=1"),
376        ]);
377        let parts = sa(&[Some("QUERY"), Some("QUERY")]);
378        let keys = sa(&[Some("y"), Some("b")]);
379
380        let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
381        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
382
383        assert_eq!(out_sa.len(), 2);
384        assert_eq!(out_sa.value(0), "2");
385        assert!(out_sa.is_null(1));
386        Ok(())
387    }
388
389    #[test]
390    fn test_spark_userinfo_and_nulls() -> Result<()> {
391        let urls = sa(&[
392            Some("ftp://user:pwd@ftp.example.com:21/files"),
393            Some("https://example.com"),
394            None,
395        ]);
396        let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
397
398        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
399        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
400
401        assert_eq!(out_sa.len(), 3);
402        assert_eq!(out_sa.value(0), "user:pwd");
403        assert!(out_sa.is_null(1));
404        assert!(out_sa.is_null(2));
405        Ok(())
406    }
407
408    #[test]
409    fn test_invalid_arg_count() {
410        let urls = sa(&[Some("https://example.com")]);
411        let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
412        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
413
414        let parts = sa(&[Some("HOST")]);
415        let keys = sa(&[Some("x")]);
416        let err =
417            spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
418                .unwrap_err();
419        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
420    }
421
422    #[test]
423    fn test_non_string_types_error() {
424        let urls = sa(&[Some("https://example.com")]);
425        let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
426
427        let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
428        let msg = format!("{err}");
429        assert!(msg.contains("expects STRING arguments"));
430    }
431}