pub mod config;
pub mod runtime;
pub mod loader;
pub mod host_functions;
pub mod host_imports;
pub mod sandbox;
pub mod hot_reload;
pub mod metrics;
pub use config::{PluginRuntimeConfig, PluginRuntimeConfigBuilder, PluginConfig};
pub use runtime::{WasmPluginRuntime, LoadedPlugin, PluginState, PluginError};
pub use loader::{PluginLoader, PluginManifest, PluginLoadError, SignatureVerifier};
pub use host_functions::HostFunctionRegistry;
pub use sandbox::{PluginSandbox, SecurityPolicy, Permission, ResourceLimits};
pub use hot_reload::{HotReloader, ReloadEvent, ReloadError};
pub use metrics::{PluginMetrics, PluginStats, HookLatency};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use dashmap::DashMap;
#[derive(Debug, Clone)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub description: String,
pub author: String,
pub hooks: Vec<HookType>,
pub permissions: Vec<Permission>,
pub min_memory: usize,
pub max_memory: usize,
}
impl Default for PluginMetadata {
fn default() -> Self {
Self {
name: String::new(),
version: "0.0.0".to_string(),
description: String::new(),
author: String::new(),
hooks: Vec::new(),
permissions: Vec::new(),
min_memory: 1024 * 1024, max_memory: 64 * 1024 * 1024, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum HookType {
PreQuery,
PostQuery,
Authenticate,
Authorize,
CacheGet,
CacheSet,
Route,
Rewrite,
Metrics,
OnConnect,
OnDisconnect,
Custom,
}
impl HookType {
pub fn export_name(&self) -> &'static str {
match self {
HookType::PreQuery => "pre_query",
HookType::PostQuery => "post_query",
HookType::Authenticate => "authenticate",
HookType::Authorize => "authorize",
HookType::CacheGet => "cache_get",
HookType::CacheSet => "cache_set",
HookType::Route => "route",
HookType::Rewrite => "rewrite",
HookType::Metrics => "metrics",
HookType::OnConnect => "on_connect",
HookType::OnDisconnect => "on_disconnect",
HookType::Custom => "custom_hook",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"pre_query" | "prequery" => Some(HookType::PreQuery),
"post_query" | "postquery" => Some(HookType::PostQuery),
"authenticate" | "auth" => Some(HookType::Authenticate),
"authorize" => Some(HookType::Authorize),
"cache_get" | "cacheget" => Some(HookType::CacheGet),
"cache_set" | "cacheset" => Some(HookType::CacheSet),
"route" | "routing" => Some(HookType::Route),
"rewrite" => Some(HookType::Rewrite),
"metrics" => Some(HookType::Metrics),
"on_connect" | "connect" => Some(HookType::OnConnect),
"on_disconnect" | "disconnect" => Some(HookType::OnDisconnect),
"custom" => Some(HookType::Custom),
_ => None,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct HookContext {
pub request_id: String,
pub client_id: Option<String>,
pub identity: Option<String>,
pub database: Option<String>,
pub branch: Option<String>,
pub attributes: HashMap<String, String>,
}
impl Default for HookContext {
fn default() -> Self {
Self {
request_id: uuid::Uuid::new_v4().to_string(),
client_id: None,
identity: None,
database: None,
branch: None,
attributes: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct QueryContext {
pub query: String,
pub normalized: String,
pub tables: Vec<String>,
pub is_read_only: bool,
pub hook_context: HookContext,
}
#[derive(Debug, Clone)]
pub enum PreQueryResult {
Continue,
Rewrite(String),
Block(String),
Cached(Vec<u8>),
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct PostQueryOutcome {
pub success: bool,
pub target_node: Option<String>,
pub elapsed_us: u64,
pub response_bytes: u64,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub enum AuthResult {
Success(Identity),
Denied(String),
Defer,
}
#[derive(Debug, Clone)]
pub struct Identity {
pub user_id: String,
pub username: String,
pub roles: Vec<String>,
pub tenant_id: Option<String>,
pub claims: HashMap<String, String>,
}
impl Default for Identity {
fn default() -> Self {
Self {
user_id: String::new(),
username: String::new(),
roles: Vec::new(),
tenant_id: None,
claims: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub enum RouteResult {
Default,
Node(String),
Primary,
Standby,
Branch(String),
Block(String),
}
pub struct PluginManager {
runtime: Arc<WasmPluginRuntime>,
plugins: DashMap<String, Arc<LoadedPlugin>>,
hooks: RwLock<HashMap<HookType, Vec<String>>>,
config: PluginRuntimeConfig,
hot_reloader: Option<HotReloader>,
metrics: Arc<PluginMetrics>,
}
impl PluginManager {
pub fn new(config: PluginRuntimeConfig) -> Result<Self, PluginError> {
let runtime = Arc::new(WasmPluginRuntime::new(&config)?);
let metrics = Arc::new(PluginMetrics::new());
let hot_reloader = if config.hot_reload {
Some(HotReloader::new(&config.plugin_dir)?)
} else {
None
};
Ok(Self {
runtime,
plugins: DashMap::new(),
hooks: RwLock::new(HashMap::new()),
config,
hot_reloader,
metrics,
})
}
pub fn load_plugin(&self, path: &std::path::Path) -> Result<(), PluginError> {
let mut loader = PluginLoader::new();
if let Some(ref dir) = self.runtime.config().trust_root {
let verifier = SignatureVerifier::from_trust_root(dir)
.map_err(|e| PluginError::LoadError(e.to_string()))?;
loader = loader.with_signature_verifier(verifier);
}
let (manifest, wasm_bytes) = loader.load(path)?;
let plugin = self.runtime.instantiate(&manifest, &wasm_bytes)?;
let plugin = Arc::new(plugin);
{
let mut hooks = self.hooks.write();
for hook in &manifest.hooks {
hooks
.entry(*hook)
.or_insert_with(Vec::new)
.push(manifest.name.clone());
}
}
self.plugins.insert(manifest.name.clone(), plugin);
tracing::info!(
plugin = %manifest.name,
version = %manifest.version,
hooks = ?manifest.hooks,
"Plugin loaded"
);
Ok(())
}
pub fn unload_plugin(&self, name: &str) -> Result<(), PluginError> {
if let Some((_, plugin)) = self.plugins.remove(name) {
let mut hooks = self.hooks.write();
for hook_plugins in hooks.values_mut() {
hook_plugins.retain(|p| p != name);
}
if let Err(e) = self.runtime.call_hook(&plugin, HookType::OnDisconnect, &[]) {
tracing::warn!(plugin = %name, error = %e, "Error calling on_unload");
}
tracing::info!(plugin = %name, "Plugin unloaded");
}
Ok(())
}
pub fn reload_plugin(&self, name: &str) -> Result<(), PluginError> {
if let Some(plugin) = self.plugins.get(name) {
let path = plugin.path.clone();
drop(plugin);
self.unload_plugin(name)?;
self.load_plugin(&path)?;
}
Ok(())
}
pub fn execute_pre_query(&self, ctx: &QueryContext) -> PreQueryResult {
let hooks = self.hooks.read();
let plugin_names = hooks.get(&HookType::PreQuery).cloned().unwrap_or_default();
drop(hooks);
for plugin_name in plugin_names {
if let Some(plugin) = self.plugins.get(&plugin_name) {
let start = std::time::Instant::now();
match self.runtime.call_pre_query(&plugin, ctx) {
Ok(result) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::PreQuery,
start.elapsed(),
true,
);
match result {
PreQueryResult::Continue => continue,
other => return other,
}
}
Err(e) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::PreQuery,
start.elapsed(),
false,
);
tracing::warn!(
plugin = %plugin_name,
error = %e,
"Pre-query hook failed"
);
}
}
}
}
PreQueryResult::Continue
}
pub fn execute_post_query(&self, ctx: &QueryContext, outcome: &PostQueryOutcome) {
let hooks = self.hooks.read();
let plugin_names = hooks.get(&HookType::PostQuery).cloned().unwrap_or_default();
drop(hooks);
for plugin_name in plugin_names {
if let Some(plugin) = self.plugins.get(&plugin_name) {
let start = std::time::Instant::now();
let payload = match serde_json::to_vec(&(ctx, outcome)) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
plugin = %plugin_name,
error = %e,
"Post-query serialisation failed"
);
continue;
}
};
match self.runtime.call_hook(&plugin, HookType::PostQuery, &payload) {
Ok(_) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::PostQuery,
start.elapsed(),
true,
);
}
Err(e) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::PostQuery,
start.elapsed(),
false,
);
tracing::warn!(
plugin = %plugin_name,
error = %e,
"Post-query hook failed"
);
}
}
}
}
}
pub fn execute_authenticate(&self, request: &AuthRequest) -> AuthResult {
let hooks = self.hooks.read();
let plugin_names = hooks.get(&HookType::Authenticate).cloned().unwrap_or_default();
drop(hooks);
for plugin_name in plugin_names {
if let Some(plugin) = self.plugins.get(&plugin_name) {
let start = std::time::Instant::now();
match self.runtime.call_authenticate(&plugin, request) {
Ok(result) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::Authenticate,
start.elapsed(),
true,
);
match result {
AuthResult::Defer => continue,
other => return other,
}
}
Err(e) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::Authenticate,
start.elapsed(),
false,
);
tracing::warn!(
plugin = %plugin_name,
error = %e,
"Authenticate hook failed"
);
}
}
}
}
AuthResult::Defer
}
pub fn execute_route(&self, ctx: &QueryContext) -> RouteResult {
let hooks = self.hooks.read();
let plugin_names = hooks.get(&HookType::Route).cloned().unwrap_or_default();
drop(hooks);
for plugin_name in plugin_names {
if let Some(plugin) = self.plugins.get(&plugin_name) {
let start = std::time::Instant::now();
match self.runtime.call_route(&plugin, ctx) {
Ok(result) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::Route,
start.elapsed(),
true,
);
match result {
RouteResult::Default => continue,
other => return other,
}
}
Err(e) => {
self.metrics.record_hook_call(
&plugin_name,
HookType::Route,
start.elapsed(),
false,
);
tracing::warn!(
plugin = %plugin_name,
error = %e,
"Route hook failed"
);
}
}
}
}
RouteResult::Default
}
pub fn list_plugins(&self) -> Vec<PluginInfo> {
self.plugins
.iter()
.map(|entry| {
let plugin = entry.value();
let stats = self.metrics.get_plugin_stats(&plugin.metadata.name);
PluginInfo {
name: plugin.metadata.name.clone(),
version: plugin.metadata.version.clone(),
description: plugin.metadata.description.clone(),
hooks: plugin.metadata.hooks.clone(),
state: plugin.state.clone(),
stats,
}
})
.collect()
}
pub fn get_metrics(&self) -> PluginManagerMetrics {
PluginManagerMetrics {
plugins_loaded: self.plugins.len(),
total_hook_calls: self.metrics.total_calls(),
total_errors: self.metrics.total_errors(),
avg_latency: self.metrics.avg_latency(),
plugins: self.list_plugins(),
}
}
pub fn check_updates(&self) -> Result<Vec<ReloadEvent>, PluginError> {
if let Some(ref reloader) = self.hot_reloader {
let events = reloader.check()?;
for event in &events {
match event {
ReloadEvent::Modified(name) => {
tracing::info!(plugin = %name, "Hot reloading plugin");
if let Err(e) = self.reload_plugin(name) {
tracing::error!(plugin = %name, error = %e, "Hot reload failed");
}
}
ReloadEvent::Removed(name) => {
tracing::info!(plugin = %name, "Plugin file removed, unloading");
if let Err(e) = self.unload_plugin(name) {
tracing::error!(plugin = %name, error = %e, "Unload failed");
}
}
ReloadEvent::Added(path) => {
tracing::info!(path = %path.display(), "New plugin detected, loading");
if let Err(e) = self.load_plugin(path) {
tracing::error!(path = %path.display(), error = %e, "Load failed");
}
}
}
}
Ok(events)
} else {
Ok(Vec::new())
}
}
}
#[derive(Debug, Clone)]
pub struct AuthRequest {
pub headers: HashMap<String, String>,
pub username: Option<String>,
pub password: Option<String>,
pub client_ip: String,
pub database: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PluginInfo {
pub name: String,
pub version: String,
pub description: String,
pub hooks: Vec<HookType>,
pub state: PluginState,
pub stats: PluginStats,
}
#[derive(Debug, Clone)]
pub struct PluginManagerMetrics {
pub plugins_loaded: usize,
pub total_hook_calls: u64,
pub total_errors: u64,
pub avg_latency: Duration,
pub plugins: Vec<PluginInfo>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_type_export_name() {
assert_eq!(HookType::PreQuery.export_name(), "pre_query");
assert_eq!(HookType::Authenticate.export_name(), "authenticate");
assert_eq!(HookType::Route.export_name(), "route");
}
#[test]
fn test_hook_type_from_str() {
assert_eq!(HookType::from_str("pre_query"), Some(HookType::PreQuery));
assert_eq!(HookType::from_str("authenticate"), Some(HookType::Authenticate));
assert_eq!(HookType::from_str("unknown"), None);
}
#[test]
fn test_plugin_metadata_default() {
let meta = PluginMetadata::default();
assert!(meta.name.is_empty());
assert_eq!(meta.version, "0.0.0");
assert!(meta.hooks.is_empty());
}
#[test]
fn test_hook_context_default() {
let ctx = HookContext::default();
assert!(!ctx.request_id.is_empty());
assert!(ctx.client_id.is_none());
}
#[test]
fn test_pre_query_result() {
let result = PreQueryResult::Continue;
assert!(matches!(result, PreQueryResult::Continue));
let result = PreQueryResult::Block("blocked".to_string());
assert!(matches!(result, PreQueryResult::Block(_)));
}
#[test]
fn test_auth_result() {
let result = AuthResult::Denied("invalid".to_string());
assert!(matches!(result, AuthResult::Denied(_)));
let result = AuthResult::Defer;
assert!(matches!(result, AuthResult::Defer));
}
#[test]
fn test_route_result() {
let result = RouteResult::Default;
assert!(matches!(result, RouteResult::Default));
let result = RouteResult::Branch("test".to_string());
assert!(matches!(result, RouteResult::Branch(_)));
}
#[test]
fn test_identity_default() {
let identity = Identity::default();
assert!(identity.user_id.is_empty());
assert!(identity.roles.is_empty());
assert!(identity.tenant_id.is_none());
}
#[test]
fn test_execute_post_query_no_plugins_is_noop() {
let config = PluginRuntimeConfig::default();
let pm = PluginManager::new(config).expect("construct PluginManager");
let ctx = QueryContext {
query: "SELECT 1".to_string(),
normalized: "SELECT 1".to_string(),
tables: Vec::new(),
is_read_only: true,
hook_context: HookContext::default(),
};
let outcome = PostQueryOutcome {
success: true,
target_node: Some("primary".to_string()),
elapsed_us: 42,
response_bytes: 128,
error: None,
};
pm.execute_post_query(&ctx, &outcome);
let metrics = pm.get_metrics();
assert_eq!(metrics.plugins_loaded, 0);
assert_eq!(metrics.total_hook_calls, 0);
}
#[test]
fn test_execute_pre_query_no_plugins_returns_continue() {
let pm = PluginManager::new(PluginRuntimeConfig::default())
.expect("construct PluginManager");
let ctx = QueryContext {
query: "SELECT 1".to_string(),
normalized: "SELECT 1".to_string(),
tables: Vec::new(),
is_read_only: true,
hook_context: HookContext::default(),
};
assert!(matches!(pm.execute_pre_query(&ctx), PreQueryResult::Continue));
}
#[test]
fn test_post_query_outcome_serialisation() {
let outcome = PostQueryOutcome {
success: false,
target_node: None,
elapsed_us: 1234,
response_bytes: 0,
error: Some("backend timeout".to_string()),
};
let json = serde_json::to_string(&outcome).expect("serialise");
assert!(json.contains("\"success\":false"));
assert!(json.contains("\"elapsed_us\":1234"));
assert!(json.contains("backend timeout"));
}
}