use std::collections::HashMap;
use std::sync::Arc;
use super::backends::BackendRegistry;
use super::{FunctionInfo, FunctionSignature, ScalarFunction};
use crate::core::{Error, Result, Value};
pub struct UserDefinedScalarFunction {
name: String,
code: String,
language: String,
param_names: Vec<String>,
signature: FunctionSignature,
backend_registry: Arc<BackendRegistry>,
}
impl UserDefinedScalarFunction {
pub fn new(
name: impl Into<String>,
code: impl Into<String>,
language: impl Into<String>,
param_names: Vec<String>,
signature: FunctionSignature,
backend_registry: Arc<BackendRegistry>,
) -> Self {
Self {
name: name.into(),
code: code.into(),
language: language.into(),
param_names,
signature,
backend_registry,
}
}
}
impl ScalarFunction for UserDefinedScalarFunction {
fn name(&self) -> &str {
&self.name
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
self.name.clone(),
super::FunctionType::Scalar,
"User-defined function".to_string(),
self.signature.clone(),
)
}
fn evaluate(&self, args: &[Value]) -> Result<Value> {
let backend = self
.backend_registry
.get_backend(&self.language)
.ok_or_else(|| {
Error::internal(format!(
"No backend available for language: {}",
self.language
))
})?;
backend.execute(
&self.code,
args,
&self
.param_names
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
)
}
fn clone_box(&self) -> Box<dyn ScalarFunction> {
Box::new(Self {
name: self.name.clone(),
code: self.code.clone(),
language: self.language.clone(),
param_names: self.param_names.clone(),
signature: self.signature.clone(),
backend_registry: self.backend_registry.clone(),
})
}
}
pub struct UserDefinedFunctionRegistry {
functions: HashMap<String, Arc<UserDefinedScalarFunction>>,
backend_registry: Arc<BackendRegistry>,
}
impl UserDefinedFunctionRegistry {
pub fn new(backend_registry: Arc<BackendRegistry>) -> Self {
Self {
functions: HashMap::new(),
backend_registry,
}
}
pub fn register(
&mut self,
name: String,
code: String,
language: String,
param_names: Vec<String>,
signature: FunctionSignature,
) -> Result<()> {
if !self.backend_registry.is_language_supported(&language) {
return Err(Error::internal(format!(
"Unsupported language: {}",
language
)));
}
let udf = Arc::new(UserDefinedScalarFunction::new(
name.clone(),
code,
language,
param_names,
signature,
self.backend_registry.clone(),
));
self.functions.insert(name.to_uppercase(), udf);
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<UserDefinedScalarFunction>> {
self.functions.get(&name.to_uppercase()).cloned()
}
pub fn exists(&self, name: &str) -> bool {
self.functions.contains_key(&name.to_uppercase())
}
pub fn is_language_supported(&self, language: &str) -> bool {
self.backend_registry.is_language_supported(language)
}
pub fn unregister(&mut self, name: &str) -> Result<()> {
let key = name.to_uppercase();
if self.functions.remove(&key).is_none() {
return Err(Error::FunctionNotFound(name.to_string()));
}
Ok(())
}
pub fn list(&self) -> Vec<String> {
self.functions.keys().cloned().collect()
}
}
impl Clone for UserDefinedFunctionRegistry {
fn clone(&self) -> Self {
Self {
functions: self.functions.clone(),
backend_registry: self.backend_registry.clone(),
}
}
}
impl Default for UserDefinedFunctionRegistry {
fn default() -> Self {
panic!("UserDefinedFunctionRegistry::default() should not be called directly. Use new() with a backend registry.");
}
}