1pub mod config;
40pub mod runtime;
41pub mod loader;
42pub mod host_functions;
43pub mod host_imports;
44pub mod sandbox;
45pub mod hot_reload;
46pub mod metrics;
47
48pub use config::{PluginRuntimeConfig, PluginRuntimeConfigBuilder, PluginConfig};
49pub use runtime::{WasmPluginRuntime, LoadedPlugin, PluginState, PluginError};
50pub use loader::{PluginLoader, PluginManifest, PluginLoadError, SignatureVerifier};
51pub use host_functions::HostFunctionRegistry;
52pub use sandbox::{PluginSandbox, SecurityPolicy, Permission, ResourceLimits};
53pub use hot_reload::{HotReloader, ReloadEvent, ReloadError};
54pub use metrics::{PluginMetrics, PluginStats, HookLatency};
55
56use std::collections::HashMap;
57use std::sync::Arc;
58use std::time::Duration;
59use parking_lot::RwLock;
60use dashmap::DashMap;
61
62#[derive(Debug, Clone)]
64pub struct PluginMetadata {
65 pub name: String,
67
68 pub version: String,
70
71 pub description: String,
73
74 pub author: String,
76
77 pub hooks: Vec<HookType>,
79
80 pub permissions: Vec<Permission>,
82
83 pub min_memory: usize,
85
86 pub max_memory: usize,
88}
89
90impl Default for PluginMetadata {
91 fn default() -> Self {
92 Self {
93 name: String::new(),
94 version: "0.0.0".to_string(),
95 description: String::new(),
96 author: String::new(),
97 hooks: Vec::new(),
98 permissions: Vec::new(),
99 min_memory: 1024 * 1024, max_memory: 64 * 1024 * 1024, }
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
107pub enum HookType {
108 PreQuery,
110
111 PostQuery,
113
114 Authenticate,
116
117 Authorize,
119
120 CacheGet,
122
123 CacheSet,
125
126 Route,
128
129 Rewrite,
131
132 Metrics,
134
135 OnConnect,
137
138 OnDisconnect,
140
141 Custom,
143}
144
145impl HookType {
146 pub fn export_name(&self) -> &'static str {
148 match self {
149 HookType::PreQuery => "pre_query",
150 HookType::PostQuery => "post_query",
151 HookType::Authenticate => "authenticate",
152 HookType::Authorize => "authorize",
153 HookType::CacheGet => "cache_get",
154 HookType::CacheSet => "cache_set",
155 HookType::Route => "route",
156 HookType::Rewrite => "rewrite",
157 HookType::Metrics => "metrics",
158 HookType::OnConnect => "on_connect",
159 HookType::OnDisconnect => "on_disconnect",
160 HookType::Custom => "custom_hook",
161 }
162 }
163
164 pub fn from_str(s: &str) -> Option<Self> {
166 match s.to_lowercase().as_str() {
167 "pre_query" | "prequery" => Some(HookType::PreQuery),
168 "post_query" | "postquery" => Some(HookType::PostQuery),
169 "authenticate" | "auth" => Some(HookType::Authenticate),
170 "authorize" => Some(HookType::Authorize),
171 "cache_get" | "cacheget" => Some(HookType::CacheGet),
172 "cache_set" | "cacheset" => Some(HookType::CacheSet),
173 "route" | "routing" => Some(HookType::Route),
174 "rewrite" => Some(HookType::Rewrite),
175 "metrics" => Some(HookType::Metrics),
176 "on_connect" | "connect" => Some(HookType::OnConnect),
177 "on_disconnect" | "disconnect" => Some(HookType::OnDisconnect),
178 "custom" => Some(HookType::Custom),
179 _ => None,
180 }
181 }
182}
183
184#[derive(Debug, Clone, serde::Serialize)]
186pub struct HookContext {
187 pub request_id: String,
189
190 pub client_id: Option<String>,
192
193 pub identity: Option<String>,
195
196 pub database: Option<String>,
198
199 pub branch: Option<String>,
201
202 pub attributes: HashMap<String, String>,
204}
205
206impl Default for HookContext {
207 fn default() -> Self {
208 Self {
209 request_id: uuid::Uuid::new_v4().to_string(),
210 client_id: None,
211 identity: None,
212 database: None,
213 branch: None,
214 attributes: HashMap::new(),
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct QueryContext {
222 pub query: String,
224
225 pub normalized: String,
227
228 pub tables: Vec<String>,
230
231 pub is_read_only: bool,
233
234 pub hook_context: HookContext,
236}
237
238#[derive(Debug, Clone)]
240pub enum PreQueryResult {
241 Continue,
243
244 Rewrite(String),
246
247 Block(String),
249
250 Cached(Vec<u8>),
252}
253
254#[derive(Debug, Clone, serde::Serialize)]
260pub struct PostQueryOutcome {
261 pub success: bool,
263
264 pub target_node: Option<String>,
266
267 pub elapsed_us: u64,
269
270 pub response_bytes: u64,
272
273 pub error: Option<String>,
275}
276
277#[derive(Debug, Clone)]
279pub enum AuthResult {
280 Success(Identity),
282
283 Denied(String),
285
286 Defer,
288}
289
290#[derive(Debug, Clone)]
292pub struct Identity {
293 pub user_id: String,
295
296 pub username: String,
298
299 pub roles: Vec<String>,
301
302 pub tenant_id: Option<String>,
304
305 pub claims: HashMap<String, String>,
307}
308
309impl Default for Identity {
310 fn default() -> Self {
311 Self {
312 user_id: String::new(),
313 username: String::new(),
314 roles: Vec::new(),
315 tenant_id: None,
316 claims: HashMap::new(),
317 }
318 }
319}
320
321#[derive(Debug, Clone)]
323pub enum RouteResult {
324 Default,
326
327 Node(String),
329
330 Primary,
332
333 Standby,
335
336 Branch(String),
338
339 Block(String),
345}
346
347pub struct PluginManager {
349 runtime: Arc<WasmPluginRuntime>,
351
352 plugins: DashMap<String, Arc<LoadedPlugin>>,
354
355 hooks: RwLock<HashMap<HookType, Vec<String>>>,
357
358 config: PluginRuntimeConfig,
360
361 hot_reloader: Option<HotReloader>,
363
364 metrics: Arc<PluginMetrics>,
366}
367
368impl PluginManager {
369 pub fn new(config: PluginRuntimeConfig) -> Result<Self, PluginError> {
371 let runtime = Arc::new(WasmPluginRuntime::new(&config)?);
372 let metrics = Arc::new(PluginMetrics::new());
373
374 let hot_reloader = if config.hot_reload {
375 Some(HotReloader::new(&config.plugin_dir)?)
376 } else {
377 None
378 };
379
380 Ok(Self {
381 runtime,
382 plugins: DashMap::new(),
383 hooks: RwLock::new(HashMap::new()),
384 config,
385 hot_reloader,
386 metrics,
387 })
388 }
389
390 pub fn load_plugin(&self, path: &std::path::Path) -> Result<(), PluginError> {
394 let mut loader = PluginLoader::new();
395 if let Some(ref dir) = self.runtime.config().trust_root {
396 let verifier = SignatureVerifier::from_trust_root(dir)
397 .map_err(|e| PluginError::LoadError(e.to_string()))?;
398 loader = loader.with_signature_verifier(verifier);
399 }
400 let (manifest, wasm_bytes) = loader.load(path)?;
401
402 let plugin = self.runtime.instantiate(&manifest, &wasm_bytes)?;
403 let plugin = Arc::new(plugin);
404
405 {
407 let mut hooks = self.hooks.write();
408 for hook in &manifest.hooks {
409 hooks
410 .entry(*hook)
411 .or_insert_with(Vec::new)
412 .push(manifest.name.clone());
413 }
414 }
415
416 self.plugins.insert(manifest.name.clone(), plugin);
417
418 tracing::info!(
419 plugin = %manifest.name,
420 version = %manifest.version,
421 hooks = ?manifest.hooks,
422 "Plugin loaded"
423 );
424
425 Ok(())
426 }
427
428 pub fn unload_plugin(&self, name: &str) -> Result<(), PluginError> {
430 if let Some((_, plugin)) = self.plugins.remove(name) {
431 let mut hooks = self.hooks.write();
433 for hook_plugins in hooks.values_mut() {
434 hook_plugins.retain(|p| p != name);
435 }
436
437 if let Err(e) = self.runtime.call_hook(&plugin, HookType::OnDisconnect, &[]) {
439 tracing::warn!(plugin = %name, error = %e, "Error calling on_unload");
440 }
441
442 tracing::info!(plugin = %name, "Plugin unloaded");
443 }
444
445 Ok(())
446 }
447
448 pub fn reload_plugin(&self, name: &str) -> Result<(), PluginError> {
450 if let Some(plugin) = self.plugins.get(name) {
451 let path = plugin.path.clone();
452 drop(plugin);
453
454 self.unload_plugin(name)?;
455 self.load_plugin(&path)?;
456 }
457
458 Ok(())
459 }
460
461 pub fn has_hook(&self, hook: HookType) -> bool {
465 self.hooks
466 .read()
467 .get(&hook)
468 .map_or(false, |names| !names.is_empty())
469 }
470
471 pub fn execute_pre_query(&self, ctx: &QueryContext) -> PreQueryResult {
473 let hooks = self.hooks.read();
474 let plugin_names = hooks.get(&HookType::PreQuery).cloned().unwrap_or_default();
475 drop(hooks);
476
477 for plugin_name in plugin_names {
478 if let Some(plugin) = self.plugins.get(&plugin_name) {
479 let start = std::time::Instant::now();
480
481 match self.runtime.call_pre_query(&plugin, ctx) {
482 Ok(result) => {
483 self.metrics.record_hook_call(
484 &plugin_name,
485 HookType::PreQuery,
486 start.elapsed(),
487 true,
488 );
489
490 match result {
491 PreQueryResult::Continue => continue,
492 other => return other,
493 }
494 }
495 Err(e) => {
496 self.metrics.record_hook_call(
497 &plugin_name,
498 HookType::PreQuery,
499 start.elapsed(),
500 false,
501 );
502 tracing::warn!(
503 plugin = %plugin_name,
504 error = %e,
505 "Pre-query hook failed"
506 );
507 }
508 }
509 }
510 }
511
512 PreQueryResult::Continue
513 }
514
515 pub fn execute_post_query(&self, ctx: &QueryContext, outcome: &PostQueryOutcome) {
522 let hooks = self.hooks.read();
523 let plugin_names = hooks.get(&HookType::PostQuery).cloned().unwrap_or_default();
524 drop(hooks);
525
526 if plugin_names.is_empty() {
527 return;
528 }
529
530 let payload = match serde_json::to_vec(&(ctx, outcome)) {
533 Ok(v) => v,
534 Err(e) => {
535 tracing::warn!(error = %e, "Post-query serialisation failed");
536 return;
537 }
538 };
539
540 for plugin_name in plugin_names {
541 if let Some(plugin) = self.plugins.get(&plugin_name) {
542 let start = std::time::Instant::now();
543
544 match self.runtime.call_hook(&plugin, HookType::PostQuery, &payload) {
545 Ok(_) => {
546 self.metrics.record_hook_call(
547 &plugin_name,
548 HookType::PostQuery,
549 start.elapsed(),
550 true,
551 );
552 }
553 Err(e) => {
554 self.metrics.record_hook_call(
555 &plugin_name,
556 HookType::PostQuery,
557 start.elapsed(),
558 false,
559 );
560 tracing::warn!(
561 plugin = %plugin_name,
562 error = %e,
563 "Post-query hook failed"
564 );
565 }
566 }
567 }
568 }
569 }
570
571 pub fn execute_authenticate(&self, request: &AuthRequest) -> AuthResult {
573 let hooks = self.hooks.read();
574 let plugin_names = hooks.get(&HookType::Authenticate).cloned().unwrap_or_default();
575 drop(hooks);
576
577 for plugin_name in plugin_names {
578 if let Some(plugin) = self.plugins.get(&plugin_name) {
579 let start = std::time::Instant::now();
580
581 match self.runtime.call_authenticate(&plugin, request) {
582 Ok(result) => {
583 self.metrics.record_hook_call(
584 &plugin_name,
585 HookType::Authenticate,
586 start.elapsed(),
587 true,
588 );
589
590 match result {
591 AuthResult::Defer => continue,
592 other => return other,
593 }
594 }
595 Err(e) => {
596 self.metrics.record_hook_call(
597 &plugin_name,
598 HookType::Authenticate,
599 start.elapsed(),
600 false,
601 );
602 tracing::warn!(
603 plugin = %plugin_name,
604 error = %e,
605 "Authenticate hook failed"
606 );
607 }
608 }
609 }
610 }
611
612 AuthResult::Defer
613 }
614
615 pub fn execute_route(&self, ctx: &QueryContext) -> RouteResult {
617 let hooks = self.hooks.read();
618 let plugin_names = hooks.get(&HookType::Route).cloned().unwrap_or_default();
619 drop(hooks);
620
621 for plugin_name in plugin_names {
622 if let Some(plugin) = self.plugins.get(&plugin_name) {
623 let start = std::time::Instant::now();
624
625 match self.runtime.call_route(&plugin, ctx) {
626 Ok(result) => {
627 self.metrics.record_hook_call(
628 &plugin_name,
629 HookType::Route,
630 start.elapsed(),
631 true,
632 );
633
634 match result {
635 RouteResult::Default => continue,
636 other => return other,
637 }
638 }
639 Err(e) => {
640 self.metrics.record_hook_call(
641 &plugin_name,
642 HookType::Route,
643 start.elapsed(),
644 false,
645 );
646 tracing::warn!(
647 plugin = %plugin_name,
648 error = %e,
649 "Route hook failed"
650 );
651 }
652 }
653 }
654 }
655
656 RouteResult::Default
657 }
658
659 pub fn list_plugins(&self) -> Vec<PluginInfo> {
661 self.plugins
662 .iter()
663 .map(|entry| {
664 let plugin = entry.value();
665 let stats = self.metrics.get_plugin_stats(&plugin.metadata.name);
666
667 PluginInfo {
668 name: plugin.metadata.name.clone(),
669 version: plugin.metadata.version.clone(),
670 description: plugin.metadata.description.clone(),
671 hooks: plugin.metadata.hooks.clone(),
672 state: plugin.state.clone(),
673 stats,
674 }
675 })
676 .collect()
677 }
678
679 pub fn get_metrics(&self) -> PluginManagerMetrics {
681 PluginManagerMetrics {
682 plugins_loaded: self.plugins.len(),
683 total_hook_calls: self.metrics.total_calls(),
684 total_errors: self.metrics.total_errors(),
685 avg_latency: self.metrics.avg_latency(),
686 plugins: self.list_plugins(),
687 }
688 }
689
690 pub fn check_updates(&self) -> Result<Vec<ReloadEvent>, PluginError> {
692 if let Some(ref reloader) = self.hot_reloader {
693 let events = reloader.check()?;
694
695 for event in &events {
696 match event {
697 ReloadEvent::Modified(name) => {
698 tracing::info!(plugin = %name, "Hot reloading plugin");
699 if let Err(e) = self.reload_plugin(name) {
700 tracing::error!(plugin = %name, error = %e, "Hot reload failed");
701 }
702 }
703 ReloadEvent::Removed(name) => {
704 tracing::info!(plugin = %name, "Plugin file removed, unloading");
705 if let Err(e) = self.unload_plugin(name) {
706 tracing::error!(plugin = %name, error = %e, "Unload failed");
707 }
708 }
709 ReloadEvent::Added(path) => {
710 tracing::info!(path = %path.display(), "New plugin detected, loading");
711 if let Err(e) = self.load_plugin(path) {
712 tracing::error!(path = %path.display(), error = %e, "Load failed");
713 }
714 }
715 }
716 }
717
718 Ok(events)
719 } else {
720 Ok(Vec::new())
721 }
722 }
723}
724
725#[derive(Debug, Clone)]
727pub struct AuthRequest {
728 pub headers: HashMap<String, String>,
730
731 pub username: Option<String>,
733
734 pub password: Option<String>,
736
737 pub client_ip: String,
739
740 pub database: Option<String>,
742}
743
744#[derive(Debug, Clone)]
746pub struct PluginInfo {
747 pub name: String,
749
750 pub version: String,
752
753 pub description: String,
755
756 pub hooks: Vec<HookType>,
758
759 pub state: PluginState,
761
762 pub stats: PluginStats,
764}
765
766#[derive(Debug, Clone)]
768pub struct PluginManagerMetrics {
769 pub plugins_loaded: usize,
771
772 pub total_hook_calls: u64,
774
775 pub total_errors: u64,
777
778 pub avg_latency: Duration,
780
781 pub plugins: Vec<PluginInfo>,
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 #[test]
790 fn test_hook_type_export_name() {
791 assert_eq!(HookType::PreQuery.export_name(), "pre_query");
792 assert_eq!(HookType::Authenticate.export_name(), "authenticate");
793 assert_eq!(HookType::Route.export_name(), "route");
794 }
795
796 #[test]
797 fn test_hook_type_from_str() {
798 assert_eq!(HookType::from_str("pre_query"), Some(HookType::PreQuery));
799 assert_eq!(HookType::from_str("authenticate"), Some(HookType::Authenticate));
800 assert_eq!(HookType::from_str("unknown"), None);
801 }
802
803 #[test]
804 fn test_plugin_metadata_default() {
805 let meta = PluginMetadata::default();
806 assert!(meta.name.is_empty());
807 assert_eq!(meta.version, "0.0.0");
808 assert!(meta.hooks.is_empty());
809 }
810
811 #[test]
812 fn test_hook_context_default() {
813 let ctx = HookContext::default();
814 assert!(!ctx.request_id.is_empty());
815 assert!(ctx.client_id.is_none());
816 }
817
818 #[test]
819 fn test_pre_query_result() {
820 let result = PreQueryResult::Continue;
821 assert!(matches!(result, PreQueryResult::Continue));
822
823 let result = PreQueryResult::Block("blocked".to_string());
824 assert!(matches!(result, PreQueryResult::Block(_)));
825 }
826
827 #[test]
828 fn test_auth_result() {
829 let result = AuthResult::Denied("invalid".to_string());
830 assert!(matches!(result, AuthResult::Denied(_)));
831
832 let result = AuthResult::Defer;
833 assert!(matches!(result, AuthResult::Defer));
834 }
835
836 #[test]
837 fn test_route_result() {
838 let result = RouteResult::Default;
839 assert!(matches!(result, RouteResult::Default));
840
841 let result = RouteResult::Branch("test".to_string());
842 assert!(matches!(result, RouteResult::Branch(_)));
843 }
844
845 #[test]
846 fn test_identity_default() {
847 let identity = Identity::default();
848 assert!(identity.user_id.is_empty());
849 assert!(identity.roles.is_empty());
850 assert!(identity.tenant_id.is_none());
851 }
852
853 #[test]
858 fn test_execute_post_query_no_plugins_is_noop() {
859 let config = PluginRuntimeConfig::default();
860 let pm = PluginManager::new(config).expect("construct PluginManager");
861
862 let ctx = QueryContext {
863 query: "SELECT 1".to_string(),
864 normalized: "SELECT 1".to_string(),
865 tables: Vec::new(),
866 is_read_only: true,
867 hook_context: HookContext::default(),
868 };
869 let outcome = PostQueryOutcome {
870 success: true,
871 target_node: Some("primary".to_string()),
872 elapsed_us: 42,
873 response_bytes: 128,
874 error: None,
875 };
876
877 pm.execute_post_query(&ctx, &outcome);
879
880 let metrics = pm.get_metrics();
882 assert_eq!(metrics.plugins_loaded, 0);
883 assert_eq!(metrics.total_hook_calls, 0);
884 }
885
886 #[test]
889 fn test_execute_pre_query_no_plugins_returns_continue() {
890 let pm = PluginManager::new(PluginRuntimeConfig::default())
891 .expect("construct PluginManager");
892 let ctx = QueryContext {
893 query: "SELECT 1".to_string(),
894 normalized: "SELECT 1".to_string(),
895 tables: Vec::new(),
896 is_read_only: true,
897 hook_context: HookContext::default(),
898 };
899 assert!(matches!(pm.execute_pre_query(&ctx), PreQueryResult::Continue));
900 }
901
902 #[test]
905 fn test_post_query_outcome_serialisation() {
906 let outcome = PostQueryOutcome {
907 success: false,
908 target_node: None,
909 elapsed_us: 1234,
910 response_bytes: 0,
911 error: Some("backend timeout".to_string()),
912 };
913 let json = serde_json::to_string(&outcome).expect("serialise");
914 assert!(json.contains("\"success\":false"));
915 assert!(json.contains("\"elapsed_us\":1234"));
916 assert!(json.contains("backend timeout"));
917 }
918}