1pub mod config;
40pub mod host_functions;
41pub mod host_imports;
42pub mod hot_reload;
43pub mod loader;
44pub mod metrics;
45pub mod runtime;
46pub mod sandbox;
47
48pub use config::{PluginConfig, PluginRuntimeConfig, PluginRuntimeConfigBuilder};
49pub use host_functions::HostFunctionRegistry;
50pub use hot_reload::{HotReloader, ReloadError, ReloadEvent};
51pub use loader::{PluginLoadError, PluginLoader, PluginManifest, SignatureVerifier};
52pub use metrics::{HookLatency, PluginMetrics, PluginStats};
53pub use runtime::{LoadedPlugin, PluginError, PluginState, WasmPluginRuntime};
54pub use sandbox::{Permission, PluginSandbox, ResourceLimits, SecurityPolicy};
55
56use dashmap::DashMap;
57use parking_lot::RwLock;
58use std::collections::HashMap;
59use std::sync::Arc;
60use std::time::Duration;
61
62#[derive(Debug, Clone)]
64pub struct PluginMetadata {
65 pub name: String,
67
68 pub version: String,
70
71 pub description: String,
73
74 pub author: String,
76
77 pub hooks: Vec<HookType>,
79
80 pub permissions: Vec<Permission>,
82
83 pub min_memory: usize,
85
86 pub max_memory: usize,
88}
89
90impl Default for PluginMetadata {
91 fn default() -> Self {
92 Self {
93 name: String::new(),
94 version: "0.0.0".to_string(),
95 description: String::new(),
96 author: String::new(),
97 hooks: Vec::new(),
98 permissions: Vec::new(),
99 min_memory: 1024 * 1024, max_memory: 64 * 1024 * 1024, }
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
107pub enum HookType {
108 PreQuery,
110
111 PostQuery,
113
114 Authenticate,
116
117 Authorize,
119
120 CacheGet,
122
123 CacheSet,
125
126 Route,
128
129 Rewrite,
131
132 Metrics,
134
135 OnConnect,
137
138 OnDisconnect,
140
141 Custom,
143}
144
145impl HookType {
146 pub fn export_name(&self) -> &'static str {
148 match self {
149 HookType::PreQuery => "pre_query",
150 HookType::PostQuery => "post_query",
151 HookType::Authenticate => "authenticate",
152 HookType::Authorize => "authorize",
153 HookType::CacheGet => "cache_get",
154 HookType::CacheSet => "cache_set",
155 HookType::Route => "route",
156 HookType::Rewrite => "rewrite",
157 HookType::Metrics => "metrics",
158 HookType::OnConnect => "on_connect",
159 HookType::OnDisconnect => "on_disconnect",
160 HookType::Custom => "custom_hook",
161 }
162 }
163
164 #[allow(clippy::should_implement_trait)]
166 pub fn from_str(s: &str) -> Option<Self> {
167 match s.to_lowercase().as_str() {
168 "pre_query" | "prequery" => Some(HookType::PreQuery),
169 "post_query" | "postquery" => Some(HookType::PostQuery),
170 "authenticate" | "auth" => Some(HookType::Authenticate),
171 "authorize" => Some(HookType::Authorize),
172 "cache_get" | "cacheget" => Some(HookType::CacheGet),
173 "cache_set" | "cacheset" => Some(HookType::CacheSet),
174 "route" | "routing" => Some(HookType::Route),
175 "rewrite" => Some(HookType::Rewrite),
176 "metrics" => Some(HookType::Metrics),
177 "on_connect" | "connect" => Some(HookType::OnConnect),
178 "on_disconnect" | "disconnect" => Some(HookType::OnDisconnect),
179 "custom" => Some(HookType::Custom),
180 _ => None,
181 }
182 }
183}
184
185#[derive(Debug, Clone, serde::Serialize)]
187pub struct HookContext {
188 pub request_id: String,
190
191 pub client_id: Option<String>,
193
194 pub identity: Option<String>,
196
197 pub database: Option<String>,
199
200 pub branch: Option<String>,
202
203 pub attributes: HashMap<String, String>,
205}
206
207impl Default for HookContext {
208 fn default() -> Self {
209 Self {
210 request_id: uuid::Uuid::new_v4().to_string(),
211 client_id: None,
212 identity: None,
213 database: None,
214 branch: None,
215 attributes: HashMap::new(),
216 }
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct QueryContext {
223 pub query: String,
225
226 pub normalized: String,
228
229 pub tables: Vec<String>,
231
232 pub is_read_only: bool,
234
235 pub hook_context: HookContext,
237}
238
239#[derive(Debug, Clone)]
241pub enum PreQueryResult {
242 Continue,
244
245 Rewrite(String),
247
248 Block(String),
250
251 Cached(Vec<u8>),
253}
254
255#[derive(Debug, Clone, serde::Serialize)]
261pub struct PostQueryOutcome {
262 pub success: bool,
264
265 pub target_node: Option<String>,
267
268 pub elapsed_us: u64,
270
271 pub response_bytes: u64,
273
274 pub error: Option<String>,
276}
277
278#[derive(Debug, Clone)]
280pub enum AuthResult {
281 Success(Identity),
283
284 Denied(String),
286
287 Defer,
289}
290
291#[derive(Debug, Clone, Default)]
293pub struct Identity {
294 pub user_id: String,
296
297 pub username: String,
299
300 pub roles: Vec<String>,
302
303 pub tenant_id: Option<String>,
305
306 pub claims: HashMap<String, String>,
308}
309
310#[derive(Debug, Clone)]
312pub enum RouteResult {
313 Default,
315
316 Node(String),
318
319 Primary,
321
322 Standby,
324
325 Branch(String),
327
328 Block(String),
334}
335
336pub struct PluginManager {
338 runtime: Arc<WasmPluginRuntime>,
340
341 plugins: DashMap<String, Arc<LoadedPlugin>>,
343
344 hooks: RwLock<HashMap<HookType, Vec<String>>>,
346
347 #[allow(dead_code)]
349 config: PluginRuntimeConfig,
350
351 hot_reloader: Option<HotReloader>,
353
354 metrics: Arc<PluginMetrics>,
356}
357
358impl PluginManager {
359 pub fn new(config: PluginRuntimeConfig) -> Result<Self, PluginError> {
361 let runtime = Arc::new(WasmPluginRuntime::new(&config)?);
362 let metrics = Arc::new(PluginMetrics::new());
363
364 let hot_reloader = if config.hot_reload {
365 Some(HotReloader::new(&config.plugin_dir)?)
366 } else {
367 None
368 };
369
370 Ok(Self {
371 runtime,
372 plugins: DashMap::new(),
373 hooks: RwLock::new(HashMap::new()),
374 config,
375 hot_reloader,
376 metrics,
377 })
378 }
379
380 pub fn load_plugin(&self, path: &std::path::Path) -> Result<(), PluginError> {
384 let mut loader = PluginLoader::new();
385 if let Some(ref dir) = self.runtime.config().trust_root {
386 let verifier = SignatureVerifier::from_trust_root(dir)
387 .map_err(|e| PluginError::LoadError(e.to_string()))?;
388 loader = loader.with_signature_verifier(verifier);
389 }
390 let (manifest, wasm_bytes) = loader.load(path)?;
391
392 let plugin = self.runtime.instantiate(&manifest, &wasm_bytes)?;
393 let plugin = Arc::new(plugin);
394
395 {
397 let mut hooks = self.hooks.write();
398 for hook in &manifest.hooks {
399 hooks.entry(*hook).or_default().push(manifest.name.clone());
400 }
401 }
402
403 self.plugins.insert(manifest.name.clone(), plugin);
404
405 tracing::info!(
406 plugin = %manifest.name,
407 version = %manifest.version,
408 hooks = ?manifest.hooks,
409 "Plugin loaded"
410 );
411
412 Ok(())
413 }
414
415 pub fn unload_plugin(&self, name: &str) -> Result<(), PluginError> {
417 if let Some((_, plugin)) = self.plugins.remove(name) {
418 let mut hooks = self.hooks.write();
420 for hook_plugins in hooks.values_mut() {
421 hook_plugins.retain(|p| p != name);
422 }
423
424 if let Err(e) = self.runtime.call_hook(&plugin, HookType::OnDisconnect, &[]) {
426 tracing::warn!(plugin = %name, error = %e, "Error calling on_unload");
427 }
428
429 tracing::info!(plugin = %name, "Plugin unloaded");
430 }
431
432 Ok(())
433 }
434
435 pub fn reload_plugin(&self, name: &str) -> Result<(), PluginError> {
437 if let Some(plugin) = self.plugins.get(name) {
438 let path = plugin.path.clone();
439 drop(plugin);
440
441 self.unload_plugin(name)?;
442 self.load_plugin(&path)?;
443 }
444
445 Ok(())
446 }
447
448 pub fn has_hook(&self, hook: HookType) -> bool {
452 self.hooks
453 .read()
454 .get(&hook)
455 .is_some_and(|names| !names.is_empty())
456 }
457
458 pub fn execute_pre_query(&self, ctx: &QueryContext) -> PreQueryResult {
460 let hooks = self.hooks.read();
461 let plugin_names = hooks.get(&HookType::PreQuery).cloned().unwrap_or_default();
462 drop(hooks);
463
464 for plugin_name in plugin_names {
465 if let Some(plugin) = self.plugins.get(&plugin_name) {
466 let start = std::time::Instant::now();
467
468 match self.runtime.call_pre_query(&plugin, ctx) {
469 Ok(result) => {
470 self.metrics.record_hook_call(
471 &plugin_name,
472 HookType::PreQuery,
473 start.elapsed(),
474 true,
475 );
476
477 match result {
478 PreQueryResult::Continue => continue,
479 other => return other,
480 }
481 }
482 Err(e) => {
483 self.metrics.record_hook_call(
484 &plugin_name,
485 HookType::PreQuery,
486 start.elapsed(),
487 false,
488 );
489 tracing::warn!(
490 plugin = %plugin_name,
491 error = %e,
492 "Pre-query hook failed"
493 );
494 }
495 }
496 }
497 }
498
499 PreQueryResult::Continue
500 }
501
502 pub fn execute_post_query(&self, ctx: &QueryContext, outcome: &PostQueryOutcome) {
509 let hooks = self.hooks.read();
510 let plugin_names = hooks.get(&HookType::PostQuery).cloned().unwrap_or_default();
511 drop(hooks);
512
513 if plugin_names.is_empty() {
514 return;
515 }
516
517 let payload = match serde_json::to_vec(&(ctx, outcome)) {
520 Ok(v) => v,
521 Err(e) => {
522 tracing::warn!(error = %e, "Post-query serialisation failed");
523 return;
524 }
525 };
526
527 for plugin_name in plugin_names {
528 if let Some(plugin) = self.plugins.get(&plugin_name) {
529 let start = std::time::Instant::now();
530
531 match self
532 .runtime
533 .call_hook(&plugin, HookType::PostQuery, &payload)
534 {
535 Ok(_) => {
536 self.metrics.record_hook_call(
537 &plugin_name,
538 HookType::PostQuery,
539 start.elapsed(),
540 true,
541 );
542 }
543 Err(e) => {
544 self.metrics.record_hook_call(
545 &plugin_name,
546 HookType::PostQuery,
547 start.elapsed(),
548 false,
549 );
550 tracing::warn!(
551 plugin = %plugin_name,
552 error = %e,
553 "Post-query hook failed"
554 );
555 }
556 }
557 }
558 }
559 }
560
561 pub fn execute_authenticate(&self, request: &AuthRequest) -> AuthResult {
563 let hooks = self.hooks.read();
564 let plugin_names = hooks
565 .get(&HookType::Authenticate)
566 .cloned()
567 .unwrap_or_default();
568 drop(hooks);
569
570 for plugin_name in plugin_names {
571 if let Some(plugin) = self.plugins.get(&plugin_name) {
572 let start = std::time::Instant::now();
573
574 match self.runtime.call_authenticate(&plugin, request) {
575 Ok(result) => {
576 self.metrics.record_hook_call(
577 &plugin_name,
578 HookType::Authenticate,
579 start.elapsed(),
580 true,
581 );
582
583 match result {
584 AuthResult::Defer => continue,
585 other => return other,
586 }
587 }
588 Err(e) => {
589 self.metrics.record_hook_call(
590 &plugin_name,
591 HookType::Authenticate,
592 start.elapsed(),
593 false,
594 );
595 tracing::warn!(
596 plugin = %plugin_name,
597 error = %e,
598 "Authenticate hook failed"
599 );
600 }
601 }
602 }
603 }
604
605 AuthResult::Defer
606 }
607
608 pub fn execute_route(&self, ctx: &QueryContext) -> RouteResult {
610 let hooks = self.hooks.read();
611 let plugin_names = hooks.get(&HookType::Route).cloned().unwrap_or_default();
612 drop(hooks);
613
614 for plugin_name in plugin_names {
615 if let Some(plugin) = self.plugins.get(&plugin_name) {
616 let start = std::time::Instant::now();
617
618 match self.runtime.call_route(&plugin, ctx) {
619 Ok(result) => {
620 self.metrics.record_hook_call(
621 &plugin_name,
622 HookType::Route,
623 start.elapsed(),
624 true,
625 );
626
627 match result {
628 RouteResult::Default => continue,
629 other => return other,
630 }
631 }
632 Err(e) => {
633 self.metrics.record_hook_call(
634 &plugin_name,
635 HookType::Route,
636 start.elapsed(),
637 false,
638 );
639 tracing::warn!(
640 plugin = %plugin_name,
641 error = %e,
642 "Route hook failed"
643 );
644 }
645 }
646 }
647 }
648
649 RouteResult::Default
650 }
651
652 pub fn list_plugins(&self) -> Vec<PluginInfo> {
654 self.plugins
655 .iter()
656 .map(|entry| {
657 let plugin = entry.value();
658 let stats = self.metrics.get_plugin_stats(&plugin.metadata.name);
659
660 PluginInfo {
661 name: plugin.metadata.name.clone(),
662 version: plugin.metadata.version.clone(),
663 description: plugin.metadata.description.clone(),
664 hooks: plugin.metadata.hooks.clone(),
665 state: plugin.state.clone(),
666 stats,
667 }
668 })
669 .collect()
670 }
671
672 pub fn get_metrics(&self) -> PluginManagerMetrics {
674 PluginManagerMetrics {
675 plugins_loaded: self.plugins.len(),
676 total_hook_calls: self.metrics.total_calls(),
677 total_errors: self.metrics.total_errors(),
678 avg_latency: self.metrics.avg_latency(),
679 plugins: self.list_plugins(),
680 }
681 }
682
683 pub fn check_updates(&self) -> Result<Vec<ReloadEvent>, PluginError> {
685 if let Some(ref reloader) = self.hot_reloader {
686 let events = reloader.check()?;
687
688 for event in &events {
689 match event {
690 ReloadEvent::Modified(name) => {
691 tracing::info!(plugin = %name, "Hot reloading plugin");
692 if let Err(e) = self.reload_plugin(name) {
693 tracing::error!(plugin = %name, error = %e, "Hot reload failed");
694 }
695 }
696 ReloadEvent::Removed(name) => {
697 tracing::info!(plugin = %name, "Plugin file removed, unloading");
698 if let Err(e) = self.unload_plugin(name) {
699 tracing::error!(plugin = %name, error = %e, "Unload failed");
700 }
701 }
702 ReloadEvent::Added(path) => {
703 tracing::info!(path = %path.display(), "New plugin detected, loading");
704 if let Err(e) = self.load_plugin(path) {
705 tracing::error!(path = %path.display(), error = %e, "Load failed");
706 }
707 }
708 }
709 }
710
711 Ok(events)
712 } else {
713 Ok(Vec::new())
714 }
715 }
716}
717
718#[derive(Debug, Clone)]
720pub struct AuthRequest {
721 pub headers: HashMap<String, String>,
723
724 pub username: Option<String>,
726
727 pub password: Option<String>,
729
730 pub client_ip: String,
732
733 pub database: Option<String>,
735}
736
737#[derive(Debug, Clone)]
739pub struct PluginInfo {
740 pub name: String,
742
743 pub version: String,
745
746 pub description: String,
748
749 pub hooks: Vec<HookType>,
751
752 pub state: PluginState,
754
755 pub stats: PluginStats,
757}
758
759#[derive(Debug, Clone)]
761pub struct PluginManagerMetrics {
762 pub plugins_loaded: usize,
764
765 pub total_hook_calls: u64,
767
768 pub total_errors: u64,
770
771 pub avg_latency: Duration,
773
774 pub plugins: Vec<PluginInfo>,
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781
782 #[test]
783 fn test_hook_type_export_name() {
784 assert_eq!(HookType::PreQuery.export_name(), "pre_query");
785 assert_eq!(HookType::Authenticate.export_name(), "authenticate");
786 assert_eq!(HookType::Route.export_name(), "route");
787 }
788
789 #[test]
790 fn test_hook_type_from_str() {
791 assert_eq!(HookType::from_str("pre_query"), Some(HookType::PreQuery));
792 assert_eq!(
793 HookType::from_str("authenticate"),
794 Some(HookType::Authenticate)
795 );
796 assert_eq!(HookType::from_str("unknown"), None);
797 }
798
799 #[test]
800 fn test_plugin_metadata_default() {
801 let meta = PluginMetadata::default();
802 assert!(meta.name.is_empty());
803 assert_eq!(meta.version, "0.0.0");
804 assert!(meta.hooks.is_empty());
805 }
806
807 #[test]
808 fn test_hook_context_default() {
809 let ctx = HookContext::default();
810 assert!(!ctx.request_id.is_empty());
811 assert!(ctx.client_id.is_none());
812 }
813
814 #[test]
815 fn test_pre_query_result() {
816 let result = PreQueryResult::Continue;
817 assert!(matches!(result, PreQueryResult::Continue));
818
819 let result = PreQueryResult::Block("blocked".to_string());
820 assert!(matches!(result, PreQueryResult::Block(_)));
821 }
822
823 #[test]
824 fn test_auth_result() {
825 let result = AuthResult::Denied("invalid".to_string());
826 assert!(matches!(result, AuthResult::Denied(_)));
827
828 let result = AuthResult::Defer;
829 assert!(matches!(result, AuthResult::Defer));
830 }
831
832 #[test]
833 fn test_route_result() {
834 let result = RouteResult::Default;
835 assert!(matches!(result, RouteResult::Default));
836
837 let result = RouteResult::Branch("test".to_string());
838 assert!(matches!(result, RouteResult::Branch(_)));
839 }
840
841 #[test]
842 fn test_identity_default() {
843 let identity = Identity::default();
844 assert!(identity.user_id.is_empty());
845 assert!(identity.roles.is_empty());
846 assert!(identity.tenant_id.is_none());
847 }
848
849 #[test]
854 fn test_execute_post_query_no_plugins_is_noop() {
855 let config = PluginRuntimeConfig::default();
856 let pm = PluginManager::new(config).expect("construct PluginManager");
857
858 let ctx = QueryContext {
859 query: "SELECT 1".to_string(),
860 normalized: "SELECT 1".to_string(),
861 tables: Vec::new(),
862 is_read_only: true,
863 hook_context: HookContext::default(),
864 };
865 let outcome = PostQueryOutcome {
866 success: true,
867 target_node: Some("primary".to_string()),
868 elapsed_us: 42,
869 response_bytes: 128,
870 error: None,
871 };
872
873 pm.execute_post_query(&ctx, &outcome);
875
876 let metrics = pm.get_metrics();
878 assert_eq!(metrics.plugins_loaded, 0);
879 assert_eq!(metrics.total_hook_calls, 0);
880 }
881
882 #[test]
885 fn test_execute_pre_query_no_plugins_returns_continue() {
886 let pm =
887 PluginManager::new(PluginRuntimeConfig::default()).expect("construct PluginManager");
888 let ctx = QueryContext {
889 query: "SELECT 1".to_string(),
890 normalized: "SELECT 1".to_string(),
891 tables: Vec::new(),
892 is_read_only: true,
893 hook_context: HookContext::default(),
894 };
895 assert!(matches!(
896 pm.execute_pre_query(&ctx),
897 PreQueryResult::Continue
898 ));
899 }
900
901 #[test]
904 fn test_post_query_outcome_serialisation() {
905 let outcome = PostQueryOutcome {
906 success: false,
907 target_node: None,
908 elapsed_us: 1234,
909 response_bytes: 0,
910 error: Some("backend timeout".to_string()),
911 };
912 let json = serde_json::to_string(&outcome).expect("serialise");
913 assert!(json.contains("\"success\":false"));
914 assert!(json.contains("\"elapsed_us\":1234"));
915 assert!(json.contains("backend timeout"));
916 }
917}