use std::path::Path;
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::Error;
use crate::middleware::MiddlewareKind;
#[derive(Debug, Clone)]
pub struct PluginExport {
pub name: String,
pub kind: MiddlewareKind,
pub stateless: bool,
pub needs_body: bool,
pub inspects: Vec<String>,
}
#[derive(Debug)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub abi_version: String,
pub exports: Vec<PluginExport>,
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ModuleId(pub Arc<str>);
#[derive(Debug, Clone)]
pub enum ContextValue {
Text(String),
Bytes(Vec<u8>),
Int64(i64),
Uint64(u64),
Boolean(bool),
ListText(Vec<String>),
}
#[derive(Debug, Clone)]
pub struct ContextEntry {
pub path: String,
pub value: ContextValue,
}
#[derive(Debug, Clone)]
pub struct Header {
pub name: String,
pub value: String,
}
#[derive(Debug, Clone)]
pub struct BytesView {
pub data: Vec<u8>,
pub truncated: bool,
}
pub struct L4PeekInput {
pub peek: Vec<u8>,
pub context: Vec<ContextEntry>,
}
#[derive(Debug)]
pub enum L4PeekDecision {
Continue,
Close,
}
pub struct L4BytesInput {
pub bytes: BytesView,
pub context: Vec<ContextEntry>,
}
#[derive(Debug)]
pub enum L4BytesDecision {
Continue,
Tunnel,
Close,
}
pub struct L7RequestInput {
pub method: String,
pub uri: String,
pub headers: Vec<Header>,
pub body: Option<BytesView>,
pub context: Vec<ContextEntry>,
}
#[derive(Debug, Clone)]
pub struct SynthResponse {
pub status: u16,
pub headers: Vec<Header>,
pub body: Vec<u8>,
}
#[derive(Debug)]
pub enum L7RequestDecision {
Continue,
Short(SynthResponse),
Close,
}
pub struct L7ResponseInput {
pub status: u16,
pub headers: Vec<Header>,
pub body: Option<BytesView>,
pub context: Vec<ContextEntry>,
}
#[derive(Debug, Clone)]
pub struct ModifiedResponse {
pub status: Option<u16>,
pub headers: Option<Vec<Header>>,
pub body: Option<Vec<u8>>,
}
#[derive(Debug)]
pub enum L7ResponseDecision {
Continue,
Modify(ModifiedResponse),
Abort,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PluginError {
#[error("plugin {code}: {message}")]
Plugin { code: String, message: String, on_error_hint: Option<String> },
#[error("plugin trap: {0}")]
Trap(#[source] PluginTrap),
#[error("plugin pool exhausted: no instance available")]
Exhausted,
}
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct PluginTrap(pub String);
impl PluginTrap {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self(message.into())
}
}
impl PluginError {
#[must_use]
pub fn trap(message: impl Into<String>) -> Self {
Self::Trap(PluginTrap::new(message))
}
}
#[async_trait]
pub trait WasmRuntime: Send + Sync {
async fn load_component(&self, path: &Path) -> Result<Arc<PluginMetadata>, Error>;
async fn invoke_l4_peek(
&self,
module_id: &ModuleId,
export_name: &str,
args_json: &str,
input: L4PeekInput,
) -> Result<L4PeekDecision, PluginError>;
async fn invoke_l4_bytes(
&self,
module_id: &ModuleId,
export_name: &str,
args_json: &str,
input: L4BytesInput,
) -> Result<L4BytesDecision, PluginError>;
async fn invoke_l7_request(
&self,
module_id: &ModuleId,
export_name: &str,
args_json: &str,
input: L7RequestInput,
) -> Result<L7RequestDecision, PluginError>;
async fn invoke_l7_response(
&self,
module_id: &ModuleId,
export_name: &str,
args_json: &str,
input: L7ResponseInput,
) -> Result<L7ResponseDecision, PluginError>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WasmPoolSummary {
pub kind: String,
pub key: String,
pub export: String,
pub capacity: usize,
pub available: usize,
pub total_allocations: u64,
pub failures: u64,
}
pub trait WasmPoolStats: Send + Sync {
fn snapshot(&self) -> Vec<WasmPoolSummary>;
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct PluginHttpPolicy {
#[serde(default)]
pub allow_insecure: bool,
#[serde(default)]
pub allowed_hosts: Vec<String>,
#[serde(default = "default_max_body_size")]
pub max_body_size: u32,
#[serde(default = "default_timeout_ms")]
pub default_timeout_ms: u32,
#[serde(default = "default_follow_redirects")]
pub default_follow_redirects: u32,
}
const fn default_max_body_size() -> u32 {
1024 * 1024
}
const fn default_timeout_ms() -> u32 {
30_000
}
const fn default_follow_redirects() -> u32 {
5
}
impl Default for PluginHttpPolicy {
fn default() -> Self {
Self {
allow_insecure: false,
allowed_hosts: Vec::new(),
max_body_size: default_max_body_size(),
default_timeout_ms: default_timeout_ms(),
default_follow_redirects: default_follow_redirects(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PluginPolicyTable {
pub policies: std::collections::HashMap<String, PluginHttpPolicy>,
}
impl PluginPolicyTable {
#[must_use]
pub fn new() -> Self {
Self { policies: std::collections::HashMap::new() }
}
pub fn from_json(s: &str) -> Result<Self, Error> {
let policies: std::collections::HashMap<String, PluginHttpPolicy> =
serde_json::from_str(s).map_err(|e| Error::compile(format!("wasm/policy.json: {e}")))?;
Ok(Self { policies })
}
pub fn load_from_dir(wasm_dir: &std::path::Path) -> Result<Self, Error> {
let path = wasm_dir.join("policy.json");
match std::fs::read_to_string(&path) {
Ok(s) => Self::from_json(&s),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Self::new()),
Err(e) => Err(Error::compile(format!("wasm/policy.json: read {}: {e}", path.display()))),
}
}
#[must_use]
pub fn get_or_default(&self, stem: &str) -> PluginHttpPolicy {
self.policies.get(stem).cloned().unwrap_or_default()
}
}
#[cfg(test)]
mod policy_tests {
use super::*;
#[test]
fn default_policy_is_deny_all() {
let p = PluginHttpPolicy::default();
assert!(!p.allow_insecure);
assert!(p.allowed_hosts.is_empty(), "deny-all by default");
assert_eq!(p.max_body_size, 1024 * 1024);
assert_eq!(p.default_timeout_ms, 30_000);
assert_eq!(p.default_follow_redirects, 5);
}
#[test]
fn policy_table_round_trips_explicit_fields() {
let json = r#"{
"edge": {
"allow_insecure": true,
"allowed_hosts": ["api.internal", "*.example.com"],
"max_body_size": 65536,
"default_timeout_ms": 5000,
"default_follow_redirects": 0
}
}"#;
let t = PluginPolicyTable::from_json(json).expect("parse");
let p = t.get_or_default("edge");
assert!(p.allow_insecure);
assert_eq!(p.allowed_hosts, vec!["api.internal".to_string(), "*.example.com".to_string()]);
assert_eq!(p.max_body_size, 65_536);
assert_eq!(p.default_timeout_ms, 5000);
assert_eq!(p.default_follow_redirects, 0);
}
#[test]
fn policy_table_partial_entry_fills_defaults() {
let json = r#"{ "edge": { "allowed_hosts": ["x.y"] } }"#;
let t = PluginPolicyTable::from_json(json).expect("parse");
let p = t.get_or_default("edge");
assert_eq!(p.allowed_hosts, vec!["x.y".to_string()]);
assert_eq!(p.max_body_size, 1024 * 1024, "default fills");
assert_eq!(p.default_timeout_ms, 30_000);
}
#[test]
fn policy_table_missing_plugin_returns_deny_all_default() {
let t = PluginPolicyTable::from_json(r#"{ "other": {} }"#).expect("parse");
let p = t.get_or_default("missing");
assert_eq!(p, PluginHttpPolicy::default());
}
#[test]
fn policy_table_load_from_dir_handles_absent_file() {
let tmp = tempfile::tempdir().expect("tempdir");
let t = PluginPolicyTable::load_from_dir(tmp.path()).expect("absent ok");
assert!(t.policies.is_empty());
}
#[test]
fn policy_table_load_from_dir_parses_json() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("policy.json"), r#"{ "x": { "allowed_hosts": ["*"] } }"#)
.expect("write");
let t = PluginPolicyTable::load_from_dir(tmp.path()).expect("parse");
assert_eq!(t.get_or_default("x").allowed_hosts, vec!["*".to_string()]);
}
#[test]
fn policy_table_load_from_dir_propagates_parse_errors() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("policy.json"), "{ this is not json").expect("write");
let err = PluginPolicyTable::load_from_dir(tmp.path()).expect_err("must fail");
assert!(err.to_string().contains("policy.json"));
}
}