1use std::collections::HashMap;
29use std::path::PathBuf;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32
33use parking_lot::RwLock;
34use wasmtime::{Engine, Instance, Linker, Module, Store, TypedFunc, Memory};
35
36use super::config::PluginRuntimeConfig;
37use super::host_functions::HostFunctionRegistry;
38use super::host_imports::{register_crypto_imports, register_kv_imports, KvBackend, StoreCtx};
39use super::sandbox::{PluginSandbox, SecurityPolicy, ResourceLimits};
40use super::{
41 AuthRequest, AuthResult, HookType, PluginMetadata, PreQueryResult,
42 QueryContext, RouteResult,
43};
44
45#[derive(Debug, Clone)]
47pub enum PluginError {
48 LoadError(String),
50
51 InstantiationError(String),
53
54 ExecutionError(String),
56
57 Timeout(String),
59
60 MemoryExceeded(String),
62
63 SecurityViolation(String),
65
66 InvalidManifest(String),
68
69 HookNotFound(String),
71
72 RuntimeError(String),
74}
75
76impl std::fmt::Display for PluginError {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 PluginError::LoadError(msg) => write!(f, "Load error: {}", msg),
80 PluginError::InstantiationError(msg) => write!(f, "Instantiation error: {}", msg),
81 PluginError::ExecutionError(msg) => write!(f, "Execution error: {}", msg),
82 PluginError::Timeout(msg) => write!(f, "Timeout: {}", msg),
83 PluginError::MemoryExceeded(msg) => write!(f, "Memory exceeded: {}", msg),
84 PluginError::SecurityViolation(msg) => write!(f, "Security violation: {}", msg),
85 PluginError::InvalidManifest(msg) => write!(f, "Invalid manifest: {}", msg),
86 PluginError::HookNotFound(msg) => write!(f, "Hook not found: {}", msg),
87 PluginError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg),
88 }
89 }
90}
91
92impl std::error::Error for PluginError {}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum PluginState {
97 Loading,
99
100 Running,
102
103 Paused,
105
106 Error(String),
108
109 Unloading,
111}
112
113pub struct LoadedPlugin {
115 pub metadata: PluginMetadata,
117
118 pub state: PluginState,
120
121 pub path: PathBuf,
123
124 module: Module,
127
128 sandbox: PluginSandbox,
130
131 instance_data: RwLock<PluginInstanceData>,
133
134 loaded_at: Instant,
136
137 last_invoked: RwLock<Option<Instant>>,
139
140 invocation_count: std::sync::atomic::AtomicU64,
142}
143
144struct PluginInstanceData {
146 memory_used: usize,
148
149 fuel_consumed: u64,
151
152 state: HashMap<String, Vec<u8>>,
154}
155
156impl LoadedPlugin {
157 pub fn new(
159 metadata: PluginMetadata,
160 path: PathBuf,
161 module: Module,
162 sandbox: PluginSandbox,
163 ) -> Self {
164 Self {
165 metadata,
166 state: PluginState::Running,
167 path,
168 module,
169 sandbox,
170 instance_data: RwLock::new(PluginInstanceData {
171 memory_used: 0,
172 fuel_consumed: 0,
173 state: HashMap::new(),
174 }),
175 loaded_at: Instant::now(),
176 last_invoked: RwLock::new(None),
177 invocation_count: std::sync::atomic::AtomicU64::new(0),
178 }
179 }
180
181 pub(crate) fn module(&self) -> &Module {
185 &self.module
186 }
187
188 pub fn memory_used(&self) -> usize {
190 self.instance_data.read().memory_used
191 }
192
193 pub fn invocation_count(&self) -> u64 {
195 self.invocation_count.load(std::sync::atomic::Ordering::Relaxed)
196 }
197
198 pub fn uptime(&self) -> Duration {
200 self.loaded_at.elapsed()
201 }
202
203 pub fn last_invoked(&self) -> Option<Instant> {
205 *self.last_invoked.read()
206 }
207
208 pub fn record_invocation(&self) {
210 self.invocation_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
211 *self.last_invoked.write() = Some(Instant::now());
212 }
213}
214
215pub struct WasmPluginRuntime {
217 config: PluginRuntimeConfig,
219
220 engine: Engine,
223
224 host_functions: Arc<HostFunctionRegistry>,
226
227 kv: KvBackend,
230
231 module_cache: RwLock<HashMap<PathBuf, Module>>,
234
235 default_policy: SecurityPolicy,
237
238 created_at: Instant,
240}
241
242impl WasmPluginRuntime {
243 pub fn new(config: &PluginRuntimeConfig) -> Result<Self, PluginError> {
245 let host_functions = Arc::new(HostFunctionRegistry::new());
246
247 let mut engine_config = wasmtime::Config::new();
248 if config.fuel_metering {
249 engine_config.consume_fuel(true);
250 }
251 engine_config.epoch_interruption(true);
254
255 let engine = Engine::new(&engine_config).map_err(|e| {
256 PluginError::RuntimeError(format!("wasmtime engine init: {}", e))
257 })?;
258
259 let default_policy = SecurityPolicy {
260 allowed_hosts: vec!["localhost".to_string()],
261 allowed_paths: vec![config.plugin_dir.clone()],
262 max_memory: config.memory_limit,
263 max_execution_time: config.timeout,
264 allow_network: false,
265 allow_filesystem: false,
266 };
267
268 Ok(Self {
269 config: config.clone(),
270 engine,
271 host_functions,
272 kv: KvBackend::new(),
273 module_cache: RwLock::new(HashMap::new()),
274 default_policy,
275 created_at: Instant::now(),
276 })
277 }
278
279 pub fn kv(&self) -> &KvBackend {
282 &self.kv
283 }
284
285 pub(crate) fn engine(&self) -> &Engine {
288 &self.engine
289 }
290
291 pub fn config(&self) -> &PluginRuntimeConfig {
295 &self.config
296 }
297
298 pub fn instantiate(
300 &self,
301 manifest: &super::loader::PluginManifest,
302 wasm_bytes: &[u8],
303 ) -> Result<LoadedPlugin, PluginError> {
304 if wasm_bytes.len() < 8 {
306 return Err(PluginError::LoadError("WASM module too small".to_string()));
307 }
308
309 if &wasm_bytes[0..4] != b"\x00asm" {
311 return Err(PluginError::LoadError("Invalid WASM magic number".to_string()));
312 }
313
314 let metadata = PluginMetadata {
316 name: manifest.name.clone(),
317 version: manifest.version.clone(),
318 description: manifest.description.clone(),
319 author: manifest.author.clone(),
320 hooks: manifest.hooks.clone(),
321 permissions: manifest.permissions.clone(),
322 min_memory: manifest.min_memory,
323 max_memory: manifest.max_memory.min(self.config.memory_limit),
324 };
325
326 let resource_limits = ResourceLimits {
328 max_memory: metadata.max_memory,
329 max_execution_time: self.config.timeout,
330 max_fuel: if self.config.fuel_metering {
331 Some(self.config.fuel_limit)
332 } else {
333 None
334 },
335 max_table_elements: 10000,
336 max_instances: 1,
337 };
338
339 let sandbox = PluginSandbox::new(
340 self.default_policy.clone(),
341 resource_limits,
342 manifest.permissions.clone(),
343 );
344
345 let module = Module::from_binary(&self.engine, wasm_bytes).map_err(|e| {
348 PluginError::InstantiationError(format!("wasmtime compile: {}", e))
349 })?;
350
351 {
353 let mut cache = self.module_cache.write();
354 cache.insert(manifest.path.clone(), module.clone());
355 }
356
357 Ok(LoadedPlugin::new(
358 metadata,
359 manifest.path.clone(),
360 module,
361 sandbox,
362 ))
363 }
364
365 pub fn call_hook(
381 &self,
382 plugin: &LoadedPlugin,
383 hook: HookType,
384 args: &[u8],
385 ) -> Result<Vec<u8>, PluginError> {
386 if !plugin.metadata.hooks.contains(&hook) {
388 return Err(PluginError::HookNotFound(format!(
389 "Plugin {} does not support hook {:?}",
390 plugin.metadata.name, hook
391 )));
392 }
393
394 if plugin.state != PluginState::Running {
396 return Err(PluginError::ExecutionError(format!(
397 "Plugin {} is not running (state: {:?})",
398 plugin.metadata.name, plugin.state
399 )));
400 }
401
402 plugin.record_invocation();
404
405 let store_ctx = StoreCtx {
409 plugin_name: plugin.metadata.name.clone(),
410 kv: self.kv.clone(),
411 };
412 let mut store: Store<StoreCtx> = Store::new(&self.engine, store_ctx);
413 if self.config.fuel_metering {
414 store.set_fuel(self.config.fuel_limit).map_err(|e| {
416 PluginError::RuntimeError(format!("set_fuel: {}", e))
417 })?;
418 }
419 store.set_epoch_deadline(u64::MAX);
427
428 let mut linker: Linker<StoreCtx> = Linker::new(&self.engine);
432 register_kv_imports(&mut linker)?;
433 register_crypto_imports(&mut linker)?;
434 let instance = linker
435 .instantiate(&mut store, &plugin.module)
436 .map_err(|e| {
437 PluginError::InstantiationError(format!(
438 "instantiate {}: {}",
439 plugin.metadata.name, e
440 ))
441 })?;
442
443 let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| {
444 PluginError::ExecutionError(format!(
445 "plugin {} does not export `memory`",
446 plugin.metadata.name
447 ))
448 })?;
449
450 let alloc = get_typed::<_, i32, i32>(&instance, &mut store, "alloc")?;
451 let dealloc = get_typed::<_, (i32, i32), ()>(&instance, &mut store, "dealloc")?;
452
453 let in_len = args.len() as i32;
456 let in_ptr = alloc.call(&mut store, in_len).map_err(|e| {
457 PluginError::ExecutionError(format!("alloc({}): {}", in_len, e))
458 })?;
459 if in_len > 0 {
460 write_memory(&memory, &mut store, in_ptr, args)?;
461 }
462
463 let export_name = hook.export_name();
466 let result_bytes = match get_typed::<_, (i32, i32), i64>(&instance, &mut store, export_name) {
467 Ok(hook_fn) => {
468 let packed = hook_fn.call(&mut store, (in_ptr, in_len)).map_err(|e| {
469 PluginError::ExecutionError(format!(
470 "hook {} call: {}",
471 export_name, e
472 ))
473 })?;
474 let out_ptr = (packed >> 32) as i32;
475 let out_len = (packed & 0xFFFF_FFFF) as i32;
476 if out_len > 0 {
477 let bytes = read_memory(&memory, &store, out_ptr, out_len)?;
478 let _ = dealloc.call(&mut store, (out_ptr, out_len));
480 bytes
481 } else {
482 Vec::new()
483 }
484 }
485 Err(_) => {
486 let observer = get_typed::<_, (i32, i32), ()>(
488 &instance,
489 &mut store,
490 export_name,
491 )?;
492 observer.call(&mut store, (in_ptr, in_len)).map_err(|e| {
493 PluginError::ExecutionError(format!(
494 "observer hook {} call: {}",
495 export_name, e
496 ))
497 })?;
498 Vec::new()
499 }
500 };
501
502 let _ = dealloc.call(&mut store, (in_ptr, in_len));
505
506 if self.config.fuel_metering {
508 if let Ok(remaining) = store.get_fuel() {
509 let consumed = self.config.fuel_limit.saturating_sub(remaining);
510 plugin.instance_data.write().fuel_consumed = consumed;
511 }
512 }
513 plugin.instance_data.write().memory_used =
514 (memory.data_size(&store)) as usize;
515
516 Ok(result_bytes)
517 }
518
519 pub fn call_pre_query(
521 &self,
522 plugin: &LoadedPlugin,
523 ctx: &QueryContext,
524 ) -> Result<PreQueryResult, PluginError> {
525 let args = serde_json::to_vec(ctx).map_err(|e| {
527 PluginError::ExecutionError(format!("Failed to serialize context: {}", e))
528 })?;
529
530 let result = self.call_hook(plugin, HookType::PreQuery, &args)?;
532
533 if result.is_empty() {
535 return Ok(PreQueryResult::Continue);
536 }
537
538 serde_json::from_slice(&result).map_err(|e| {
539 PluginError::ExecutionError(format!("Failed to deserialize result: {}", e))
540 })
541 }
542
543 pub fn call_authenticate(
545 &self,
546 plugin: &LoadedPlugin,
547 request: &AuthRequest,
548 ) -> Result<AuthResult, PluginError> {
549 let args = serde_json::to_vec(request).map_err(|e| {
551 PluginError::ExecutionError(format!("Failed to serialize request: {}", e))
552 })?;
553
554 let result = self.call_hook(plugin, HookType::Authenticate, &args)?;
556
557 if result.is_empty() {
559 return Ok(AuthResult::Defer);
560 }
561
562 serde_json::from_slice(&result).map_err(|e| {
563 PluginError::ExecutionError(format!("Failed to deserialize result: {}", e))
564 })
565 }
566
567 pub fn call_route(
569 &self,
570 plugin: &LoadedPlugin,
571 ctx: &QueryContext,
572 ) -> Result<RouteResult, PluginError> {
573 let args = serde_json::to_vec(ctx).map_err(|e| {
575 PluginError::ExecutionError(format!("Failed to serialize context: {}", e))
576 })?;
577
578 let result = self.call_hook(plugin, HookType::Route, &args)?;
580
581 if result.is_empty() {
583 return Ok(RouteResult::Default);
584 }
585
586 serde_json::from_slice(&result).map_err(|e| {
587 PluginError::ExecutionError(format!("Failed to deserialize result: {}", e))
588 })
589 }
590
591 pub fn stats(&self) -> RuntimeStats {
593 RuntimeStats {
594 uptime: self.created_at.elapsed(),
595 cached_modules: self.module_cache.read().len(),
596 fuel_metering_enabled: self.config.fuel_metering,
597 memory_limit: self.config.memory_limit,
598 timeout: self.config.timeout,
599 }
600 }
601}
602
603#[derive(Debug, Clone)]
605pub struct RuntimeStats {
606 pub uptime: Duration,
608
609 pub cached_modules: usize,
611
612 pub fuel_metering_enabled: bool,
614
615 pub memory_limit: usize,
617
618 pub timeout: Duration,
620}
621
622impl serde::Serialize for QueryContext {
627 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
628 where
629 S: serde::Serializer,
630 {
631 use serde::ser::SerializeStruct;
632 let mut state = serializer.serialize_struct("QueryContext", 5)?;
633 state.serialize_field("query", &self.query)?;
634 state.serialize_field("normalized", &self.normalized)?;
635 state.serialize_field("tables", &self.tables)?;
636 state.serialize_field("is_read_only", &self.is_read_only)?;
637 state.serialize_field("hook_context", &self.hook_context)?;
638 state.end()
639 }
640}
641
642fn get_typed<T, P, R>(
645 instance: &Instance,
646 store: &mut Store<T>,
647 name: &str,
648) -> Result<TypedFunc<P, R>, PluginError>
649where
650 P: wasmtime::WasmParams,
651 R: wasmtime::WasmResults,
652{
653 instance
654 .get_typed_func::<P, R>(store, name)
655 .map_err(|e| PluginError::ExecutionError(format!("export `{}`: {}", name, e)))
656}
657
658fn write_memory<T>(
661 memory: &Memory,
662 store: &mut Store<T>,
663 ptr: i32,
664 bytes: &[u8],
665) -> Result<(), PluginError> {
666 memory.write(store, ptr as usize, bytes).map_err(|e| {
667 PluginError::ExecutionError(format!("memory.write @ {}: {}", ptr, e))
668 })
669}
670
671fn read_memory<T>(
673 memory: &Memory,
674 store: &Store<T>,
675 ptr: i32,
676 len: i32,
677) -> Result<Vec<u8>, PluginError> {
678 if len <= 0 {
679 return Ok(Vec::new());
680 }
681 let mut out = vec![0u8; len as usize];
682 memory.read(store, ptr as usize, &mut out).map_err(|e| {
683 PluginError::ExecutionError(format!("memory.read @ {}+{}: {}", ptr, len, e))
684 })?;
685 Ok(out)
686}
687
688impl serde::Serialize for AuthRequest {
689 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
690 where
691 S: serde::Serializer,
692 {
693 use serde::ser::SerializeStruct;
694 let mut state = serializer.serialize_struct("AuthRequest", 5)?;
695 state.serialize_field("headers", &self.headers)?;
696 state.serialize_field("username", &self.username)?;
697 state.serialize_field("password", &self.password)?;
698 state.serialize_field("client_ip", &self.client_ip)?;
699 state.serialize_field("database", &self.database)?;
700 state.end()
701 }
702}
703
704impl<'de> serde::Deserialize<'de> for PreQueryResult {
705 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
706 where
707 D: serde::Deserializer<'de>,
708 {
709 #[derive(serde::Deserialize)]
710 struct Helper {
711 action: String,
712 #[serde(default)]
713 value: Option<String>,
714 #[serde(default)]
715 data: Option<Vec<u8>>,
716 }
717
718 let helper = Helper::deserialize(deserializer)?;
719 match helper.action.as_str() {
720 "continue" => Ok(PreQueryResult::Continue),
721 "rewrite" => Ok(PreQueryResult::Rewrite(
722 helper.value.unwrap_or_default(),
723 )),
724 "block" => Ok(PreQueryResult::Block(
725 helper.value.unwrap_or_default(),
726 )),
727 "cached" => Ok(PreQueryResult::Cached(
728 helper.data.unwrap_or_default(),
729 )),
730 _ => Ok(PreQueryResult::Continue),
731 }
732 }
733}
734
735impl<'de> serde::Deserialize<'de> for AuthResult {
736 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
737 where
738 D: serde::Deserializer<'de>,
739 {
740 #[derive(serde::Deserialize)]
741 struct Helper {
742 action: String,
743 #[serde(default)]
744 identity: Option<IdentityHelper>,
745 #[serde(default)]
746 message: Option<String>,
747 }
748
749 #[derive(serde::Deserialize)]
750 struct IdentityHelper {
751 user_id: String,
752 username: String,
753 #[serde(default)]
754 roles: Vec<String>,
755 #[serde(default)]
756 tenant_id: Option<String>,
757 }
758
759 let helper = Helper::deserialize(deserializer)?;
760 match helper.action.as_str() {
761 "success" => {
762 let id = helper.identity.unwrap_or(IdentityHelper {
763 user_id: String::new(),
764 username: String::new(),
765 roles: Vec::new(),
766 tenant_id: None,
767 });
768 Ok(AuthResult::Success(super::Identity {
769 user_id: id.user_id,
770 username: id.username,
771 roles: id.roles,
772 tenant_id: id.tenant_id,
773 claims: std::collections::HashMap::new(),
774 }))
775 }
776 "denied" => Ok(AuthResult::Denied(
777 helper.message.unwrap_or_default(),
778 )),
779 "defer" => Ok(AuthResult::Defer),
780 _ => Ok(AuthResult::Defer),
781 }
782 }
783}
784
785impl<'de> serde::Deserialize<'de> for RouteResult {
786 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
787 where
788 D: serde::Deserializer<'de>,
789 {
790 #[derive(serde::Deserialize)]
791 struct Helper {
792 action: String,
793 #[serde(default)]
794 target: Option<String>,
795 #[serde(default)]
796 reason: Option<String>,
797 }
798
799 let helper = Helper::deserialize(deserializer)?;
800 match helper.action.as_str() {
801 "default" => Ok(RouteResult::Default),
802 "node" => Ok(RouteResult::Node(helper.target.unwrap_or_default())),
803 "primary" => Ok(RouteResult::Primary),
804 "standby" => Ok(RouteResult::Standby),
805 "branch" => Ok(RouteResult::Branch(helper.target.unwrap_or_default())),
806 "block" => Ok(RouteResult::Block(
810 helper.reason.unwrap_or_else(|| "blocked by plugin".to_string()),
811 )),
812 _ => Ok(RouteResult::Default),
813 }
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820
821 fn build_test_module(engine: &Engine) -> Module {
830 const PAYLOAD: &[u8] = b"hello-from-wasm";
831 let payload_hex: String = PAYLOAD
832 .iter()
833 .map(|b| format!("\\{:02x}", b))
834 .collect();
835 let wat = format!(
836 r#"
837 (module
838 (memory (export "memory") 1)
839
840 ;; Trivial alloc: always returns offset 4096 (test inputs
841 ;; are tiny so non-overlapping reuse is fine here). Real
842 ;; plugins ship a real allocator; the runtime only cares
843 ;; that `alloc` returns a writable address.
844 (func (export "alloc") (param $size i32) (result i32)
845 (i32.const 4096))
846
847 (func (export "dealloc") (param $ptr i32) (param $size i32)
848 (drop (local.get $ptr))
849 (drop (local.get $size)))
850
851 ;; Result-returning hook: writes PAYLOAD at offset 1024 and
852 ;; returns (1024 << 32) | PAYLOAD.len.
853 (func (export "pre_query")
854 (param $in_ptr i32) (param $in_len i32) (result i64)
855 (i64.or
856 (i64.shl (i64.const 1024) (i64.const 32))
857 (i64.const {payload_len})))
858
859 ;; Observer hook: takes args, returns nothing.
860 (func (export "post_query")
861 (param $in_ptr i32) (param $in_len i32)
862 (drop (local.get $in_ptr)))
863
864 (data (i32.const 1024) "{payload}")
865 )
866 "#,
867 payload = payload_hex,
868 payload_len = PAYLOAD.len(),
869 );
870 let bytes = wat::parse_str(&wat).expect("wat parses");
871 Module::from_binary(engine, &bytes).expect("module compiles")
872 }
873
874 #[test]
875 fn test_plugin_error_display() {
876 let err = PluginError::LoadError("test".to_string());
877 assert!(err.to_string().contains("Load error"));
878
879 let err = PluginError::Timeout("plugin-a".to_string());
880 assert!(err.to_string().contains("Timeout"));
881 }
882
883 #[test]
884 fn test_plugin_state() {
885 assert_eq!(PluginState::Running, PluginState::Running);
886 assert_ne!(PluginState::Running, PluginState::Paused);
887 }
888
889 #[test]
890 fn test_runtime_creation() {
891 let config = PluginRuntimeConfig::default();
892 let runtime = WasmPluginRuntime::new(&config);
893 assert!(runtime.is_ok());
894 }
895
896 #[test]
897 fn test_runtime_stats() {
898 let config = PluginRuntimeConfig::default();
899 let runtime = WasmPluginRuntime::new(&config).unwrap();
900 let stats = runtime.stats();
901
902 assert_eq!(stats.cached_modules, 0);
903 assert!(stats.fuel_metering_enabled);
904 }
905
906 #[test]
907 fn test_loaded_plugin_invocation_count() {
908 let engine = Engine::default();
911 let module = build_test_module(&engine);
912 let metadata = PluginMetadata::default();
913 let sandbox = PluginSandbox::default();
914 let plugin = LoadedPlugin::new(
915 metadata,
916 PathBuf::from("/test/plugin.wasm"),
917 module,
918 sandbox,
919 );
920
921 assert_eq!(plugin.invocation_count(), 0);
922 plugin.record_invocation();
923 assert_eq!(plugin.invocation_count(), 1);
924 plugin.record_invocation();
925 assert_eq!(plugin.invocation_count(), 2);
926 }
927
928 #[test]
933 fn test_call_hook_roundtrips_real_wasm() {
934 let mut config = PluginRuntimeConfig::default();
935 config.fuel_metering = false;
938 let runtime = WasmPluginRuntime::new(&config).unwrap();
939
940 let module = build_test_module(runtime.engine());
941 let mut metadata = PluginMetadata::default();
942 metadata.name = "test-roundtrip".to_string();
943 metadata.hooks = vec![HookType::PreQuery, HookType::PostQuery];
944
945 let plugin = LoadedPlugin::new(
946 metadata,
947 PathBuf::from("/test/roundtrip.wasm"),
948 module,
949 PluginSandbox::default(),
950 );
951 let bytes = runtime
955 .call_hook(&plugin, HookType::PreQuery, b"ignored input")
956 .expect("pre_query call");
957 assert_eq!(bytes, b"hello-from-wasm");
958 assert_eq!(plugin.invocation_count(), 1);
959
960 let out = runtime
962 .call_hook(&plugin, HookType::PostQuery, b"some bytes")
963 .expect("post_query call");
964 assert!(out.is_empty());
965 assert_eq!(plugin.invocation_count(), 2);
966 }
967
968 #[test]
971 fn test_call_hook_rejects_undeclared_hook() {
972 let runtime = WasmPluginRuntime::new(&PluginRuntimeConfig::default()).unwrap();
973 let module = build_test_module(runtime.engine());
974 let mut metadata = PluginMetadata::default();
975 metadata.hooks = vec![]; let plugin = LoadedPlugin::new(
977 metadata,
978 PathBuf::from("/test/empty.wasm"),
979 module,
980 PluginSandbox::default(),
981 );
982 let err = runtime
983 .call_hook(&plugin, HookType::PreQuery, &[])
984 .unwrap_err();
985 assert!(matches!(err, PluginError::HookNotFound(_)));
986 }
987
988 #[test]
991 fn test_call_hook_missing_export_returns_error() {
992 let runtime = WasmPluginRuntime::new(&PluginRuntimeConfig::default()).unwrap();
993 let module = build_test_module(runtime.engine());
994 let mut metadata = PluginMetadata::default();
995 metadata.hooks = vec![HookType::Authenticate];
997 let plugin = LoadedPlugin::new(
998 metadata,
999 PathBuf::from("/test/missing.wasm"),
1000 module,
1001 PluginSandbox::default(),
1002 );
1003 let err = runtime
1004 .call_hook(&plugin, HookType::Authenticate, &[])
1005 .unwrap_err();
1006 assert!(matches!(err, PluginError::ExecutionError(_)));
1007 }
1008
1009 fn build_kv_test_module(engine: &Engine) -> Module {
1013 let wat = r#"
1017 (module
1018 (import "env" "kv_set"
1019 (func $kv_set (param i32 i32 i32 i32) (result i32)))
1020 (memory (export "memory") 1)
1021
1022 (data (i32.const 100) "key")
1023 (data (i32.const 200) "value")
1024
1025 (func (export "alloc") (param i32) (result i32) (i32.const 4096))
1026 (func (export "dealloc") (param i32 i32))
1027
1028 ;; pre_query: kv_set("key", "value"); return 0 (no payload).
1029 (func (export "pre_query")
1030 (param $in_ptr i32) (param $in_len i32) (result i64)
1031 (drop (call $kv_set
1032 (i32.const 100) (i32.const 3)
1033 (i32.const 200) (i32.const 5)))
1034 (i64.const 0))
1035 )
1036 "#;
1037 let bytes = wat::parse_str(wat).expect("kv-wat parses");
1038 Module::from_binary(engine, &bytes).expect("kv module compiles")
1039 }
1040
1041 #[test]
1045 fn test_host_kv_import_persists_value() {
1046 let mut config = PluginRuntimeConfig::default();
1047 config.fuel_metering = false;
1048 let runtime = WasmPluginRuntime::new(&config).unwrap();
1049
1050 let module = build_kv_test_module(runtime.engine());
1051 let mut metadata = PluginMetadata::default();
1052 metadata.name = "kv-test-plugin".to_string();
1053 metadata.hooks = vec![HookType::PreQuery];
1054
1055 let plugin = LoadedPlugin::new(
1056 metadata,
1057 PathBuf::from("/test/kv.wasm"),
1058 module,
1059 PluginSandbox::default(),
1060 );
1061
1062 assert_eq!(runtime.kv().get("kv-test-plugin", b"key"), None);
1064
1065 let _ = runtime
1066 .call_hook(&plugin, HookType::PreQuery, &[])
1067 .expect("pre_query call");
1068
1069 assert_eq!(
1072 runtime.kv().get("kv-test-plugin", b"key"),
1073 Some(b"value".to_vec())
1074 );
1075 assert_eq!(runtime.kv().get("other-plugin", b"key"), None);
1077 }
1078
1079 fn build_sha256_test_module(engine: &Engine) -> Module {
1087 let wat = r#"
1088 (module
1089 (import "env" "sha256_hex"
1090 (func $sha256_hex (param i32 i32 i32) (result i32)))
1091 (memory (export "memory") 1)
1092
1093 (data (i32.const 100) "abc")
1094
1095 (func (export "alloc") (param i32) (result i32) (i32.const 4096))
1096 (func (export "dealloc") (param i32 i32))
1097
1098 (func (export "pre_query")
1099 (param $in_ptr i32) (param $in_len i32) (result i64)
1100 (drop (call $sha256_hex
1101 (i32.const 100) (i32.const 3)
1102 (i32.const 200)))
1103 (i64.or
1104 (i64.shl (i64.const 200) (i64.const 32))
1105 (i64.const 64)))
1106 )
1107 "#;
1108 let bytes = wat::parse_str(wat).expect("sha256-wat parses");
1109 Module::from_binary(engine, &bytes).expect("sha256 module compiles")
1110 }
1111
1112 #[test]
1116 fn test_route_result_deserialises_block_with_reason() {
1117 let json = r#"{"action":"block","reason":"cross-region read forbidden"}"#;
1118 let r: RouteResult = serde_json::from_str(json).expect("block deserialises");
1119 match r {
1120 RouteResult::Block(reason) => {
1121 assert_eq!(reason, "cross-region read forbidden");
1122 }
1123 other => panic!("expected Block, got {:?}", other),
1124 }
1125 }
1126
1127 #[test]
1130 fn test_route_result_block_defaults_reason_when_missing() {
1131 let json = r#"{"action":"block"}"#;
1132 let r: RouteResult = serde_json::from_str(json).expect("block deserialises");
1133 match r {
1134 RouteResult::Block(reason) => {
1135 assert!(!reason.is_empty(), "default reason should not be empty");
1136 }
1137 other => panic!("expected Block, got {:?}", other),
1138 }
1139 }
1140
1141 #[test]
1145 fn test_host_sha256_import_matches_rfc_6234_vector() {
1146 const SHA256_OF_ABC: &[u8; 64] =
1147 b"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad";
1148
1149 let mut config = PluginRuntimeConfig::default();
1150 config.fuel_metering = false;
1151 let runtime = WasmPluginRuntime::new(&config).unwrap();
1152
1153 let module = build_sha256_test_module(runtime.engine());
1154 let mut metadata = PluginMetadata::default();
1155 metadata.name = "sha256-test-plugin".to_string();
1156 metadata.hooks = vec![HookType::PreQuery];
1157
1158 let plugin = LoadedPlugin::new(
1159 metadata,
1160 PathBuf::from("/test/sha256.wasm"),
1161 module,
1162 PluginSandbox::default(),
1163 );
1164
1165 let out = runtime
1166 .call_hook(&plugin, HookType::PreQuery, &[])
1167 .expect("pre_query call");
1168 assert_eq!(out.len(), 64);
1169 assert_eq!(&out[..], SHA256_OF_ABC);
1170 }
1171}