1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::debug;
11
12use super::types::{PluginCapability, WasmError, WasmResult, WasmValue};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[repr(u32)]
17pub enum LogLevel {
18 Trace = 0,
19 Debug = 1,
20 Info = 2,
21 Warn = 3,
22 Error = 4,
23}
24
25impl From<u32> for LogLevel {
26 fn from(v: u32) -> Self {
27 match v {
28 0 => LogLevel::Trace,
29 1 => LogLevel::Debug,
30 2 => LogLevel::Info,
31 3 => LogLevel::Warn,
32 _ => LogLevel::Error,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum MessageDirection {
40 Incoming,
41 Outgoing,
42}
43
44pub type HostCallback = Arc<dyn Fn(&str, Vec<WasmValue>) -> WasmResult<WasmValue> + Send + Sync>;
46
47pub struct HostContext {
49 pub plugin_id: String,
51 pub capabilities: Vec<PluginCapability>,
53 config: Arc<RwLock<HashMap<String, WasmValue>>>,
55 storage: Arc<RwLock<HashMap<String, Vec<u8>>>>,
57 message_queue: Arc<RwLock<Vec<HostMessage>>>,
59 custom_functions: HashMap<String, HostCallback>,
61 metrics: Arc<RwLock<HostMetrics>>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct HostMessage {
68 pub target: String,
69 pub payload: Vec<u8>,
70 pub timestamp: u64,
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct HostMetrics {
76 pub log_calls: u64,
77 pub config_reads: u64,
78 pub config_writes: u64,
79 pub messages_sent: u64,
80 pub tool_calls: u64,
81 pub storage_reads: u64,
82 pub storage_writes: u64,
83 pub total_execution_time_ns: u64,
84}
85
86impl HostContext {
87 pub fn new(plugin_id: &str, capabilities: Vec<PluginCapability>) -> Self {
88 Self {
89 plugin_id: plugin_id.to_string(),
90 capabilities,
91 config: Arc::new(RwLock::new(HashMap::new())),
92 storage: Arc::new(RwLock::new(HashMap::new())),
93 message_queue: Arc::new(RwLock::new(Vec::new())),
94 custom_functions: HashMap::new(),
95 metrics: Arc::new(RwLock::new(HostMetrics::default())),
96 }
97 }
98
99 pub fn has_capability(&self, cap: &PluginCapability) -> bool {
101 self.capabilities.contains(cap)
102 }
103
104 pub fn require_capability(&self, cap: &PluginCapability) -> WasmResult<()> {
106 if self.has_capability(cap) {
107 Ok(())
108 } else {
109 Err(WasmError::HostFunctionError(format!(
110 "Plugin {} lacks required capability: {}",
111 self.plugin_id, cap
112 )))
113 }
114 }
115
116 pub fn register_function(&mut self, name: &str, callback: HostCallback) {
118 self.custom_functions.insert(name.to_string(), callback);
119 }
120
121 pub async fn set_config(&self, key: &str, value: WasmValue) {
123 self.config.write().await.insert(key.to_string(), value);
124 }
125
126 pub async fn get_config(&self, key: &str) -> Option<WasmValue> {
128 self.config.read().await.get(key).cloned()
129 }
130
131 pub async fn drain_messages(&self) -> Vec<HostMessage> {
133 let mut queue = self.message_queue.write().await;
134 std::mem::take(&mut *queue)
135 }
136
137 pub async fn metrics(&self) -> HostMetrics {
139 self.metrics.read().await.clone()
140 }
141}
142
143#[async_trait]
145pub trait HostFunctions: Send + Sync {
146 async fn log(&self, level: LogLevel, message: &str) -> WasmResult<()>;
150
151 async fn get_config(&self, key: &str) -> WasmResult<Option<WasmValue>>;
155
156 async fn set_config(&self, key: &str, value: WasmValue) -> WasmResult<()>;
158
159 async fn send_message(&self, target: &str, payload: &[u8]) -> WasmResult<()>;
163
164 async fn call_tool(&self, tool_name: &str, args: WasmValue) -> WasmResult<WasmValue>;
168
169 async fn storage_get(&self, key: &str) -> WasmResult<Option<Vec<u8>>>;
173
174 async fn storage_set(&self, key: &str, value: &[u8]) -> WasmResult<()>;
176
177 async fn storage_delete(&self, key: &str) -> WasmResult<()>;
179
180 async fn now_ms(&self) -> WasmResult<u64>;
184
185 async fn random_bytes(&self, len: u32) -> WasmResult<Vec<u8>>;
187
188 async fn sleep_ms(&self, ms: u64) -> WasmResult<()>;
190
191 async fn call_custom(&self, name: &str, args: Vec<WasmValue>) -> WasmResult<WasmValue>;
195}
196
197pub struct DefaultHostFunctions {
199 context: Arc<HostContext>,
200}
201
202impl DefaultHostFunctions {
203 pub fn new(context: Arc<HostContext>) -> Self {
204 Self { context }
205 }
206
207 async fn inc_metric(&self, f: impl FnOnce(&mut HostMetrics)) {
208 let mut metrics = self.context.metrics.write().await;
209 f(&mut metrics);
210 }
211}
212
213#[async_trait]
214impl HostFunctions for DefaultHostFunctions {
215 async fn log(&self, level: LogLevel, message: &str) -> WasmResult<()> {
216 self.inc_metric(|m| m.log_calls += 1).await;
217
218 let plugin_id = &self.context.plugin_id;
219 match level {
220 LogLevel::Trace => tracing::trace!(plugin_id, "{}", message),
221 LogLevel::Debug => tracing::debug!(plugin_id, "{}", message),
222 LogLevel::Info => tracing::info!(plugin_id, "{}", message),
223 LogLevel::Warn => tracing::warn!(plugin_id, "{}", message),
224 LogLevel::Error => tracing::error!(plugin_id, "{}", message),
225 }
226 Ok(())
227 }
228
229 async fn get_config(&self, key: &str) -> WasmResult<Option<WasmValue>> {
230 self.context
231 .require_capability(&PluginCapability::ReadConfig)?;
232 self.inc_metric(|m| m.config_reads += 1).await;
233
234 Ok(self.context.get_config(key).await)
235 }
236
237 async fn set_config(&self, key: &str, value: WasmValue) -> WasmResult<()> {
238 self.context
239 .require_capability(&PluginCapability::WriteConfig)?;
240 self.inc_metric(|m| m.config_writes += 1).await;
241
242 self.context.set_config(key, value).await;
243 Ok(())
244 }
245
246 async fn send_message(&self, target: &str, payload: &[u8]) -> WasmResult<()> {
247 self.context
248 .require_capability(&PluginCapability::SendMessage)?;
249 self.inc_metric(|m| m.messages_sent += 1).await;
250
251 let msg = HostMessage {
252 target: target.to_string(),
253 payload: payload.to_vec(),
254 timestamp: std::time::SystemTime::now()
255 .duration_since(std::time::UNIX_EPOCH)
256 .unwrap_or_default()
257 .as_millis() as u64,
258 };
259
260 self.context.message_queue.write().await.push(msg);
261 debug!(
262 "Plugin {} sent message to {}",
263 self.context.plugin_id, target
264 );
265 Ok(())
266 }
267
268 async fn call_tool(&self, tool_name: &str, _args: WasmValue) -> WasmResult<WasmValue> {
269 self.context
270 .require_capability(&PluginCapability::CallTool)?;
271 self.inc_metric(|m| m.tool_calls += 1).await;
272
273 debug!(
276 "Plugin {} calling tool: {}",
277 self.context.plugin_id, tool_name
278 );
279 Ok(WasmValue::Map(HashMap::from([
280 ("tool".to_string(), WasmValue::String(tool_name.to_string())),
281 (
282 "status".to_string(),
283 WasmValue::String("success".to_string()),
284 ),
285 ])))
286 }
287
288 async fn storage_get(&self, key: &str) -> WasmResult<Option<Vec<u8>>> {
289 self.context
290 .require_capability(&PluginCapability::Storage)?;
291 self.inc_metric(|m| m.storage_reads += 1).await;
292
293 Ok(self.context.storage.read().await.get(key).cloned())
294 }
295
296 async fn storage_set(&self, key: &str, value: &[u8]) -> WasmResult<()> {
297 self.context
298 .require_capability(&PluginCapability::Storage)?;
299 self.inc_metric(|m| m.storage_writes += 1).await;
300
301 self.context
302 .storage
303 .write()
304 .await
305 .insert(key.to_string(), value.to_vec());
306 Ok(())
307 }
308
309 async fn storage_delete(&self, key: &str) -> WasmResult<()> {
310 self.context
311 .require_capability(&PluginCapability::Storage)?;
312
313 self.context.storage.write().await.remove(key);
314 Ok(())
315 }
316
317 async fn now_ms(&self) -> WasmResult<u64> {
318 Ok(std::time::SystemTime::now()
319 .duration_since(std::time::UNIX_EPOCH)
320 .unwrap_or_default()
321 .as_millis() as u64)
322 }
323
324 async fn random_bytes(&self, len: u32) -> WasmResult<Vec<u8>> {
325 self.context.require_capability(&PluginCapability::Random)?;
326
327 use rand::RngCore;
328 let mut bytes = vec![0u8; len as usize];
329 rand::thread_rng().fill_bytes(&mut bytes);
330 Ok(bytes)
331 }
332
333 async fn sleep_ms(&self, ms: u64) -> WasmResult<()> {
334 self.context.require_capability(&PluginCapability::Timer)?;
335
336 tokio::time::sleep(tokio::time::Duration::from_millis(ms)).await;
337 Ok(())
338 }
339
340 async fn call_custom(&self, name: &str, args: Vec<WasmValue>) -> WasmResult<WasmValue> {
341 if let Some(callback) = self.context.custom_functions.get(name) {
342 callback(name, args)
343 } else {
344 Err(WasmError::HostFunctionError(format!(
345 "Custom function not found: {}",
346 name
347 )))
348 }
349 }
350}
351
352pub struct HostFunctionRegistry {
354 functions: HashMap<String, HostFunctionInfo>,
356}
357
358#[derive(Debug, Clone)]
360pub struct HostFunctionInfo {
361 pub name: String,
362 pub module: String,
363 pub params: Vec<String>,
364 pub returns: Vec<String>,
365 pub required_capability: Option<PluginCapability>,
366}
367
368impl HostFunctionRegistry {
369 pub fn new() -> Self {
370 let mut registry = Self {
371 functions: HashMap::new(),
372 };
373 registry.register_builtin_functions();
374 registry
375 }
376
377 fn register_builtin_functions(&mut self) {
378 self.register("host_log", "env", vec!["i32", "i32", "i32"], vec![], None);
380
381 self.register(
383 "host_get_config",
384 "env",
385 vec!["i32", "i32", "i32"],
386 vec!["i32"],
387 Some(PluginCapability::ReadConfig),
388 );
389 self.register(
390 "host_set_config",
391 "env",
392 vec!["i32", "i32", "i32", "i32"],
393 vec!["i32"],
394 Some(PluginCapability::WriteConfig),
395 );
396
397 self.register(
399 "host_send_message",
400 "env",
401 vec!["i32", "i32", "i32", "i32"],
402 vec!["i32"],
403 Some(PluginCapability::SendMessage),
404 );
405
406 self.register(
408 "host_call_tool",
409 "env",
410 vec!["i32", "i32", "i32", "i32", "i32"],
411 vec!["i32"],
412 Some(PluginCapability::CallTool),
413 );
414
415 self.register(
417 "host_storage_get",
418 "env",
419 vec!["i32", "i32", "i32", "i32"],
420 vec!["i32"],
421 Some(PluginCapability::Storage),
422 );
423 self.register(
424 "host_storage_set",
425 "env",
426 vec!["i32", "i32", "i32", "i32"],
427 vec!["i32"],
428 Some(PluginCapability::Storage),
429 );
430
431 self.register("host_now_ms", "env", vec![], vec!["i64"], None);
433 self.register(
434 "host_random_bytes",
435 "env",
436 vec!["i32", "i32"],
437 vec!["i32"],
438 Some(PluginCapability::Random),
439 );
440 self.register(
441 "host_sleep_ms",
442 "env",
443 vec!["i64"],
444 vec![],
445 Some(PluginCapability::Timer),
446 );
447
448 self.register("host_alloc", "env", vec!["i32"], vec!["i32"], None);
450 self.register("host_free", "env", vec!["i32"], vec![], None);
451 }
452
453 fn register(
454 &mut self,
455 name: &str,
456 module: &str,
457 params: Vec<&str>,
458 returns: Vec<&str>,
459 required_capability: Option<PluginCapability>,
460 ) {
461 self.functions.insert(
462 name.to_string(),
463 HostFunctionInfo {
464 name: name.to_string(),
465 module: module.to_string(),
466 params: params.into_iter().map(String::from).collect(),
467 returns: returns.into_iter().map(String::from).collect(),
468 required_capability,
469 },
470 );
471 }
472
473 pub fn get(&self, name: &str) -> Option<&HostFunctionInfo> {
474 self.functions.get(name)
475 }
476
477 pub fn list(&self) -> Vec<&HostFunctionInfo> {
478 self.functions.values().collect()
479 }
480
481 pub fn has_function(&self, name: &str) -> bool {
482 self.functions.contains_key(name)
483 }
484}
485
486impl Default for HostFunctionRegistry {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_log_level_conversion() {
498 assert_eq!(LogLevel::from(0), LogLevel::Trace);
499 assert_eq!(LogLevel::from(2), LogLevel::Info);
500 assert_eq!(LogLevel::from(99), LogLevel::Error);
501 }
502
503 #[test]
504 fn test_host_context() {
505 let ctx = HostContext::new(
506 "test-plugin",
507 vec![PluginCapability::ReadConfig, PluginCapability::SendMessage],
508 );
509
510 assert!(ctx.has_capability(&PluginCapability::ReadConfig));
511 assert!(!ctx.has_capability(&PluginCapability::Storage));
512 }
513
514 #[tokio::test]
515 async fn test_host_context_config() {
516 let ctx = HostContext::new("test", vec![PluginCapability::ReadConfig]);
517
518 ctx.set_config("key1", WasmValue::String("value1".into()))
519 .await;
520
521 let val = ctx.get_config("key1").await;
522 assert_eq!(val, Some(WasmValue::String("value1".into())));
523 }
524
525 #[tokio::test]
526 async fn test_default_host_functions() {
527 let ctx = Arc::new(HostContext::new(
528 "test",
529 vec![PluginCapability::ReadConfig, PluginCapability::Timer],
530 ));
531 let host = DefaultHostFunctions::new(ctx.clone());
532
533 host.log(LogLevel::Info, "Test message").await.unwrap();
535
536 let ts = host.now_ms().await.unwrap();
538 assert!(ts > 0);
539
540 host.sleep_ms(1).await.unwrap();
542
543 let result = host.storage_get("key").await;
545 assert!(result.is_err());
546 }
547
548 #[test]
549 fn test_host_function_registry() {
550 let registry = HostFunctionRegistry::new();
551
552 assert!(registry.has_function("host_log"));
553 assert!(registry.has_function("host_get_config"));
554 assert!(!registry.has_function("nonexistent"));
555
556 let info = registry.get("host_log").unwrap();
557 assert_eq!(info.module, "env");
558 }
559}