use std::sync::Arc;
use serde::Deserialize;
use wasmtime::component::{Component, Linker};
use wasmtime::{Config, Engine, Store};
use crate::adapter::ComponentScalarFn;
use crate::adapter_aggregate::ComponentAggregateFn;
use crate::adapter_procedure::ComponentProcedure;
use crate::bindings::aggregate::AggregatePlugin;
use crate::bindings::procedure::ProcedurePlugin as ProcedurePluginBindings;
use crate::bindings::scalar::ScalarPlugin;
use crate::error::WasmError;
use crate::host_state::HostState;
use crate::pool::WasmInstancePool;
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ComponentManifest {
pub id: String,
pub version: String,
#[serde(default)]
pub abi: Option<String>,
#[serde(default)]
pub capabilities: Vec<uni_plugin::ManifestCapability>,
#[serde(default)]
pub determinism: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub fuel_per_call: Option<u64>,
#[serde(default)]
pub memory_max_pages: Option<u32>,
#[serde(default)]
pub timeout_ms: Option<u64>,
}
impl ComponentManifest {
#[must_use]
pub fn declared_capability_set(&self) -> uni_plugin::CapabilitySet {
uni_plugin::CapabilitySet::from_manifest(self.capabilities.iter().cloned())
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct WireFnSignature {
pub args: Vec<WireArgType>,
pub returns: WireArgType,
#[serde(default = "default_volatility")]
pub volatility: String,
#[serde(default = "default_null_handling")]
pub null_handling: String,
}
fn default_volatility() -> String {
"immutable".to_owned()
}
fn default_null_handling() -> String {
"propagate".to_owned()
}
fn default_proc_mode() -> String {
"read".to_owned()
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
pub enum WireArgType {
Primitive {
arrow: String,
},
CypherValue,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
pub enum RegistrationEntry {
Scalar {
qname: String,
signature: WireFnSignature,
},
Aggregate {
qname: String,
signature: WireFnSignature,
state: WireArgType,
},
Procedure {
qname: String,
args: Vec<WireArgType>,
yields: Vec<WireArgType>,
#[serde(default = "default_proc_mode")]
mode: String,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RegistrationManifest {
pub entries: Vec<RegistrationEntry>,
}
#[derive(Clone)]
pub struct PreparedComponent {
pub manifest: ComponentManifest,
pub effective: uni_plugin::CapabilitySet,
pub denied_capabilities: Vec<String>,
pub http: Option<Arc<dyn uni_plugin::HttpEgress>>,
}
impl std::fmt::Debug for PreparedComponent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreparedComponent")
.field("manifest", &self.manifest)
.field("effective", &self.effective)
.field("denied_capabilities", &self.denied_capabilities)
.field("http", &self.http.is_some())
.finish()
}
}
pub struct ScalarPluginInstance {
store: Store<HostState>,
bindings: ScalarPlugin,
}
impl std::fmt::Debug for ScalarPluginInstance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScalarPluginInstance")
.finish_non_exhaustive()
}
}
trait WasmCallErr {
fn code(&self) -> u32;
fn message(&self) -> &str;
fn retryable(&self) -> bool;
}
macro_rules! impl_wasm_call_err {
($ty:ty) => {
impl WasmCallErr for $ty {
fn code(&self) -> u32 {
self.code
}
fn message(&self) -> &str {
&self.message
}
fn retryable(&self) -> bool {
self.retryable
}
}
};
}
impl_wasm_call_err!(crate::bindings::scalar::FnError);
impl_wasm_call_err!(crate::bindings::aggregate::FnError);
impl_wasm_call_err!(crate::bindings::procedure::FnError);
fn map_call<E: WasmCallErr>(
label: &str,
result: Result<Result<Vec<u8>, E>, wasmtime::Error>,
) -> Result<Vec<u8>, WasmError> {
match result {
Ok(Ok(bytes)) => Ok(bytes),
Ok(Err(fn_err)) => Err(WasmError::Invoke(format!(
"{label} fn-error code={} retryable={}: {}",
fn_err.code(),
fn_err.retryable(),
fn_err.message()
))),
Err(e) => Err(WasmError::Invoke(format!("{label} trap: {e}"))),
}
}
impl ScalarPluginInstance {
pub fn invoke_scalar(&mut self, qname: &str, ipc: &[u8]) -> Result<Vec<u8>, WasmError> {
let result = self
.bindings
.call_invoke_scalar(&mut self.store, qname, ipc);
map_call("invoke-scalar", result)
}
fn read_manifest(&mut self) -> Result<ComponentManifest, WasmError> {
let s = self
.bindings
.call_manifest(&mut self.store)
.map_err(|e| WasmError::Instantiate(format!("call manifest: {e}")))?;
serde_json::from_str(&s)
.map_err(|e| WasmError::InvalidWasm(format!("manifest json parse: {e}")))
}
fn read_register(&mut self) -> Result<RegistrationManifest, WasmError> {
let s = self
.bindings
.call_register(&mut self.store)
.map_err(|e| WasmError::Instantiate(format!("call register: {e}")))?;
serde_json::from_str(&s)
.map_err(|e| WasmError::InvalidWasm(format!("register json parse: {e}")))
}
}
pub struct AggregatePluginInstance {
store: Store<HostState>,
bindings: AggregatePlugin,
}
impl std::fmt::Debug for AggregatePluginInstance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AggregatePluginInstance")
.finish_non_exhaustive()
}
}
impl AggregatePluginInstance {
pub fn agg_new(&mut self, qname: &str) -> Result<Vec<u8>, WasmError> {
map_call(
"agg-new",
self.bindings.call_agg_new(&mut self.store, qname),
)
}
pub fn agg_update(
&mut self,
qname: &str,
state: &[u8],
values_ipc: &[u8],
) -> Result<Vec<u8>, WasmError> {
map_call(
"agg-update",
self.bindings
.call_agg_update(&mut self.store, qname, state, values_ipc),
)
}
pub fn agg_merge(
&mut self,
qname: &str,
state: &[u8],
other_states_ipc: &[u8],
) -> Result<Vec<u8>, WasmError> {
map_call(
"agg-merge",
self.bindings
.call_agg_merge(&mut self.store, qname, state, other_states_ipc),
)
}
pub fn agg_evaluate(&mut self, qname: &str, state: &[u8]) -> Result<Vec<u8>, WasmError> {
map_call(
"agg-evaluate",
self.bindings
.call_agg_evaluate(&mut self.store, qname, state),
)
}
}
pub struct ProcedurePluginInstance {
store: Store<HostState>,
bindings: ProcedurePluginBindings,
}
impl std::fmt::Debug for ProcedurePluginInstance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProcedurePluginInstance")
.finish_non_exhaustive()
}
}
impl ProcedurePluginInstance {
pub fn invoke_procedure(&mut self, qname: &str, args_ipc: &[u8]) -> Result<Vec<u8>, WasmError> {
map_call(
"invoke-procedure",
self.bindings
.call_invoke_procedure(&mut self.store, qname, args_ipc),
)
}
}
#[derive(Default)]
pub struct WasmLoader {
http: Option<Arc<dyn uni_plugin::HttpEgress>>,
}
impl std::fmt::Debug for WasmLoader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WasmLoader")
.field("http", &self.http.is_some())
.finish()
}
}
impl WasmLoader {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_http(mut self, http: Arc<dyn uni_plugin::HttpEgress>) -> Self {
self.http = Some(http);
self
}
fn bootstrap_prepared(&self, host_grants: &uni_plugin::CapabilitySet) -> PreparedComponent {
PreparedComponent {
manifest: ComponentManifest {
id: String::new(),
version: String::new(),
abi: None,
capabilities: Vec::new(),
determinism: None,
description: None,
fuel_per_call: None,
memory_max_pages: None,
timeout_ms: None,
},
effective: host_grants.clone(),
denied_capabilities: Vec::new(),
http: self.http.clone(),
}
}
pub fn prepare(
&self,
manifest_json: &[u8],
grants: &uni_plugin::CapabilitySet,
) -> Result<PreparedComponent, WasmError> {
let manifest: ComponentManifest = serde_json::from_slice(manifest_json)
.map_err(|e| WasmError::InvalidWasm(format!("manifest json parse: {e}")))?;
Ok(self.prepare_parsed(manifest, grants))
}
pub fn prepare_parsed(
&self,
manifest: ComponentManifest,
grants: &uni_plugin::CapabilitySet,
) -> PreparedComponent {
let declared = manifest.declared_capability_set();
let effective = declared.intersect(grants);
let denied: Vec<String> = declared
.iter()
.filter(|c| !effective.contains_variant(c))
.map(|c| format!("{c:?}"))
.collect();
PreparedComponent {
manifest,
effective,
denied_capabilities: denied,
http: self.http.clone(),
}
}
pub fn instantiate(
&self,
bytes: &[u8],
prepared: &PreparedComponent,
) -> Result<ScalarPluginInstance, WasmError> {
let engine = build_engine(&prepared.manifest)?;
let component = Component::from_binary(&engine, bytes)
.map_err(|e| WasmError::InvalidWasm(format!("component compile: {e}")))?;
let linker: Linker<HostState> =
select_linker_for_manifest(&engine, &prepared.manifest, &prepared.effective)?;
let mut store = Store::new(
&engine,
HostState::new(prepared.effective.clone(), prepared.http.clone()),
);
apply_resource_limits(&mut store, &prepared.manifest);
let bindings = ScalarPlugin::instantiate(&mut store, &component, &linker)
.map_err(|e| WasmError::Instantiate(format!("scalar-plugin instantiate: {e}")))?;
Ok(ScalarPluginInstance { store, bindings })
}
pub fn load(
&self,
bytes: &[u8],
host_grants: &uni_plugin::CapabilitySet,
registrar: &mut uni_plugin::PluginRegistrar<'_>,
) -> Result<LoadOutcome, WasmError> {
let bootstrap = self.bootstrap_prepared(host_grants);
let mut bootstrap_inst = self.instantiate(bytes, &bootstrap)?;
let parsed_manifest = bootstrap_inst.read_manifest()?;
drop(bootstrap_inst);
registrar.set_plugin_id(uni_plugin::PluginId::new(parsed_manifest.id.clone()));
let prepared = self.prepare_parsed(parsed_manifest, host_grants);
let pool = build_scalar_pool(bytes, &prepared)?;
let registration = {
let mut leased = crate::pool::PooledInstance::acquire(Arc::clone(&pool))
.map_err(|e| WasmError::Instantiate(format!("acquire warm instance: {e}")))?;
let r = leased.get_mut().read_register()?;
drop(leased);
r
};
let names = apply_registration(bytes, &prepared, &pool, registration, registrar)?;
Ok(LoadOutcome {
plugin_id: prepared.manifest.id.clone(),
version: prepared.manifest.version.clone(),
effective_capabilities: capability_names(&prepared.effective),
denied_capabilities: prepared.denied_capabilities,
scalars_registered: names.scalars,
aggregates_registered: names.aggregates,
procedures_registered: names.procedures,
pool,
})
}
pub fn load_as_plugin(
&self,
bytes: &[u8],
host_grants: &uni_plugin::CapabilitySet,
) -> Result<Box<dyn uni_plugin::Plugin + Send + Sync>, WasmError> {
let bootstrap = self.bootstrap_prepared(host_grants);
let mut bootstrap_inst = self.instantiate(bytes, &bootstrap)?;
let parsed_manifest = bootstrap_inst.read_manifest()?;
drop(bootstrap_inst);
let prepared = self.prepare_parsed(parsed_manifest, host_grants);
let scalar_pool = build_scalar_pool(bytes, &prepared)?;
let registration = {
let mut leased = crate::pool::PooledInstance::acquire(Arc::clone(&scalar_pool))
.map_err(|e| WasmError::Instantiate(format!("acquire warm instance: {e}")))?;
let r = leased.get_mut().read_register()?;
drop(leased);
r
};
let manifest = synthesize_plugin_manifest(&prepared.manifest, ®istration)?;
Ok(Box::new(ComponentPlugin {
manifest,
bytes: bytes.to_vec(),
prepared,
scalar_pool,
registration,
}))
}
}
pub struct LoadOutcome {
pub plugin_id: String,
pub version: String,
pub effective_capabilities: Vec<String>,
pub denied_capabilities: Vec<String>,
pub scalars_registered: Vec<String>,
pub aggregates_registered: Vec<String>,
pub procedures_registered: Vec<String>,
pub pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
}
impl std::fmt::Debug for LoadOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoadOutcome")
.field("plugin_id", &self.plugin_id)
.field("version", &self.version)
.field("effective_capabilities", &self.effective_capabilities)
.field("denied_capabilities", &self.denied_capabilities)
.field("scalars_registered", &self.scalars_registered)
.field("aggregates_registered", &self.aggregates_registered)
.field("procedures_registered", &self.procedures_registered)
.finish_non_exhaustive()
}
}
fn capability_names(caps: &uni_plugin::CapabilitySet) -> Vec<String> {
caps.iter().map(|c| format!("{c:?}")).collect()
}
fn select_linker_for_manifest(
engine: &Engine,
manifest: &ComponentManifest,
effective_caps: &uni_plugin::CapabilitySet,
) -> Result<Linker<HostState>, WasmError> {
use crate::linker::{build_scalar_linker_v1, build_scalar_linker_v2};
use crate::multi_version::{SUPPORTED_MAJORS, major_for_abi};
let Some(abi_str) = manifest.abi.as_deref() else {
return build_scalar_linker_v1(engine, effective_caps);
};
let abi = uni_plugin::AbiRange::parse(abi_str)
.map_err(|e| WasmError::InvalidWasm(format!("manifest abi parse: {e}")))?;
match major_for_abi(&abi)? {
1 => build_scalar_linker_v1(engine, effective_caps),
2 => build_scalar_linker_v2(engine, effective_caps),
_ => Err(WasmError::AbiUnsupported {
requested: abi_str.to_owned(),
supported: SUPPORTED_MAJORS.to_vec(),
}),
}
}
fn build_engine(manifest: &ComponentManifest) -> Result<Engine, WasmError> {
let mut cfg = Config::new();
cfg.wasm_component_model(true);
if manifest.fuel_per_call.is_some() {
cfg.consume_fuel(true);
}
if manifest.timeout_ms.is_some() {
cfg.epoch_interruption(true);
}
Engine::new(&cfg).map_err(|e| WasmError::Instantiate(format!("engine config: {e}")))
}
fn apply_resource_limits(store: &mut Store<HostState>, manifest: &ComponentManifest) {
if let Some(fuel) = manifest.fuel_per_call {
let _ = store.set_fuel(fuel);
}
if manifest.timeout_ms.is_some() {
store.set_epoch_deadline(1);
}
}
fn build_pool<I, F>(
bytes: &[u8],
prepared: &PreparedComponent,
build_instance: F,
) -> Result<Arc<WasmInstancePool<I>>, WasmError>
where
I: Send + 'static,
F: Fn(Store<HostState>, &Component, &Linker<HostState>) -> Result<I, WasmError>
+ Send
+ Sync
+ 'static,
{
let bytes_owned: Arc<Vec<u8>> = Arc::new(bytes.to_vec());
let prepared_owned: Arc<PreparedComponent> = Arc::new(prepared.clone());
let build_instance = Arc::new(build_instance);
let factory = {
let bytes = Arc::clone(&bytes_owned);
let prepared = Arc::clone(&prepared_owned);
let build_instance = Arc::clone(&build_instance);
move || -> Result<I, WasmError> {
let engine = build_engine(&prepared.manifest)?;
let component = Component::from_binary(&engine, &bytes)
.map_err(|e| WasmError::InvalidWasm(format!("component compile: {e}")))?;
let linker: Linker<HostState> =
select_linker_for_manifest(&engine, &prepared.manifest, &prepared.effective)?;
let mut store = Store::new(
&engine,
HostState::new(prepared.effective.clone(), prepared.http.clone()),
);
apply_resource_limits(&mut store, &prepared.manifest);
build_instance(store, &component, &linker)
}
};
let pool = WasmInstancePool::new(crate::pool::PoolConfig::default(), factory)?;
Ok(Arc::new(pool))
}
struct RegisteredQNames {
scalars: Vec<String>,
aggregates: Vec<String>,
procedures: Vec<String>,
}
fn apply_registration(
bytes: &[u8],
prepared: &PreparedComponent,
scalar_pool: &Arc<WasmInstancePool<ScalarPluginInstance>>,
registration: RegistrationManifest,
registrar: &mut uni_plugin::PluginRegistrar<'_>,
) -> Result<RegisteredQNames, WasmError> {
let mut scalars = Vec::new();
let mut aggregates = Vec::new();
let mut procedures = Vec::new();
let mut agg_pool: Option<Arc<WasmInstancePool<AggregatePluginInstance>>> = None;
let mut proc_pool: Option<Arc<WasmInstancePool<ProcedurePluginInstance>>> = None;
for entry in registration.entries {
match entry {
RegistrationEntry::Scalar { qname, signature } => {
let parsed_qname = uni_plugin::QName::parse(&qname)
.map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
let sig = wire_fn_sig_to_internal(&signature)?;
let adapter = Arc::new(ComponentScalarFn::new(
Arc::clone(scalar_pool),
parsed_qname.clone(),
sig.clone(),
));
registrar
.scalar_fn(parsed_qname, sig, adapter)
.map_err(|e| {
WasmError::Internal(format!("registrar.scalar_fn `{qname}`: {e}"))
})?;
scalars.push(qname);
}
RegistrationEntry::Aggregate {
qname,
signature,
state,
} => {
let parsed_qname = uni_plugin::QName::parse(&qname)
.map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
let sig = wire_agg_sig_to_internal(&signature, &state)?;
let pool_ref = match &agg_pool {
Some(p) => Arc::clone(p),
None => {
let p = build_aggregate_pool(bytes, prepared)?;
agg_pool = Some(Arc::clone(&p));
p
}
};
let adapter = Arc::new(ComponentAggregateFn::new(
pool_ref,
parsed_qname.clone(),
sig.clone(),
));
registrar
.aggregate_fn(parsed_qname, sig, adapter)
.map_err(|e| {
WasmError::Internal(format!("registrar.aggregate_fn `{qname}`: {e}"))
})?;
aggregates.push(qname);
}
RegistrationEntry::Procedure {
qname,
args,
yields,
mode,
} => {
let parsed_qname = uni_plugin::QName::parse(&qname)
.map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
let sig = wire_proc_sig_to_internal(&args, &yields, &mode)?;
let pool_ref = match &proc_pool {
Some(p) => Arc::clone(p),
None => {
let p = build_procedure_pool(bytes, prepared)?;
proc_pool = Some(Arc::clone(&p));
p
}
};
let adapter = Arc::new(ComponentProcedure::new(
pool_ref,
parsed_qname.clone(),
sig.clone(),
));
registrar
.procedure(parsed_qname, sig, adapter)
.map_err(|e| {
WasmError::Internal(format!("registrar.procedure `{qname}`: {e}"))
})?;
procedures.push(qname);
}
}
}
Ok(RegisteredQNames {
scalars,
aggregates,
procedures,
})
}
fn synthesize_plugin_manifest(
component: &ComponentManifest,
registration: &RegistrationManifest,
) -> Result<uni_plugin::PluginManifest, WasmError> {
use uni_plugin::{
AbiRange, Capability, CapabilitySet, Determinism, PluginId, ProvidedSurfaces, Scope,
SideEffects,
};
let version = semver::Version::parse(&component.version).map_err(|e| {
WasmError::InvalidWasm(format!("manifest version `{}`: {e}", component.version))
})?;
let abi = AbiRange::parse(component.abi.as_deref().unwrap_or("^1"))
.map_err(|e| WasmError::InvalidWasm(format!("manifest abi: {e}")))?;
let mut capabilities = CapabilitySet::new();
let mut side_effects = SideEffects::ReadOnly;
for entry in ®istration.entries {
match entry {
RegistrationEntry::Scalar { .. } => {
capabilities.insert(Capability::ScalarFn);
}
RegistrationEntry::Aggregate { .. } => {
capabilities.insert(Capability::AggregateFn);
}
RegistrationEntry::Procedure { mode, .. } => {
capabilities.insert(Capability::Procedure);
match mode.as_str() {
"write" => {
capabilities.insert(Capability::ProcedureWrites);
side_effects = SideEffects::Writes;
}
"schema" => {
capabilities.insert(Capability::ProcedureSchema);
side_effects = SideEffects::Writes;
}
"dbms" => {
capabilities.insert(Capability::ProcedureDbms);
}
_ => {}
}
}
}
}
let determinism = match component.determinism.as_deref() {
Some("pure") => Determinism::Pure,
Some("session-scoped" | "session_scoped") => Determinism::SessionScoped,
_ => Determinism::Nondeterministic,
};
Ok(uni_plugin::PluginManifest {
id: PluginId::new(component.id.clone()),
version,
abi,
depends_on: Vec::new(),
capabilities,
determinism,
side_effects,
scope: Scope::Instance,
hash: None,
signature: None,
provides: ProvidedSurfaces::default(),
docs: component.description.clone().unwrap_or_default(),
metadata: std::collections::BTreeMap::new(),
})
}
pub struct ComponentPlugin {
manifest: uni_plugin::PluginManifest,
bytes: Vec<u8>,
prepared: PreparedComponent,
scalar_pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
registration: RegistrationManifest,
}
impl std::fmt::Debug for ComponentPlugin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComponentPlugin")
.field("id", &self.manifest.id.as_str())
.field("scalars", &self.registration.entries.len())
.finish()
}
}
impl uni_plugin::Plugin for ComponentPlugin {
fn manifest(&self) -> &uni_plugin::PluginManifest {
&self.manifest
}
fn register(
&self,
r: &mut uni_plugin::PluginRegistrar<'_>,
) -> Result<(), uni_plugin::PluginError> {
apply_registration(
&self.bytes,
&self.prepared,
&self.scalar_pool,
self.registration.clone(),
r,
)
.map_err(|e| {
uni_plugin::PluginError::WasmInstantiate(format!("component register: {e}"))
})?;
Ok(())
}
}
fn build_scalar_pool(
bytes: &[u8],
prepared: &PreparedComponent,
) -> Result<Arc<WasmInstancePool<ScalarPluginInstance>>, WasmError> {
build_pool(bytes, prepared, |mut store, component, linker| {
let bindings = ScalarPlugin::instantiate(&mut store, component, linker)
.map_err(|e| WasmError::Instantiate(format!("scalar-plugin instantiate: {e}")))?;
Ok(ScalarPluginInstance { store, bindings })
})
}
fn build_aggregate_pool(
bytes: &[u8],
prepared: &PreparedComponent,
) -> Result<Arc<WasmInstancePool<AggregatePluginInstance>>, WasmError> {
build_pool(bytes, prepared, |mut store, component, linker| {
let bindings = AggregatePlugin::instantiate(&mut store, component, linker)
.map_err(|e| WasmError::Instantiate(format!("aggregate-plugin instantiate: {e}")))?;
Ok(AggregatePluginInstance { store, bindings })
})
}
fn build_procedure_pool(
bytes: &[u8],
prepared: &PreparedComponent,
) -> Result<Arc<WasmInstancePool<ProcedurePluginInstance>>, WasmError> {
build_pool(bytes, prepared, |mut store, component, linker| {
let bindings = ProcedurePluginBindings::instantiate(&mut store, component, linker)
.map_err(|e| WasmError::Instantiate(format!("procedure-plugin instantiate: {e}")))?;
Ok(ProcedurePluginInstance { store, bindings })
})
}
fn wire_arg(w: &WireArgType) -> Result<uni_plugin::traits::scalar::ArgType, WasmError> {
use uni_plugin::traits::scalar::ArgType;
Ok(match w {
WireArgType::Primitive { arrow } => ArgType::Primitive(arrow_name_to_dt(arrow)?),
WireArgType::CypherValue => ArgType::CypherValue,
})
}
fn parse_volatility(s: &str) -> Result<datafusion::logical_expr::Volatility, WasmError> {
use datafusion::logical_expr::Volatility;
Ok(match s {
"immutable" => Volatility::Immutable,
"stable" => Volatility::Stable,
"volatile" => Volatility::Volatile,
other => {
return Err(WasmError::InvalidWasm(format!(
"unsupported volatility: `{other}`"
)));
}
})
}
fn parse_null_handling(s: &str) -> Result<uni_plugin::traits::scalar::NullHandling, WasmError> {
use uni_plugin::traits::scalar::NullHandling;
Ok(match s {
"propagate" => NullHandling::PropagateNulls,
"user_handled" => NullHandling::UserHandled,
other => {
return Err(WasmError::InvalidWasm(format!(
"unsupported null_handling: `{other}`"
)));
}
})
}
fn parse_proc_mode(s: &str) -> Result<uni_plugin::traits::procedure::ProcedureMode, WasmError> {
use uni_plugin::traits::procedure::ProcedureMode;
Ok(match s {
"read" => ProcedureMode::Read,
"write" => ProcedureMode::Write,
"schema" => ProcedureMode::Schema,
"dbms" => ProcedureMode::Dbms,
other => {
return Err(WasmError::InvalidWasm(format!(
"unsupported procedure mode: `{other}`"
)));
}
})
}
fn wire_agg_sig_to_internal(
wire_sig: &WireFnSignature,
wire_state: &WireArgType,
) -> Result<uni_plugin::traits::aggregate::AggSignature, WasmError> {
use arrow_schema::Field;
use uni_plugin::traits::aggregate::AggSignature;
let internal = wire_fn_sig_to_internal(wire_sig)?;
let state_field = match wire_state {
WireArgType::Primitive { arrow } => {
let dt = arrow_name_to_dt(arrow)?;
Field::new("state", dt, true)
}
_ => {
return Err(WasmError::InvalidWasm(
"aggregate state must be a Primitive Arrow type".to_owned(),
));
}
};
Ok(AggSignature {
volatility: internal.volatility,
args: internal.args,
returns: internal.returns,
state_fields: vec![state_field],
supports_partial: true,
})
}
fn wire_proc_sig_to_internal(
args: &[WireArgType],
yields: &[WireArgType],
mode: &str,
) -> Result<uni_plugin::traits::procedure::ProcedureSignature, WasmError> {
use arrow_schema::Field;
use uni_plugin::capability::SideEffects;
use uni_plugin::traits::procedure::{NamedArgType, ProcedureSignature};
use uni_plugin::traits::scalar::ArgType;
let named_args: Vec<NamedArgType> = args
.iter()
.enumerate()
.map(|(i, w)| {
let ty = wire_arg(w)?;
Ok::<NamedArgType, WasmError>(NamedArgType {
name: format!("arg{i}").into(),
ty,
default: None,
doc: String::new(),
})
})
.collect::<Result<_, _>>()?;
let yield_fields: Vec<Field> = yields
.iter()
.enumerate()
.map(|(i, w)| {
let ty = wire_arg(w)?;
let dt = match ty {
ArgType::Primitive(d) => d,
ArgType::CypherValue | ArgType::Variadic(_) => arrow_schema::DataType::LargeBinary,
ArgType::Vector { element, .. } => element,
};
Ok::<Field, WasmError>(Field::new(format!("yield{i}"), dt, true))
})
.collect::<Result<_, _>>()?;
Ok(ProcedureSignature {
args: named_args,
yields: yield_fields,
mode: parse_proc_mode(mode)?,
side_effects: SideEffects::default(),
retry_contract: None,
batch_input: None,
docs: String::new(),
})
}
fn arrow_name_to_dt(name: &str) -> Result<arrow_schema::DataType, WasmError> {
uni_plugin::adapter_common::arrow_types::arrow_name_to_datatype(name)
.ok_or_else(|| WasmError::InvalidWasm(format!("unsupported arrow primitive: `{name}`")))
}
fn wire_fn_sig_to_internal(
wire: &WireFnSignature,
) -> Result<uni_plugin::traits::scalar::FnSignature, WasmError> {
use uni_plugin::traits::scalar::{ArgType, FnSignature};
let args: Vec<ArgType> = wire.args.iter().map(wire_arg).collect::<Result<_, _>>()?;
Ok(FnSignature {
args,
returns: wire_arg(&wire.returns)?,
volatility: parse_volatility(&wire.volatility)?,
null_handling: parse_null_handling(&wire.null_handling)?,
})
}
#[cfg(test)]
mod tests {
use super::*;
use uni_plugin::{Capability, CapabilitySet};
fn manifest_json(caps: &[&str]) -> String {
let caps_json: Vec<String> = caps.iter().map(|c| format!("\"{c}\"")).collect();
format!(
r#"{{ "id": "ai.example.test", "version": "1.0.0", "capabilities": [{}] }}"#,
caps_json.join(", ")
)
}
#[test]
fn loader_constructs() {
let _ = WasmLoader::new();
}
#[test]
fn prepare_parses_minimal_manifest() {
let l = WasmLoader::new();
let json = manifest_json(&[]);
let prep = l.prepare(json.as_bytes(), &CapabilitySet::new()).unwrap();
assert_eq!(prep.manifest.id, "ai.example.test");
assert!(prep.effective.is_empty());
}
#[test]
fn prepare_intersects_capabilities() {
let l = WasmLoader::new();
let json = manifest_json(&["filesystem", "network", "kms"]);
let grants = CapabilitySet::from_iter_of([
Capability::Filesystem {
read: vec![],
write: vec![],
},
Capability::Network { allow: vec![] },
]);
let prep = l.prepare(json.as_bytes(), &grants).unwrap();
assert_eq!(prep.effective.len(), 2);
assert!(
prep.effective
.contains_variant(&Capability::Network { allow: vec![] })
);
assert!(
!prep
.effective
.contains_variant(&Capability::Kms { key_ids: vec![] })
);
}
#[test]
fn prepare_carries_structured_network_allowlist() {
let l = WasmLoader::new();
let json = r#"{ "id": "a.b", "version": "1.0.0",
"capabilities": [{"kind":"network","allow":["https://api.example/**"]}] }"#;
let grants = CapabilitySet::from_iter_of([Capability::Network {
allow: vec!["https://api.example/**".into()],
}]);
let prep = l.prepare(json.as_bytes(), &grants).unwrap();
assert!(
prep.effective
.iter()
.any(|c| c.network_allows("https://api.example/v1/x"))
);
assert!(
!prep
.effective
.iter()
.any(|c| c.network_allows("https://evil.example/x"))
);
}
#[test]
fn prepare_rejects_malformed_manifest() {
let l = WasmLoader::new();
let err = l.prepare(b"not json", &CapabilitySet::new()).unwrap_err();
assert!(matches!(err, WasmError::InvalidWasm(_)));
}
#[test]
fn instantiate_rejects_garbage_bytes() {
let l = WasmLoader::new();
let prep = l
.prepare(
b"{\"id\":\"a.b\",\"version\":\"0.0.0\"}",
&CapabilitySet::new(),
)
.unwrap();
let err = l.instantiate(b"not real wasm", &prep).unwrap_err();
assert!(matches!(err, WasmError::InvalidWasm(_)));
}
}