#![cfg(feature = "rhai-runtime")]
use std::sync::Arc;
use uni_plugin::{
Capability, CapabilitySet, HttpEgress, KmsProvider, PluginError, PluginId, PluginRegistrar,
QName,
};
use arrow_schema::Field;
use uni_plugin::capability::SideEffects;
use uni_plugin::secrets::SecretStore;
use uni_plugin::traits::procedure::{NamedArgType, ProcedureMode, ProcedureSignature};
use crate::adapter::RhaiScalarFn;
use crate::adapter_aggregate::{RhaiAggregateFn, build_agg_signature};
use crate::adapter_procedure::RhaiProcedure;
use crate::engine::build_engine;
use crate::error::RhaiError;
use crate::host_fns::RhaiHostFnRegistry;
use crate::manifest::{ProcedureEntry, RhaiManifest, compile, parse_manifest};
use crate::runtime::RhaiPluginRuntime;
use crate::wire_translate::{build_fn_signature, type_name_to_argtype, type_name_to_datatype};
#[derive(Debug)]
pub struct LoadOutcome {
pub plugin_id: PluginId,
pub version: String,
pub effective_capabilities: CapabilitySet,
pub denied_capabilities: Vec<Capability>,
pub scalars_registered: Vec<String>,
pub aggregates_registered: Vec<String>,
pub procedures_registered: Vec<String>,
pub runtime: Arc<RhaiPluginRuntime>,
}
#[derive(Default, Clone)]
pub struct RhaiLoader {
host_fns: RhaiHostFnRegistry,
kms: Option<Arc<dyn KmsProvider>>,
secrets: Option<Arc<SecretStore>>,
http: Option<Arc<dyn HttpEgress>>,
}
impl std::fmt::Debug for RhaiLoader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RhaiLoader")
.field("host_fn_count", &self.host_fns.len())
.finish()
}
}
impl RhaiLoader {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn host_fns_mut(&mut self) -> &mut RhaiHostFnRegistry {
&mut self.host_fns
}
#[must_use]
pub fn host_fns(&self) -> &RhaiHostFnRegistry {
&self.host_fns
}
#[must_use]
pub fn host_fn_count(&self) -> usize {
self.host_fns.len()
}
#[must_use]
pub fn with_kms(mut self, kms: Arc<dyn KmsProvider>) -> Self {
self.kms = Some(kms);
self
}
#[must_use]
pub fn with_secret_store(mut self, store: Arc<SecretStore>) -> Self {
self.secrets = Some(store);
self
}
#[must_use]
pub fn with_http(mut self, http: Arc<dyn HttpEgress>) -> Self {
self.http = Some(http);
self
}
#[must_use]
pub fn kms(&self) -> Option<Arc<dyn KmsProvider>> {
self.kms.clone()
}
#[must_use]
pub fn secret_store(&self) -> Option<Arc<SecretStore>> {
self.secrets.clone()
}
#[must_use]
pub fn http(&self) -> Option<Arc<dyn HttpEgress>> {
self.http.clone()
}
pub fn load(
&self,
script: &str,
registrar: &mut PluginRegistrar<'_>,
registrar_caps: &CapabilitySet,
) -> Result<LoadOutcome, RhaiError> {
let engine = build_engine(registrar_caps, &self.host_fns);
let ast = compile(&engine, script)?;
let manifest = parse_manifest(&engine, &ast)?;
let plugin_id = PluginId::new(manifest.id.clone());
let declared = derive_declared_capabilities(&manifest);
let (effective, denied) = intersect_caps(&declared, registrar_caps);
let runtime = RhaiPluginRuntime::new(plugin_id.clone(), engine, ast);
registrar.set_plugin_id(plugin_id.clone());
let mut scalars_registered = Vec::with_capacity(manifest.scalar_fns.len());
if effective.contains(&Capability::ScalarFn) {
for entry in &manifest.scalar_fns {
let sig = build_fn_signature(&entry.args, &entry.returns, &manifest.determinism)?;
let qname = QName::new(plugin_id.as_str(), entry.name.clone());
let adapter = if entry.vectorized {
RhaiScalarFn::new_vectorized(
Arc::clone(&runtime),
entry.name.clone(),
sig.clone(),
)
} else {
RhaiScalarFn::new(Arc::clone(&runtime), entry.name.clone(), sig.clone())
};
registrar
.scalar_fn(qname.clone(), sig, Arc::new(adapter))
.map_err(plugin_to_rhai_err)?;
scalars_registered.push(qname.to_string());
}
}
let mut aggregates_registered = Vec::with_capacity(manifest.aggregate_fns.len());
if effective.contains(&Capability::AggregateFn) {
for entry in &manifest.aggregate_fns {
let sig = build_agg_signature(&entry.args, &entry.returns, &manifest.determinism)?;
let qname = QName::new(plugin_id.as_str(), entry.name.clone());
let adapter =
RhaiAggregateFn::new(Arc::clone(&runtime), entry.name.clone(), sig.clone());
registrar
.aggregate_fn(qname.clone(), sig, Arc::new(adapter))
.map_err(plugin_to_rhai_err)?;
aggregates_registered.push(qname.to_string());
}
}
let mut procedures_registered = Vec::with_capacity(manifest.procedures.len());
if effective.contains(&Capability::Procedure) {
for entry in &manifest.procedures {
let sig = build_procedure_signature(entry)?;
let qname = QName::new(plugin_id.as_str(), entry.name.clone());
let adapter =
RhaiProcedure::new(Arc::clone(&runtime), entry.name.clone(), sig.clone());
registrar
.procedure(qname.clone(), sig, Arc::new(adapter))
.map_err(plugin_to_rhai_err)?;
procedures_registered.push(qname.to_string());
}
}
Ok(LoadOutcome {
plugin_id,
version: manifest.version,
effective_capabilities: effective,
denied_capabilities: denied,
scalars_registered,
aggregates_registered,
procedures_registered,
runtime,
})
}
}
fn build_procedure_signature(entry: &ProcedureEntry) -> Result<ProcedureSignature, RhaiError> {
let args: Vec<NamedArgType> = entry
.args
.iter()
.enumerate()
.map(|(i, t)| {
let ty = type_name_to_argtype(t)?;
Ok(NamedArgType {
name: format!("arg{i}").into(),
ty,
default: None,
doc: String::new(),
})
})
.collect::<Result<_, RhaiError>>()?;
let yields: Vec<Field> = entry
.yields
.iter()
.enumerate()
.map(|(i, t)| {
let dt = type_name_to_datatype(t)?;
Ok(Field::new(format!("col{i}"), dt, true))
})
.collect::<Result<_, RhaiError>>()?;
let mode = match entry.mode.trim().to_ascii_lowercase().as_str() {
"write" => ProcedureMode::Write,
"schema" => ProcedureMode::Schema,
"dbms" => ProcedureMode::Dbms,
_ => ProcedureMode::Read,
};
let side_effects = match mode {
ProcedureMode::Read => SideEffects::ReadOnly,
_ => SideEffects::Writes,
};
Ok(ProcedureSignature {
args,
yields,
mode,
side_effects,
retry_contract: None,
batch_input: None,
docs: String::new(),
})
}
fn derive_declared_capabilities(m: &RhaiManifest) -> CapabilitySet {
let mut set = CapabilitySet::new();
if !m.scalar_fns.is_empty() {
set.insert(Capability::ScalarFn);
}
if !m.aggregate_fns.is_empty() {
set.insert(Capability::AggregateFn);
}
if !m.procedures.is_empty() {
set.insert(Capability::Procedure);
}
set
}
fn intersect_caps(
declared: &CapabilitySet,
granted: &CapabilitySet,
) -> (CapabilitySet, Vec<Capability>) {
let effective = declared.intersect(granted);
let denied: Vec<Capability> = declared
.iter()
.filter(|c| !granted.contains(c))
.cloned()
.collect();
(effective, denied)
}
fn plugin_to_rhai_err(e: PluginError) -> RhaiError {
match e {
PluginError::DuplicateRegistration(q) => {
RhaiError::ManifestInvalid(format!("duplicate registration: {q}"))
}
PluginError::CapabilityRequired(c) => {
RhaiError::ManifestInvalid(format!("registrar caps missing: {c:?}"))
}
other => RhaiError::Internal(format!("registrar: {other}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use uni_plugin::PluginRegistry;
fn loader_with_caps() -> (RhaiLoader, CapabilitySet) {
let loader = RhaiLoader::new();
let caps = CapabilitySet::from_iter_of([
Capability::ScalarFn,
Capability::AggregateFn,
Capability::Procedure,
]);
(loader, caps)
}
#[test]
fn loads_minimal_scalar_plugin() {
let script = r#"
fn uni_manifest() {
#{
id: "ai.test.scalar",
version: "0.1.0",
scalar_fns: [
#{ name: "double", args: ["float"], returns: "float" },
],
}
}
fn double(x) { x * 2.0 }
"#;
let (loader, caps) = loader_with_caps();
let registry = PluginRegistry::new();
let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
let outcome = loader.load(script, &mut r, &caps).expect("loads");
assert_eq!(outcome.plugin_id.as_str(), "ai.test.scalar");
assert_eq!(outcome.scalars_registered.len(), 1);
assert!(outcome.denied_capabilities.is_empty());
r.commit_to_registry().expect("commits");
let q = QName::new("ai.test.scalar", "double");
assert!(registry.scalar_fn(&q).is_some());
}
#[test]
fn declared_but_not_granted_caps_show_as_denied() {
let script = r#"
fn uni_manifest() {
#{
id: "ai.test.denied",
version: "0.1.0",
scalar_fns: [
#{ name: "noop", args: [], returns: "int" },
],
aggregate_fns: [
#{ name: "agg", args: ["float"], returns: "float", state: "map" },
],
}
}
fn noop() { 0 }
"#;
let loader = RhaiLoader::new();
let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
let registry = PluginRegistry::new();
let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
let outcome = loader.load(script, &mut r, &caps).expect("loads");
assert!(
outcome
.denied_capabilities
.contains(&Capability::AggregateFn)
);
assert_eq!(outcome.scalars_registered.len(), 1);
}
#[test]
fn parse_failure_returns_parse_error() {
let script = r#"this is not valid rhai @@@"#;
let (loader, caps) = loader_with_caps();
let registry = PluginRegistry::new();
let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
let err = loader.load(script, &mut r, &caps).unwrap_err();
assert!(matches!(err, RhaiError::ParseFailed(_)));
}
}