use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType;
use datafusion_common::cast::{
as_large_string_array, as_string_array, as_string_view_array,
};
use datafusion_common::{Result, exec_err, plan_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
use url::form_urlencoded::byte_serialize;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct UrlEncode {
signature: Signature,
}
impl Default for UrlEncode {
fn default() -> Self {
Self::new()
}
}
impl UrlEncode {
pub fn new() -> Self {
Self {
signature: Signature::string(1, Volatility::Immutable),
}
}
fn encode(value: &str) -> Result<String> {
Ok(byte_serialize(value.as_bytes()).collect::<String>())
}
}
impl ScalarUDFImpl for UrlEncode {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"url_encode"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return plan_err!(
"{} expects 1 argument, but got {}",
self.name(),
arg_types.len()
);
}
Ok(arg_types[0].clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
make_scalar_function(spark_url_encode, vec![])(&args)
}
}
fn spark_url_encode(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("`url_encode` expects 1 argument");
}
match &args[0].data_type() {
DataType::Utf8 => as_string_array(&args[0])?
.iter()
.map(|x| x.map(UrlEncode::encode).transpose())
.collect::<Result<StringArray>>()
.map(|array| Arc::new(array) as ArrayRef),
DataType::LargeUtf8 => as_large_string_array(&args[0])?
.iter()
.map(|x| x.map(UrlEncode::encode).transpose())
.collect::<Result<LargeStringArray>>()
.map(|array| Arc::new(array) as ArrayRef),
DataType::Utf8View => as_string_view_array(&args[0])?
.iter()
.map(|x| x.map(UrlEncode::encode).transpose())
.collect::<Result<StringViewArray>>()
.map(|array| Arc::new(array) as ArrayRef),
other => exec_err!("`url_encode`: Expr must be STRING, got {other:?}"),
}
}