use std::any::Any;
use std::borrow::Cow;
use std::sync::Arc;
use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType;
use datafusion_common::cast::{
as_large_string_array, as_string_array, as_string_view_array,
};
use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
use percent_encoding::percent_decode;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct UrlDecode {
signature: Signature,
}
impl Default for UrlDecode {
fn default() -> Self {
Self::new()
}
}
impl UrlDecode {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
}
}
fn decode(value: &str) -> Result<String> {
Self::validate_percent_encoding(value)?;
let replaced = Self::replace_plus(value.as_bytes());
percent_decode(&replaced)
.decode_utf8()
.map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}"))
.map(|parsed| parsed.into_owned())
}
fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> {
match input.iter().position(|&b| b == b'+') {
None => Cow::Borrowed(input),
Some(first_position) => {
let mut replaced = input.to_owned();
replaced[first_position] = b' ';
for byte in &mut replaced[first_position + 1..] {
if *byte == b'+' {
*byte = b' ';
}
}
Cow::Owned(replaced)
}
}
}
fn validate_percent_encoding(value: &str) -> Result<()> {
let bytes = value.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' {
if i + 2 >= bytes.len() {
return exec_err!(
"Invalid percent-encoding: incomplete sequence at position {}",
i
);
}
let hex1 = bytes[i + 1];
let hex2 = bytes[i + 2];
if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() {
return exec_err!(
"Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}",
hex1 as char,
hex2 as char,
i
);
}
i += 3;
} else {
i += 1;
}
}
Ok(())
}
}
impl ScalarUDFImpl for UrlDecode {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"url_decode"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return plan_err!(
"{} expects 1 argument, but got {}",
self.name(),
arg_types.len()
);
}
Ok(arg_types[0].clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
make_scalar_function(spark_url_decode, vec![])(&args)
}
}
fn spark_url_decode(args: &[ArrayRef]) -> Result<ArrayRef> {
spark_handled_url_decode(args, |x| x)
}
pub fn spark_handled_url_decode(
args: &[ArrayRef],
err_handle_fn: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("`url_decode` expects 1 argument");
}
match &args[0].data_type() {
DataType::Utf8 => as_string_array(&args[0])?
.iter()
.map(|x| x.map(UrlDecode::decode).transpose())
.map(&err_handle_fn)
.collect::<Result<StringArray>>()
.map(|array| Arc::new(array) as ArrayRef),
DataType::LargeUtf8 => as_large_string_array(&args[0])?
.iter()
.map(|x| x.map(UrlDecode::decode).transpose())
.map(&err_handle_fn)
.collect::<Result<LargeStringArray>>()
.map(|array| Arc::new(array) as ArrayRef),
DataType::Utf8View => as_string_view_array(&args[0])?
.iter()
.map(|x| x.map(UrlDecode::decode).transpose())
.map(&err_handle_fn)
.collect::<Result<StringViewArray>>()
.map(|array| Arc::new(array) as ArrayRef),
other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"),
}
}
#[cfg(test)]
mod tests {
use arrow::array::StringArray;
use datafusion_common::Result;
use super::*;
#[test]
fn test_decode() -> Result<()> {
let input = Arc::new(StringArray::from(vec![
Some("https%3A%2F%2Fspark.apache.org"),
Some("inva+lid://user:pass@host/file\\;param?query\\;p2"),
Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"),
Some("%E4%BD%A0%E5%A5%BD"),
Some(""),
None,
]));
let expected = StringArray::from(vec![
Some("https://spark.apache.org"),
Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
Some("inva lid://user:pass@host/file\\;param?query\\;p2"),
Some("~!@#$%^&*()_+"),
Some("你好"),
Some(""),
None,
]);
let result = spark_url_decode(&[input as ArrayRef])?;
let result = as_string_array(&result)?;
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn test_decode_error() -> Result<()> {
let input = Arc::new(StringArray::from(vec![
Some("http%3A%2F%2spark.apache.org"), Some("https%3A%2F%2Fspark.apache.org"),
None,
]));
let result = spark_url_decode(&[input]);
assert!(
result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))
);
Ok(())
}
}