1pub mod config;
40pub mod runtime;
41pub mod loader;
42pub mod host_functions;
43pub mod host_imports;
44pub mod sandbox;
45pub mod hot_reload;
46pub mod metrics;
47
48pub use config::{PluginRuntimeConfig, PluginRuntimeConfigBuilder, PluginConfig};
49pub use runtime::{WasmPluginRuntime, LoadedPlugin, PluginState, PluginError};
50pub use loader::{PluginLoader, PluginManifest, PluginLoadError, SignatureVerifier};
51pub use host_functions::HostFunctionRegistry;
52pub use sandbox::{PluginSandbox, SecurityPolicy, Permission, ResourceLimits};
53pub use hot_reload::{HotReloader, ReloadEvent, ReloadError};
54pub use metrics::{PluginMetrics, PluginStats, HookLatency};
55
56use std::collections::HashMap;
57use std::sync::Arc;
58use std::time::Duration;
59use parking_lot::RwLock;
60use dashmap::DashMap;
61
62#[derive(Debug, Clone)]
64pub struct PluginMetadata {
65 pub name: String,
67
68 pub version: String,
70
71 pub description: String,
73
74 pub author: String,
76
77 pub hooks: Vec<HookType>,
79
80 pub permissions: Vec<Permission>,
82
83 pub min_memory: usize,
85
86 pub max_memory: usize,
88}
89
90impl Default for PluginMetadata {
91 fn default() -> Self {
92 Self {
93 name: String::new(),
94 version: "0.0.0".to_string(),
95 description: String::new(),
96 author: String::new(),
97 hooks: Vec::new(),
98 permissions: Vec::new(),
99 min_memory: 1024 * 1024, max_memory: 64 * 1024 * 1024, }
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
107pub enum HookType {
108 PreQuery,
110
111 PostQuery,
113
114 Authenticate,
116
117 Authorize,
119
120 CacheGet,
122
123 CacheSet,
125
126 Route,
128
129 Rewrite,
131
132 Metrics,
134
135 OnConnect,
137
138 OnDisconnect,
140
141 Custom,
143}
144
145impl HookType {
146 pub fn export_name(&self) -> &'static str {
148 match self {
149 HookType::PreQuery => "pre_query",
150 HookType::PostQuery => "post_query",
151 HookType::Authenticate => "authenticate",
152 HookType::Authorize => "authorize",
153 HookType::CacheGet => "cache_get",
154 HookType::CacheSet => "cache_set",
155 HookType::Route => "route",
156 HookType::Rewrite => "rewrite",
157 HookType::Metrics => "metrics",
158 HookType::OnConnect => "on_connect",
159 HookType::OnDisconnect => "on_disconnect",
160 HookType::Custom => "custom_hook",
161 }
162 }
163
164 pub fn from_str(s: &str) -> Option<Self> {
166 match s.to_lowercase().as_str() {
167 "pre_query" | "prequery" => Some(HookType::PreQuery),
168 "post_query" | "postquery" => Some(HookType::PostQuery),
169 "authenticate" | "auth" => Some(HookType::Authenticate),
170 "authorize" => Some(HookType::Authorize),
171 "cache_get" | "cacheget" => Some(HookType::CacheGet),
172 "cache_set" | "cacheset" => Some(HookType::CacheSet),
173 "route" | "routing" => Some(HookType::Route),
174 "rewrite" => Some(HookType::Rewrite),
175 "metrics" => Some(HookType::Metrics),
176 "on_connect" | "connect" => Some(HookType::OnConnect),
177 "on_disconnect" | "disconnect" => Some(HookType::OnDisconnect),
178 "custom" => Some(HookType::Custom),
179 _ => None,
180 }
181 }
182}
183
184#[derive(Debug, Clone, serde::Serialize)]
186pub struct HookContext {
187 pub request_id: String,
189
190 pub client_id: Option<String>,
192
193 pub identity: Option<String>,
195
196 pub database: Option<String>,
198
199 pub branch: Option<String>,
201
202 pub attributes: HashMap<String, String>,
204}
205
206impl Default for HookContext {
207 fn default() -> Self {
208 Self {
209 request_id: uuid::Uuid::new_v4().to_string(),
210 client_id: None,
211 identity: None,
212 database: None,
213 branch: None,
214 attributes: HashMap::new(),
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct QueryContext {
222 pub query: String,
224
225 pub normalized: String,
227
228 pub tables: Vec<String>,
230
231 pub is_read_only: bool,
233
234 pub hook_context: HookContext,
236}
237
238#[derive(Debug, Clone)]
240pub enum PreQueryResult {
241 Continue,
243
244 Rewrite(String),
246
247 Block(String),
249
250 Cached(Vec<u8>),
252}
253
254#[derive(Debug, Clone, serde::Serialize)]
260pub struct PostQueryOutcome {
261 pub success: bool,
263
264 pub target_node: Option<String>,
266
267 pub elapsed_us: u64,
269
270 pub response_bytes: u64,
272
273 pub error: Option<String>,
275}
276
277#[derive(Debug, Clone)]
279pub enum AuthResult {
280 Success(Identity),
282
283 Denied(String),
285
286 Defer,
288}
289
290#[derive(Debug, Clone)]
292pub struct Identity {
293 pub user_id: String,
295
296 pub username: String,
298
299 pub roles: Vec<String>,
301
302 pub tenant_id: Option<String>,
304
305 pub claims: HashMap<String, String>,
307}
308
309impl Default for Identity {
310 fn default() -> Self {
311 Self {
312 user_id: String::new(),
313 username: String::new(),
314 roles: Vec::new(),
315 tenant_id: None,
316 claims: HashMap::new(),
317 }
318 }
319}
320
321#[derive(Debug, Clone)]
323pub enum RouteResult {
324 Default,
326
327 Node(String),
329
330 Primary,
332
333 Standby,
335
336 Branch(String),
338
339 Block(String),
345}
346
347pub struct PluginManager {
349 runtime: Arc<WasmPluginRuntime>,
351
352 plugins: DashMap<String, Arc<LoadedPlugin>>,
354
355 hooks: RwLock<HashMap<HookType, Vec<String>>>,
357
358 config: PluginRuntimeConfig,
360
361 hot_reloader: Option<HotReloader>,
363
364 metrics: Arc<PluginMetrics>,
366}
367
368impl PluginManager {
369 pub fn new(config: PluginRuntimeConfig) -> Result<Self, PluginError> {
371 let runtime = Arc::new(WasmPluginRuntime::new(&config)?);
372 let metrics = Arc::new(PluginMetrics::new());
373
374 let hot_reloader = if config.hot_reload {
375 Some(HotReloader::new(&config.plugin_dir)?)
376 } else {
377 None
378 };
379
380 Ok(Self {
381 runtime,
382 plugins: DashMap::new(),
383 hooks: RwLock::new(HashMap::new()),
384 config,
385 hot_reloader,
386 metrics,
387 })
388 }
389
390 pub fn load_plugin(&self, path: &std::path::Path) -> Result<(), PluginError> {
394 let mut loader = PluginLoader::new();
395 if let Some(ref dir) = self.runtime.config().trust_root {
396 let verifier = SignatureVerifier::from_trust_root(dir)
397 .map_err(|e| PluginError::LoadError(e.to_string()))?;
398 loader = loader.with_signature_verifier(verifier);
399 }
400 let (manifest, wasm_bytes) = loader.load(path)?;
401
402 let plugin = self.runtime.instantiate(&manifest, &wasm_bytes)?;
403 let plugin = Arc::new(plugin);
404
405 {
407 let mut hooks = self.hooks.write();
408 for hook in &manifest.hooks {
409 hooks
410 .entry(*hook)
411 .or_insert_with(Vec::new)
412 .push(manifest.name.clone());
413 }
414 }
415
416 self.plugins.insert(manifest.name.clone(), plugin);
417
418 tracing::info!(
419 plugin = %manifest.name,
420 version = %manifest.version,
421 hooks = ?manifest.hooks,
422 "Plugin loaded"
423 );
424
425 Ok(())
426 }
427
428 pub fn unload_plugin(&self, name: &str) -> Result<(), PluginError> {
430 if let Some((_, plugin)) = self.plugins.remove(name) {
431 let mut hooks = self.hooks.write();
433 for hook_plugins in hooks.values_mut() {
434 hook_plugins.retain(|p| p != name);
435 }
436
437 if let Err(e) = self.runtime.call_hook(&plugin, HookType::OnDisconnect, &[]) {
439 tracing::warn!(plugin = %name, error = %e, "Error calling on_unload");
440 }
441
442 tracing::info!(plugin = %name, "Plugin unloaded");
443 }
444
445 Ok(())
446 }
447
448 pub fn reload_plugin(&self, name: &str) -> Result<(), PluginError> {
450 if let Some(plugin) = self.plugins.get(name) {
451 let path = plugin.path.clone();
452 drop(plugin);
453
454 self.unload_plugin(name)?;
455 self.load_plugin(&path)?;
456 }
457
458 Ok(())
459 }
460
461 pub fn execute_pre_query(&self, ctx: &QueryContext) -> PreQueryResult {
463 let hooks = self.hooks.read();
464 let plugin_names = hooks.get(&HookType::PreQuery).cloned().unwrap_or_default();
465 drop(hooks);
466
467 for plugin_name in plugin_names {
468 if let Some(plugin) = self.plugins.get(&plugin_name) {
469 let start = std::time::Instant::now();
470
471 match self.runtime.call_pre_query(&plugin, ctx) {
472 Ok(result) => {
473 self.metrics.record_hook_call(
474 &plugin_name,
475 HookType::PreQuery,
476 start.elapsed(),
477 true,
478 );
479
480 match result {
481 PreQueryResult::Continue => continue,
482 other => return other,
483 }
484 }
485 Err(e) => {
486 self.metrics.record_hook_call(
487 &plugin_name,
488 HookType::PreQuery,
489 start.elapsed(),
490 false,
491 );
492 tracing::warn!(
493 plugin = %plugin_name,
494 error = %e,
495 "Pre-query hook failed"
496 );
497 }
498 }
499 }
500 }
501
502 PreQueryResult::Continue
503 }
504
505 pub fn execute_post_query(&self, ctx: &QueryContext, outcome: &PostQueryOutcome) {
512 let hooks = self.hooks.read();
513 let plugin_names = hooks.get(&HookType::PostQuery).cloned().unwrap_or_default();
514 drop(hooks);
515
516 for plugin_name in plugin_names {
517 if let Some(plugin) = self.plugins.get(&plugin_name) {
518 let start = std::time::Instant::now();
519
520 let payload = match serde_json::to_vec(&(ctx, outcome)) {
523 Ok(v) => v,
524 Err(e) => {
525 tracing::warn!(
526 plugin = %plugin_name,
527 error = %e,
528 "Post-query serialisation failed"
529 );
530 continue;
531 }
532 };
533
534 match self.runtime.call_hook(&plugin, HookType::PostQuery, &payload) {
535 Ok(_) => {
536 self.metrics.record_hook_call(
537 &plugin_name,
538 HookType::PostQuery,
539 start.elapsed(),
540 true,
541 );
542 }
543 Err(e) => {
544 self.metrics.record_hook_call(
545 &plugin_name,
546 HookType::PostQuery,
547 start.elapsed(),
548 false,
549 );
550 tracing::warn!(
551 plugin = %plugin_name,
552 error = %e,
553 "Post-query hook failed"
554 );
555 }
556 }
557 }
558 }
559 }
560
561 pub fn execute_authenticate(&self, request: &AuthRequest) -> AuthResult {
563 let hooks = self.hooks.read();
564 let plugin_names = hooks.get(&HookType::Authenticate).cloned().unwrap_or_default();
565 drop(hooks);
566
567 for plugin_name in plugin_names {
568 if let Some(plugin) = self.plugins.get(&plugin_name) {
569 let start = std::time::Instant::now();
570
571 match self.runtime.call_authenticate(&plugin, request) {
572 Ok(result) => {
573 self.metrics.record_hook_call(
574 &plugin_name,
575 HookType::Authenticate,
576 start.elapsed(),
577 true,
578 );
579
580 match result {
581 AuthResult::Defer => continue,
582 other => return other,
583 }
584 }
585 Err(e) => {
586 self.metrics.record_hook_call(
587 &plugin_name,
588 HookType::Authenticate,
589 start.elapsed(),
590 false,
591 );
592 tracing::warn!(
593 plugin = %plugin_name,
594 error = %e,
595 "Authenticate hook failed"
596 );
597 }
598 }
599 }
600 }
601
602 AuthResult::Defer
603 }
604
605 pub fn execute_route(&self, ctx: &QueryContext) -> RouteResult {
607 let hooks = self.hooks.read();
608 let plugin_names = hooks.get(&HookType::Route).cloned().unwrap_or_default();
609 drop(hooks);
610
611 for plugin_name in plugin_names {
612 if let Some(plugin) = self.plugins.get(&plugin_name) {
613 let start = std::time::Instant::now();
614
615 match self.runtime.call_route(&plugin, ctx) {
616 Ok(result) => {
617 self.metrics.record_hook_call(
618 &plugin_name,
619 HookType::Route,
620 start.elapsed(),
621 true,
622 );
623
624 match result {
625 RouteResult::Default => continue,
626 other => return other,
627 }
628 }
629 Err(e) => {
630 self.metrics.record_hook_call(
631 &plugin_name,
632 HookType::Route,
633 start.elapsed(),
634 false,
635 );
636 tracing::warn!(
637 plugin = %plugin_name,
638 error = %e,
639 "Route hook failed"
640 );
641 }
642 }
643 }
644 }
645
646 RouteResult::Default
647 }
648
649 pub fn list_plugins(&self) -> Vec<PluginInfo> {
651 self.plugins
652 .iter()
653 .map(|entry| {
654 let plugin = entry.value();
655 let stats = self.metrics.get_plugin_stats(&plugin.metadata.name);
656
657 PluginInfo {
658 name: plugin.metadata.name.clone(),
659 version: plugin.metadata.version.clone(),
660 description: plugin.metadata.description.clone(),
661 hooks: plugin.metadata.hooks.clone(),
662 state: plugin.state.clone(),
663 stats,
664 }
665 })
666 .collect()
667 }
668
669 pub fn get_metrics(&self) -> PluginManagerMetrics {
671 PluginManagerMetrics {
672 plugins_loaded: self.plugins.len(),
673 total_hook_calls: self.metrics.total_calls(),
674 total_errors: self.metrics.total_errors(),
675 avg_latency: self.metrics.avg_latency(),
676 plugins: self.list_plugins(),
677 }
678 }
679
680 pub fn check_updates(&self) -> Result<Vec<ReloadEvent>, PluginError> {
682 if let Some(ref reloader) = self.hot_reloader {
683 let events = reloader.check()?;
684
685 for event in &events {
686 match event {
687 ReloadEvent::Modified(name) => {
688 tracing::info!(plugin = %name, "Hot reloading plugin");
689 if let Err(e) = self.reload_plugin(name) {
690 tracing::error!(plugin = %name, error = %e, "Hot reload failed");
691 }
692 }
693 ReloadEvent::Removed(name) => {
694 tracing::info!(plugin = %name, "Plugin file removed, unloading");
695 if let Err(e) = self.unload_plugin(name) {
696 tracing::error!(plugin = %name, error = %e, "Unload failed");
697 }
698 }
699 ReloadEvent::Added(path) => {
700 tracing::info!(path = %path.display(), "New plugin detected, loading");
701 if let Err(e) = self.load_plugin(path) {
702 tracing::error!(path = %path.display(), error = %e, "Load failed");
703 }
704 }
705 }
706 }
707
708 Ok(events)
709 } else {
710 Ok(Vec::new())
711 }
712 }
713}
714
715#[derive(Debug, Clone)]
717pub struct AuthRequest {
718 pub headers: HashMap<String, String>,
720
721 pub username: Option<String>,
723
724 pub password: Option<String>,
726
727 pub client_ip: String,
729
730 pub database: Option<String>,
732}
733
734#[derive(Debug, Clone)]
736pub struct PluginInfo {
737 pub name: String,
739
740 pub version: String,
742
743 pub description: String,
745
746 pub hooks: Vec<HookType>,
748
749 pub state: PluginState,
751
752 pub stats: PluginStats,
754}
755
756#[derive(Debug, Clone)]
758pub struct PluginManagerMetrics {
759 pub plugins_loaded: usize,
761
762 pub total_hook_calls: u64,
764
765 pub total_errors: u64,
767
768 pub avg_latency: Duration,
770
771 pub plugins: Vec<PluginInfo>,
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778
779 #[test]
780 fn test_hook_type_export_name() {
781 assert_eq!(HookType::PreQuery.export_name(), "pre_query");
782 assert_eq!(HookType::Authenticate.export_name(), "authenticate");
783 assert_eq!(HookType::Route.export_name(), "route");
784 }
785
786 #[test]
787 fn test_hook_type_from_str() {
788 assert_eq!(HookType::from_str("pre_query"), Some(HookType::PreQuery));
789 assert_eq!(HookType::from_str("authenticate"), Some(HookType::Authenticate));
790 assert_eq!(HookType::from_str("unknown"), None);
791 }
792
793 #[test]
794 fn test_plugin_metadata_default() {
795 let meta = PluginMetadata::default();
796 assert!(meta.name.is_empty());
797 assert_eq!(meta.version, "0.0.0");
798 assert!(meta.hooks.is_empty());
799 }
800
801 #[test]
802 fn test_hook_context_default() {
803 let ctx = HookContext::default();
804 assert!(!ctx.request_id.is_empty());
805 assert!(ctx.client_id.is_none());
806 }
807
808 #[test]
809 fn test_pre_query_result() {
810 let result = PreQueryResult::Continue;
811 assert!(matches!(result, PreQueryResult::Continue));
812
813 let result = PreQueryResult::Block("blocked".to_string());
814 assert!(matches!(result, PreQueryResult::Block(_)));
815 }
816
817 #[test]
818 fn test_auth_result() {
819 let result = AuthResult::Denied("invalid".to_string());
820 assert!(matches!(result, AuthResult::Denied(_)));
821
822 let result = AuthResult::Defer;
823 assert!(matches!(result, AuthResult::Defer));
824 }
825
826 #[test]
827 fn test_route_result() {
828 let result = RouteResult::Default;
829 assert!(matches!(result, RouteResult::Default));
830
831 let result = RouteResult::Branch("test".to_string());
832 assert!(matches!(result, RouteResult::Branch(_)));
833 }
834
835 #[test]
836 fn test_identity_default() {
837 let identity = Identity::default();
838 assert!(identity.user_id.is_empty());
839 assert!(identity.roles.is_empty());
840 assert!(identity.tenant_id.is_none());
841 }
842
843 #[test]
848 fn test_execute_post_query_no_plugins_is_noop() {
849 let config = PluginRuntimeConfig::default();
850 let pm = PluginManager::new(config).expect("construct PluginManager");
851
852 let ctx = QueryContext {
853 query: "SELECT 1".to_string(),
854 normalized: "SELECT 1".to_string(),
855 tables: Vec::new(),
856 is_read_only: true,
857 hook_context: HookContext::default(),
858 };
859 let outcome = PostQueryOutcome {
860 success: true,
861 target_node: Some("primary".to_string()),
862 elapsed_us: 42,
863 response_bytes: 128,
864 error: None,
865 };
866
867 pm.execute_post_query(&ctx, &outcome);
869
870 let metrics = pm.get_metrics();
872 assert_eq!(metrics.plugins_loaded, 0);
873 assert_eq!(metrics.total_hook_calls, 0);
874 }
875
876 #[test]
879 fn test_execute_pre_query_no_plugins_returns_continue() {
880 let pm = PluginManager::new(PluginRuntimeConfig::default())
881 .expect("construct PluginManager");
882 let ctx = QueryContext {
883 query: "SELECT 1".to_string(),
884 normalized: "SELECT 1".to_string(),
885 tables: Vec::new(),
886 is_read_only: true,
887 hook_context: HookContext::default(),
888 };
889 assert!(matches!(pm.execute_pre_query(&ctx), PreQueryResult::Continue));
890 }
891
892 #[test]
895 fn test_post_query_outcome_serialisation() {
896 let outcome = PostQueryOutcome {
897 success: false,
898 target_node: None,
899 elapsed_us: 1234,
900 response_bytes: 0,
901 error: Some("backend timeout".to_string()),
902 };
903 let json = serde_json::to_string(&outcome).expect("serialise");
904 assert!(json.contains("\"success\":false"));
905 assert!(json.contains("\"elapsed_us\":1234"));
906 assert!(json.contains("backend timeout"));
907 }
908}