use std::any::Any;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow::datatypes::DataType;
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
};
use datafusion::prelude::SessionContext;
use uni_query_functions::custom_functions::{CustomFunctionRegistry, LEGACY_USER_PLUGIN_ID};
use crate::query::executor::plugin_adapter::ValueRowFn;
tokio::task_local! {
pub static SESSION_PLUGIN_REGISTRY:
std::sync::Arc<uni_plugin::PluginRegistry>;
}
pub fn scoped_with_session_plugin_registry<F: std::future::Future>(
registry: std::sync::Arc<uni_plugin::PluginRegistry>,
fut: F,
) -> tokio::task::futures::TaskLocalFuture<std::sync::Arc<uni_plugin::PluginRegistry>, F> {
SESSION_PLUGIN_REGISTRY.scope(registry, fut)
}
#[must_use]
pub fn current_session_plugin_registry() -> Option<std::sync::Arc<uni_plugin::PluginRegistry>> {
SESSION_PLUGIN_REGISTRY.try_with(|r| r.clone()).ok()
}
pub use uni_plugin::host::principal::{
CURRENT_PRINCIPAL, current_principal, maybe_scope_with_principal, scoped_with_principal,
};
pub async fn scoped_with_session_context<F: std::future::Future>(
registry: std::sync::Arc<uni_plugin::PluginRegistry>,
principal: Option<std::sync::Arc<uni_plugin::traits::connector::Principal>>,
fut: F,
) -> F::Output {
scoped_with_session_plugin_registry(registry, maybe_scope_with_principal(principal, fut)).await
}
pub fn register_plugin_scalar_udfs_pair(
ctx: &SessionContext,
instance: &uni_plugin::PluginRegistry,
session: Option<&uni_plugin::PluginRegistry>,
) -> DFResult<()> {
register_plugin_scalar_udfs(ctx, instance)?;
if let Some(session_reg) = session {
register_plugin_scalar_udfs(ctx, session_reg)?;
}
Ok(())
}
pub fn register_plugin_scalar_udfs(
ctx: &SessionContext,
plugin_registry: &uni_plugin::PluginRegistry,
) -> DFResult<()> {
for (qname, entry) in plugin_registry.iter_scalars() {
let local = qname.local();
let lower_local = local.to_lowercase();
let upper_local = local.to_uppercase();
if lower_local != upper_local {
ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
lower_local.clone(),
Arc::clone(&entry),
)));
}
ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
upper_local,
Arc::clone(&entry),
)));
ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
qname.to_string(),
Arc::clone(&entry),
)));
}
Ok(())
}
fn plugin_registry_for_custom_functions(
registry: &CustomFunctionRegistry,
) -> uni_plugin::PluginRegistry {
use datafusion::logical_expr::Volatility;
use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
use uni_plugin::{Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName};
let pr = PluginRegistry::new();
let plugin_id = PluginId::new(LEGACY_USER_PLUGIN_ID);
let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
for (name, func) in registry.iter() {
let upper = name.to_uppercase();
let mut r = PluginRegistrar::new(plugin_id.clone(), &caps, &pr);
let qname = QName::new(LEGACY_USER_PLUGIN_ID, &upper);
let adapter = Arc::new(ValueRowFn::new(upper.clone(), Arc::clone(func)));
let sig = FnSignature {
args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
returns: ArgType::CypherValue,
volatility: Volatility::Volatile,
null_handling: NullHandling::UserHandled,
};
if let Err(e) = r.scalar_fn(qname, sig, adapter) {
tracing::warn!(error = ?e, fn_name = %upper, "shadow registration failed");
continue;
}
if let Err(e) = r.commit_to_registry() {
tracing::warn!(error = ?e, fn_name = %upper, "shadow commit failed");
}
}
pr
}
pub fn register_custom_functions_as_plugin_scalars(
ctx: &SessionContext,
registry: &CustomFunctionRegistry,
) -> DFResult<()> {
let shadow = plugin_registry_for_custom_functions(registry);
register_plugin_scalar_udfs(ctx, &shadow)
}
struct PluginScalarUdf {
name: String,
entry: Arc<uni_plugin::registry::ScalarEntry>,
signature: Signature,
return_type: DataType,
}
impl PluginScalarUdf {
fn new(name: String, entry: Arc<uni_plugin::registry::ScalarEntry>) -> Self {
let volatility = entry.signature.volatility;
let return_type = derive_return_type(&entry);
Self {
signature: Signature::new(TypeSignature::VariadicAny, volatility),
name,
entry,
return_type,
}
}
}
fn derive_return_type(entry: &uni_plugin::registry::ScalarEntry) -> DataType {
use uni_plugin::traits::scalar::ArgType;
match &entry.signature.returns {
ArgType::Primitive(t) => t.clone(),
_ => DataType::LargeBinary,
}
}
impl std::fmt::Debug for PluginScalarUdf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginScalarUdf")
.field("name", &self.name)
.finish()
}
}
impl PartialEq for PluginScalarUdf {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for PluginScalarUdf {}
impl Hash for PluginScalarUdf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl ScalarUDFImpl for PluginScalarUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(self.return_type.clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let entry = Arc::clone(&self.entry);
let rows = args.number_rows;
let cols = args.args;
entry.function.invoke(&cols, rows).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"plugin `{}` fn `{}` failed: {e}",
entry.plugin, self.name
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::Volatility;
#[test]
fn test_register_plugin_scalars_routes_through_plugin_registry() {
use uni_common::Value;
use uni_query_functions::custom_functions::{CustomFunctionRegistry, CustomScalarFn};
let mut reg = CustomFunctionRegistry::new();
let f: CustomScalarFn =
Arc::new(|_args: &[Value]| Ok(Value::String("plugin-path".to_owned())));
reg.register("MYFN".to_owned(), f);
let ctx = SessionContext::new();
register_custom_functions_as_plugin_scalars(&ctx, ®).unwrap();
assert!(ctx.udf("myfn").is_ok());
assert!(ctx.udf("MYFN").is_ok());
let qname = format!("{LEGACY_USER_PLUGIN_ID}.MYFN");
assert!(ctx.udf(&qname).is_ok());
}
#[test]
fn test_native_arrow_udf_declares_primitive_return_type() {
use std::sync::OnceLock;
use uni_plugin::FnError;
use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
use uni_plugin::{
Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName,
};
struct DoubleIt;
impl ScalarPluginFn for DoubleIt {
fn signature(&self) -> &FnSignature {
static S: OnceLock<FnSignature> = OnceLock::new();
S.get_or_init(|| FnSignature {
args: vec![ArgType::Primitive(DataType::Float64)],
returns: ArgType::Primitive(DataType::Float64),
volatility: Volatility::Immutable,
null_handling: NullHandling::PropagateNulls,
})
}
fn invoke(
&self,
args: &[ColumnarValue],
_rows: usize,
) -> Result<ColumnarValue, FnError> {
Ok(args.first().cloned().unwrap())
}
}
let pr = PluginRegistry::new();
let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
let mut r = PluginRegistrar::new(PluginId::new("test.fast"), &caps, &pr);
r.scalar_fn(
QName::new("test.fast", "double"),
FnSignature {
args: vec![ArgType::Primitive(DataType::Float64)],
returns: ArgType::Primitive(DataType::Float64),
volatility: Volatility::Immutable,
null_handling: NullHandling::PropagateNulls,
},
Arc::new(DoubleIt),
)
.unwrap();
r.commit_to_registry().unwrap();
let ctx = SessionContext::new();
register_plugin_scalar_udfs(&ctx, &pr).unwrap();
let udf = ctx.udf("double").expect("udf registered");
let rt = udf.return_type(&[DataType::Float64]).unwrap();
assert_eq!(
rt,
DataType::Float64,
"primitive-typed plugin should declare Float64 directly, not LargeBinary"
);
}
}