use serde::{Deserialize, Serialize};
use crate::error::ExtismError;
use crate::loader::ExtismPluginManifest;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields)]
pub struct WireFnSignature {
pub args: Vec<WireArgType>,
pub returns: WireArgType,
#[serde(default = "default_volatility")]
pub volatility: String,
#[serde(default = "default_null_handling")]
pub null_handling: String,
}
fn default_volatility() -> String {
"immutable".to_owned()
}
fn default_null_handling() -> String {
"propagate".to_owned()
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
pub enum WireArgType {
Primitive {
arrow: String,
},
CypherValue,
Vector {
len: usize,
element: String,
},
Variadic {
inner: Box<WireArgType>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
pub enum RegistrationEntry {
Scalar {
qname: String,
signature: WireFnSignature,
},
Aggregate {
qname: String,
signature: WireFnSignature,
state: WireArgType,
},
Procedure {
qname: String,
args: Vec<WireArgType>,
yields: Vec<WireArgType>,
#[serde(default = "default_proc_mode")]
mode: String,
},
}
fn default_proc_mode() -> String {
"read".to_owned()
}
impl RegistrationEntry {
#[must_use]
pub fn qname(&self) -> &str {
match self {
Self::Scalar { qname, .. }
| Self::Aggregate { qname, .. }
| Self::Procedure { qname, .. } => qname,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields)]
pub struct RegistrationManifest {
pub entries: Vec<RegistrationEntry>,
}
pub fn parse_manifest_json(bytes: &[u8]) -> Result<ExtismPluginManifest, ExtismError> {
serde_json::from_slice(bytes)
.map_err(|e| ExtismError::ManifestInvalid(format!("json parse: {e}")))
}
pub fn parse_registration_json(bytes: &[u8]) -> Result<RegistrationManifest, ExtismError> {
serde_json::from_slice(bytes)
.map_err(|e| ExtismError::OutputDecode(format!("register json parse: {e}")))
}
pub fn read_manifest_export(
plugin: &mut extism::Plugin,
) -> Result<ExtismPluginManifest, ExtismError> {
if !plugin.function_exists("manifest") {
return Err(ExtismError::InvalidPlugin(
"plugin does not export required `manifest` function".to_owned(),
));
}
let bytes: &[u8] = plugin
.call("manifest", "")
.map_err(|e| ExtismError::InvalidPlugin(format!("call manifest: {e}")))?;
parse_manifest_json(bytes)
}
pub fn read_register_export(
plugin: &mut extism::Plugin,
) -> Result<RegistrationManifest, ExtismError> {
if !plugin.function_exists("register") {
return Err(ExtismError::InvalidPlugin(
"plugin does not export required `register` function".to_owned(),
));
}
let bytes: &[u8] = plugin
.call("register", "")
.map_err(|e| ExtismError::InvalidPlugin(format!("call register: {e}")))?;
parse_registration_json(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_minimal_manifest() {
let json = br#"{"id":"a.b","version":"0.0.1"}"#;
let m = parse_manifest_json(json).unwrap();
assert_eq!(m.id, "a.b");
assert_eq!(m.version, "0.0.1");
assert!(m.capabilities.is_empty());
assert!(m.fuel_per_call.is_none());
}
#[test]
fn parses_manifest_with_resource_limits() {
let json = br#"{
"id": "a.b",
"version": "0.0.1",
"capabilities": ["filesystem"],
"fuel_per_call": 1000,
"memory_max_pages": 4,
"timeout_ms": 500
}"#;
let m = parse_manifest_json(json).unwrap();
assert_eq!(m.fuel_per_call, Some(1000));
assert_eq!(m.memory_max_pages, Some(4));
assert_eq!(m.timeout_ms, Some(500));
assert!(m.declared_capability_set().contains_variant(
&uni_plugin::Capability::Filesystem {
read: vec![],
write: vec![],
}
));
}
#[test]
fn rejects_unknown_manifest_field() {
let json = br#"{"id":"a.b","version":"0.0.1","mystery":"surprise"}"#;
let err = parse_manifest_json(json).unwrap_err();
assert!(matches!(err, ExtismError::ManifestInvalid(_)));
}
#[test]
fn parses_empty_registration() {
let json = br#"{"entries":[]}"#;
let r = parse_registration_json(json).unwrap();
assert!(r.entries.is_empty());
}
#[test]
fn parses_scalar_registration_entry() {
let json = br#"{
"entries": [{
"kind": "scalar",
"qname": "geo.haversine",
"signature": {
"args": [
{"kind":"primitive","arrow":"float64"},
{"kind":"primitive","arrow":"float64"},
{"kind":"primitive","arrow":"float64"},
{"kind":"primitive","arrow":"float64"}
],
"returns": {"kind":"primitive","arrow":"float64"}
}
}]
}"#;
let r = parse_registration_json(json).unwrap();
assert_eq!(r.entries.len(), 1);
match &r.entries[0] {
RegistrationEntry::Scalar { qname, signature } => {
assert_eq!(qname, "geo.haversine");
assert_eq!(signature.args.len(), 4);
assert_eq!(signature.volatility, "immutable");
assert_eq!(signature.null_handling, "propagate");
assert!(matches!(
signature.returns,
WireArgType::Primitive { ref arrow } if arrow == "float64"
));
}
other => panic!("expected Scalar, got: {other:?}"),
}
}
#[test]
fn parses_aggregate_registration_entry() {
let json = br#"{
"entries": [{
"kind": "aggregate",
"qname": "stats.weighted_mean",
"signature": {
"args": [
{"kind":"primitive","arrow":"float64"},
{"kind":"primitive","arrow":"float64"}
],
"returns": {"kind":"primitive","arrow":"float64"},
"volatility": "stable"
},
"state": {"kind":"primitive","arrow":"binary"}
}]
}"#;
let r = parse_registration_json(json).unwrap();
match &r.entries[0] {
RegistrationEntry::Aggregate {
qname,
signature,
state,
} => {
assert_eq!(qname, "stats.weighted_mean");
assert_eq!(signature.volatility, "stable");
assert!(matches!(state, WireArgType::Primitive { arrow } if arrow == "binary"));
}
other => panic!("expected Aggregate, got: {other:?}"),
}
}
#[test]
fn parses_procedure_registration_entry() {
let json = br#"{
"entries": [{
"kind": "procedure",
"qname": "myorg.scan",
"args": [{"kind":"primitive","arrow":"utf8"}],
"yields": [
{"kind":"primitive","arrow":"int64"},
{"kind":"cypher_value"}
],
"mode": "write"
}]
}"#;
let r = parse_registration_json(json).unwrap();
match &r.entries[0] {
RegistrationEntry::Procedure {
qname,
args,
yields,
mode,
} => {
assert_eq!(qname, "myorg.scan");
assert_eq!(args.len(), 1);
assert_eq!(yields.len(), 2);
assert_eq!(mode, "write");
assert!(matches!(yields[1], WireArgType::CypherValue));
}
other => panic!("expected Procedure, got: {other:?}"),
}
}
#[test]
fn procedure_mode_defaults_to_read() {
let json = br#"{
"entries": [{
"kind": "procedure",
"qname": "myorg.scan",
"args": [],
"yields": []
}]
}"#;
let r = parse_registration_json(json).unwrap();
match &r.entries[0] {
RegistrationEntry::Procedure { mode, .. } => assert_eq!(mode, "read"),
_ => unreachable!(),
}
}
#[test]
fn registration_entry_exposes_qname() {
let e = RegistrationEntry::Scalar {
qname: "x.y".to_owned(),
signature: WireFnSignature {
args: vec![],
returns: WireArgType::CypherValue,
volatility: "immutable".to_owned(),
null_handling: "propagate".to_owned(),
},
};
assert_eq!(e.qname(), "x.y");
}
#[test]
fn rejects_unknown_registration_kind() {
let json = br#"{"entries":[{"kind":"telegraphic","qname":"x"}]}"#;
let err = parse_registration_json(json).unwrap_err();
assert!(matches!(err, ExtismError::OutputDecode(_)));
}
#[test]
fn parses_vector_and_variadic_argtypes() {
let json = br#"{
"entries": [{
"kind": "scalar",
"qname": "vec.norm",
"signature": {
"args": [
{"kind":"vector","len":128,"element":"float32"},
{"kind":"variadic","inner":{"kind":"primitive","arrow":"int64"}}
],
"returns": {"kind":"primitive","arrow":"float32"}
}
}]
}"#;
let r = parse_registration_json(json).unwrap();
match &r.entries[0] {
RegistrationEntry::Scalar { signature, .. } => {
assert!(matches!(
signature.args[0],
WireArgType::Vector { len: 128, ref element } if element == "float32"
));
assert!(matches!(
signature.args[1],
WireArgType::Variadic { ref inner } if matches!(
**inner,
WireArgType::Primitive { ref arrow } if arrow == "int64"
)
));
}
_ => unreachable!(),
}
}
}