datafusion_spark/function/url/
parse_url.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray,
23    StringArrayType, StringViewArray,
24};
25use arrow::datatypes::DataType;
26use datafusion_common::cast::{
27    as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{exec_datafusion_err, exec_err, Result};
30use datafusion_expr::{
31    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32    Volatility,
33};
34use datafusion_functions::utils::make_scalar_function;
35use url::{ParseError, Url};
36
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct ParseUrl {
39    signature: Signature,
40}
41
42impl Default for ParseUrl {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ParseUrl {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::one_of(
52                vec![TypeSignature::String(2), TypeSignature::String(3)],
53                Volatility::Immutable,
54            ),
55        }
56    }
57    /// Parses a URL and extracts the specified component.
58    ///
59    /// This function takes a URL string and extracts different parts of it based on the
60    /// `part` parameter. For query parameters, an optional `key` can be specified to
61    /// extract a specific query parameter value.
62    ///
63    /// # Arguments
64    ///
65    /// * `value` - The URL string to parse
66    /// * `part` - The component of the URL to extract. Valid values are:
67    ///   - `"HOST"` - The hostname (e.g., "example.com")
68    ///   - `"PATH"` - The path portion (e.g., "/path/to/resource")
69    ///   - `"QUERY"` - The query string or a specific query parameter
70    ///   - `"REF"` - The fragment/anchor (the part after #)
71    ///   - `"PROTOCOL"` - The URL scheme (e.g., "https", "http")
72    ///   - `"FILE"` - The path with query string (e.g., "/path?query=value")
73    ///   - `"AUTHORITY"` - The authority component (host:port)
74    ///   - `"USERINFO"` - The user information (username:password)
75    /// * `key` - Optional parameter used only with `"QUERY"`. When provided, extracts
76    ///   the value of the specific query parameter with this key name.
77    ///
78    /// # Returns
79    ///
80    /// * `Ok(Some(String))` - The extracted URL component as a string
81    /// * `Ok(None)` - If the requested component doesn't exist or is empty
82    /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed
83    fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
84        let url: std::result::Result<Url, ParseError> = Url::parse(value);
85        if let Err(ParseError::RelativeUrlWithoutBase) = url {
86            return if !value.contains("://") {
87                Ok(None)
88            } else {
89                Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"))
90            };
91        };
92        url.map_err(|e| exec_datafusion_err!("{e:?}"))
93            .map(|url| match part {
94                "HOST" => url.host_str().map(String::from),
95                "PATH" => {
96                    let path: String = url.path().to_string();
97                    let path: String = if path == "/" { "".to_string() } else { path };
98                    Some(path)
99                }
100                "QUERY" => match key {
101                    None => url.query().map(String::from),
102                    Some(key) => url
103                        .query_pairs()
104                        .find(|(k, _)| k == key)
105                        .map(|(_, v)| v.into_owned()),
106                },
107                "REF" => url.fragment().map(String::from),
108                "PROTOCOL" => Some(url.scheme().to_string()),
109                "FILE" => {
110                    let path = url.path();
111                    match url.query() {
112                        Some(query) => Some(format!("{path}?{query}")),
113                        None => Some(path.to_string()),
114                    }
115                }
116                "AUTHORITY" => Some(url.authority().to_string()),
117                "USERINFO" => {
118                    let username = url.username();
119                    if username.is_empty() {
120                        return None;
121                    }
122                    match url.password() {
123                        Some(password) => Some(format!("{username}:{password}")),
124                        None => Some(username.to_string()),
125                    }
126                }
127                _ => None,
128            })
129    }
130}
131
132impl ScalarUDFImpl for ParseUrl {
133    fn as_any(&self) -> &dyn Any {
134        self
135    }
136
137    fn name(&self) -> &str {
138        "parse_url"
139    }
140
141    fn signature(&self) -> &Signature {
142        &self.signature
143    }
144
145    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
146        Ok(arg_types[0].clone())
147    }
148
149    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
150        let ScalarFunctionArgs { args, .. } = args;
151        make_scalar_function(spark_parse_url, vec![])(&args)
152    }
153}
154
155/// Core implementation of URL parsing function.
156///
157/// # Arguments
158///
159/// * `args` - A slice of ArrayRef containing the input arrays:
160///   - `args[0]` - URL array: The URLs to parse
161///   - `args[1]` - Part array: The URL components to extract (HOST, PATH, QUERY, etc.)
162///   - `args[2]` - Key array (optional): For QUERY part, the specific parameter names to extract
163///
164/// # Return Value
165///
166/// Returns `Result<ArrayRef>` containing:
167/// - A string array with extracted URL components
168/// - `None` values where extraction failed or component doesn't exist
169/// - The output array type (StringArray or LargeStringArray) is determined by input types
170fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
171    spark_handled_parse_url(args, |x| x)
172}
173
174pub fn spark_handled_parse_url(
175    args: &[ArrayRef],
176    handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
177) -> Result<ArrayRef> {
178    if args.len() < 2 || args.len() > 3 {
179        return exec_err!(
180            "{} expects 2 or 3 arguments, but got {}",
181            "`parse_url`",
182            args.len()
183        );
184    }
185    // Required arguments
186    let url = &args[0];
187    let part = &args[1];
188
189    let result = if args.len() == 3 {
190        // In this case, the 'key' argument is passed
191        let key = &args[2];
192
193        match (url.data_type(), part.data_type(), key.data_type()) {
194            (DataType::Utf8, DataType::Utf8, DataType::Utf8) => {
195                process_parse_url::<_, _, _, StringArray>(
196                    as_string_array(url)?,
197                    as_string_array(part)?,
198                    as_string_array(key)?,
199                    handler_err,
200                )
201            }
202            (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
203                process_parse_url::<_, _, _, StringViewArray>(
204                    as_string_view_array(url)?,
205                    as_string_view_array(part)?,
206                    as_string_view_array(key)?,
207                    handler_err,
208                )
209            }
210            (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
211                process_parse_url::<_, _, _, LargeStringArray>(
212                    as_large_string_array(url)?,
213                    as_large_string_array(part)?,
214                    as_large_string_array(key)?,
215                    handler_err,
216                )
217            }
218            _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
219        }
220    } else {
221        // The 'key' argument is omitted, assume all values are null
222        // Create 'null' string array for 'key' argument
223        let mut builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
224        for _ in 0..args[0].len() {
225            builder.append_null();
226        }
227        let key = builder.finish();
228
229        match (url.data_type(), part.data_type()) {
230            (DataType::Utf8, DataType::Utf8) => {
231                process_parse_url::<_, _, _, StringArray>(
232                    as_string_array(url)?,
233                    as_string_array(part)?,
234                    &key,
235                    handler_err,
236                )
237            }
238            (DataType::Utf8View, DataType::Utf8View) => {
239                process_parse_url::<_, _, _, StringViewArray>(
240                    as_string_view_array(url)?,
241                    as_string_view_array(part)?,
242                    &key,
243                    handler_err,
244                )
245            }
246            (DataType::LargeUtf8, DataType::LargeUtf8) => {
247                process_parse_url::<_, _, _, LargeStringArray>(
248                    as_large_string_array(url)?,
249                    as_large_string_array(part)?,
250                    &key,
251                    handler_err,
252                )
253            }
254            _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
255        }
256    };
257    result
258}
259
260fn process_parse_url<'a, A, B, C, T>(
261    url_array: &'a A,
262    part_array: &'a B,
263    key_array: &'a C,
264    handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
265) -> Result<ArrayRef>
266where
267    &'a A: StringArrayType<'a>,
268    &'a B: StringArrayType<'a>,
269    &'a C: StringArrayType<'a>,
270    T: Array + FromIterator<Option<String>> + 'static,
271{
272    url_array
273        .iter()
274        .zip(part_array.iter())
275        .zip(key_array.iter())
276        .map(|((url, part), key)| {
277            if let (Some(url), Some(part), key) = (url, part, key) {
278                handle(ParseUrl::parse(url, part, key))
279            } else {
280                Ok(None)
281            }
282        })
283        .collect::<Result<T>>()
284        .map(|array| Arc::new(array) as ArrayRef)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use arrow::array::{ArrayRef, Int32Array, StringArray};
291    use datafusion_common::Result;
292    use std::array::from_ref;
293    use std::sync::Arc;
294
295    fn sa(vals: &[Option<&str>]) -> ArrayRef {
296        Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
297    }
298
299    #[test]
300    fn test_parse_host() -> Result<()> {
301        let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
302        assert_eq!(got, Some("example.com".to_string()));
303        Ok(())
304    }
305
306    #[test]
307    fn test_parse_query_no_key_vs_with_key() -> Result<()> {
308        let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
309        assert_eq!(got_all, Some("a=1&b=2".to_string()));
310
311        let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
312        assert_eq!(got_a, Some("1".to_string()));
313
314        let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
315        assert_eq!(got_c, None);
316        Ok(())
317    }
318
319    #[test]
320    fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
321        let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
322        assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
323        assert_eq!(
324            ParseUrl::parse(url, "PROTOCOL", None)?,
325            Some("ftp".to_string())
326        );
327        assert_eq!(
328            ParseUrl::parse(url, "USERINFO", None)?,
329            Some("user:pwd".to_string())
330        );
331        assert_eq!(
332            ParseUrl::parse(url, "FILE", None)?,
333            Some("/files?x=1".to_string())
334        );
335        assert_eq!(
336            ParseUrl::parse(url, "AUTHORITY", None)?,
337            Some("user:pwd@ftp.example.com".to_string())
338        );
339        Ok(())
340    }
341
342    #[test]
343    fn test_parse_path_root_is_empty_string() -> Result<()> {
344        let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
345        assert_eq!(got, Some("".to_string()));
346        Ok(())
347    }
348
349    #[test]
350    fn test_parse_malformed_url_returns_error() -> Result<()> {
351        let got = ParseUrl::parse("notaurl", "HOST", None)?;
352        assert_eq!(got, None);
353        Ok(())
354    }
355
356    #[test]
357    fn test_spark_utf8_two_args() -> Result<()> {
358        let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
359        let parts = sa(&[Some("HOST"), Some("PATH")]);
360
361        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
362        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
363
364        assert_eq!(out_sa.len(), 2);
365        assert_eq!(out_sa.value(0), "example.com");
366        assert_eq!(out_sa.value(1), "");
367        Ok(())
368    }
369
370    #[test]
371    fn test_spark_utf8_three_args_query_key() -> Result<()> {
372        let urls = sa(&[
373            Some("https://example.com/a?x=1&y=2"),
374            Some("https://ex.com/?a=1"),
375        ]);
376        let parts = sa(&[Some("QUERY"), Some("QUERY")]);
377        let keys = sa(&[Some("y"), Some("b")]);
378
379        let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
380        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
381
382        assert_eq!(out_sa.len(), 2);
383        assert_eq!(out_sa.value(0), "2");
384        assert!(out_sa.is_null(1));
385        Ok(())
386    }
387
388    #[test]
389    fn test_spark_userinfo_and_nulls() -> Result<()> {
390        let urls = sa(&[
391            Some("ftp://user:pwd@ftp.example.com:21/files"),
392            Some("https://example.com"),
393            None,
394        ]);
395        let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
396
397        let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
398        let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
399
400        assert_eq!(out_sa.len(), 3);
401        assert_eq!(out_sa.value(0), "user:pwd");
402        assert!(out_sa.is_null(1));
403        assert!(out_sa.is_null(2));
404        Ok(())
405    }
406
407    #[test]
408    fn test_invalid_arg_count() {
409        let urls = sa(&[Some("https://example.com")]);
410        let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
411        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
412
413        let parts = sa(&[Some("HOST")]);
414        let keys = sa(&[Some("x")]);
415        let err =
416            spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
417                .unwrap_err();
418        assert!(format!("{err}").contains("expects 2 or 3 arguments"));
419    }
420
421    #[test]
422    fn test_non_string_types_error() {
423        let urls = sa(&[Some("https://example.com")]);
424        let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
425
426        let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
427        let msg = format!("{err}");
428        assert!(msg.contains("expects STRING arguments"));
429    }
430}