Skip to main content

mofa_plugins/wasm_runtime/
host.rs

1//! Host Functions for WASM Plugins
2//!
3//! Defines the host functions that WASM plugins can call
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::debug;
11
12use super::types::{PluginCapability, WasmError, WasmResult, WasmValue};
13
14/// Log level for plugin logging
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[repr(u32)]
17pub enum LogLevel {
18    Trace = 0,
19    Debug = 1,
20    Info = 2,
21    Warn = 3,
22    Error = 4,
23}
24
25impl From<u32> for LogLevel {
26    fn from(v: u32) -> Self {
27        match v {
28            0 => LogLevel::Trace,
29            1 => LogLevel::Debug,
30            2 => LogLevel::Info,
31            3 => LogLevel::Warn,
32            _ => LogLevel::Error,
33        }
34    }
35}
36
37/// Message direction
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum MessageDirection {
40    Incoming,
41    Outgoing,
42}
43
44/// Host callback for custom functions
45pub type HostCallback = Arc<dyn Fn(&str, Vec<WasmValue>) -> WasmResult<WasmValue> + Send + Sync>;
46
47/// Host context provided to plugins
48pub struct HostContext {
49    /// Plugin ID
50    pub plugin_id: String,
51    /// Plugin capabilities
52    pub capabilities: Vec<PluginCapability>,
53    /// Configuration values
54    config: Arc<RwLock<HashMap<String, WasmValue>>>,
55    /// Storage backend
56    storage: Arc<RwLock<HashMap<String, Vec<u8>>>>,
57    /// Message queue (outgoing)
58    message_queue: Arc<RwLock<Vec<HostMessage>>>,
59    /// Custom host functions
60    custom_functions: HashMap<String, HostCallback>,
61    /// Execution metrics
62    metrics: Arc<RwLock<HostMetrics>>,
63}
64
65/// Host message
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct HostMessage {
68    pub target: String,
69    pub payload: Vec<u8>,
70    pub timestamp: u64,
71}
72
73/// Host execution metrics
74#[derive(Debug, Clone, Default)]
75pub struct HostMetrics {
76    pub log_calls: u64,
77    pub config_reads: u64,
78    pub config_writes: u64,
79    pub messages_sent: u64,
80    pub tool_calls: u64,
81    pub storage_reads: u64,
82    pub storage_writes: u64,
83    pub total_execution_time_ns: u64,
84}
85
86impl HostContext {
87    pub fn new(plugin_id: &str, capabilities: Vec<PluginCapability>) -> Self {
88        Self {
89            plugin_id: plugin_id.to_string(),
90            capabilities,
91            config: Arc::new(RwLock::new(HashMap::new())),
92            storage: Arc::new(RwLock::new(HashMap::new())),
93            message_queue: Arc::new(RwLock::new(Vec::new())),
94            custom_functions: HashMap::new(),
95            metrics: Arc::new(RwLock::new(HostMetrics::default())),
96        }
97    }
98
99    /// Check if plugin has a capability
100    pub fn has_capability(&self, cap: &PluginCapability) -> bool {
101        self.capabilities.contains(cap)
102    }
103
104    /// Require a capability, returning error if not present
105    pub fn require_capability(&self, cap: &PluginCapability) -> WasmResult<()> {
106        if self.has_capability(cap) {
107            Ok(())
108        } else {
109            Err(WasmError::HostFunctionError(format!(
110                "Plugin {} lacks required capability: {}",
111                self.plugin_id, cap
112            )))
113        }
114    }
115
116    /// Register a custom host function
117    pub fn register_function(&mut self, name: &str, callback: HostCallback) {
118        self.custom_functions.insert(name.to_string(), callback);
119    }
120
121    /// Set configuration value
122    pub async fn set_config(&self, key: &str, value: WasmValue) {
123        self.config.write().await.insert(key.to_string(), value);
124    }
125
126    /// Get configuration value
127    pub async fn get_config(&self, key: &str) -> Option<WasmValue> {
128        self.config.read().await.get(key).cloned()
129    }
130
131    /// Get all pending messages
132    pub async fn drain_messages(&self) -> Vec<HostMessage> {
133        let mut queue = self.message_queue.write().await;
134        std::mem::take(&mut *queue)
135    }
136
137    /// Get metrics
138    pub async fn metrics(&self) -> HostMetrics {
139        self.metrics.read().await.clone()
140    }
141}
142
143/// Host functions interface that plugins can call
144#[async_trait]
145pub trait HostFunctions: Send + Sync {
146    // === Logging ===
147
148    /// Log a message
149    async fn log(&self, level: LogLevel, message: &str) -> WasmResult<()>;
150
151    // === Configuration ===
152
153    /// Get configuration value
154    async fn get_config(&self, key: &str) -> WasmResult<Option<WasmValue>>;
155
156    /// Set configuration value (requires WriteConfig capability)
157    async fn set_config(&self, key: &str, value: WasmValue) -> WasmResult<()>;
158
159    // === Messaging ===
160
161    /// Send a message to another agent/plugin
162    async fn send_message(&self, target: &str, payload: &[u8]) -> WasmResult<()>;
163
164    // === Tools ===
165
166    /// Call a tool
167    async fn call_tool(&self, tool_name: &str, args: WasmValue) -> WasmResult<WasmValue>;
168
169    // === Storage ===
170
171    /// Get from storage
172    async fn storage_get(&self, key: &str) -> WasmResult<Option<Vec<u8>>>;
173
174    /// Set in storage
175    async fn storage_set(&self, key: &str, value: &[u8]) -> WasmResult<()>;
176
177    /// Delete from storage
178    async fn storage_delete(&self, key: &str) -> WasmResult<()>;
179
180    // === Utilities ===
181
182    /// Get current timestamp (milliseconds)
183    async fn now_ms(&self) -> WasmResult<u64>;
184
185    /// Generate random bytes
186    async fn random_bytes(&self, len: u32) -> WasmResult<Vec<u8>>;
187
188    /// Sleep for specified milliseconds
189    async fn sleep_ms(&self, ms: u64) -> WasmResult<()>;
190
191    // === Custom ===
192
193    /// Call a custom host function
194    async fn call_custom(&self, name: &str, args: Vec<WasmValue>) -> WasmResult<WasmValue>;
195}
196
197/// Default implementation of host functions
198pub struct DefaultHostFunctions {
199    context: Arc<HostContext>,
200}
201
202impl DefaultHostFunctions {
203    pub fn new(context: Arc<HostContext>) -> Self {
204        Self { context }
205    }
206
207    async fn inc_metric(&self, f: impl FnOnce(&mut HostMetrics)) {
208        let mut metrics = self.context.metrics.write().await;
209        f(&mut metrics);
210    }
211}
212
213#[async_trait]
214impl HostFunctions for DefaultHostFunctions {
215    async fn log(&self, level: LogLevel, message: &str) -> WasmResult<()> {
216        self.inc_metric(|m| m.log_calls += 1).await;
217
218        let plugin_id = &self.context.plugin_id;
219        match level {
220            LogLevel::Trace => tracing::trace!(plugin_id, "{}", message),
221            LogLevel::Debug => tracing::debug!(plugin_id, "{}", message),
222            LogLevel::Info => tracing::info!(plugin_id, "{}", message),
223            LogLevel::Warn => tracing::warn!(plugin_id, "{}", message),
224            LogLevel::Error => tracing::error!(plugin_id, "{}", message),
225        }
226        Ok(())
227    }
228
229    async fn get_config(&self, key: &str) -> WasmResult<Option<WasmValue>> {
230        self.context
231            .require_capability(&PluginCapability::ReadConfig)?;
232        self.inc_metric(|m| m.config_reads += 1).await;
233
234        Ok(self.context.get_config(key).await)
235    }
236
237    async fn set_config(&self, key: &str, value: WasmValue) -> WasmResult<()> {
238        self.context
239            .require_capability(&PluginCapability::WriteConfig)?;
240        self.inc_metric(|m| m.config_writes += 1).await;
241
242        self.context.set_config(key, value).await;
243        Ok(())
244    }
245
246    async fn send_message(&self, target: &str, payload: &[u8]) -> WasmResult<()> {
247        self.context
248            .require_capability(&PluginCapability::SendMessage)?;
249        self.inc_metric(|m| m.messages_sent += 1).await;
250
251        let msg = HostMessage {
252            target: target.to_string(),
253            payload: payload.to_vec(),
254            timestamp: std::time::SystemTime::now()
255                .duration_since(std::time::UNIX_EPOCH)
256                .unwrap_or_default()
257                .as_millis() as u64,
258        };
259
260        self.context.message_queue.write().await.push(msg);
261        debug!(
262            "Plugin {} sent message to {}",
263            self.context.plugin_id, target
264        );
265        Ok(())
266    }
267
268    async fn call_tool(&self, tool_name: &str, _args: WasmValue) -> WasmResult<WasmValue> {
269        self.context
270            .require_capability(&PluginCapability::CallTool)?;
271        self.inc_metric(|m| m.tool_calls += 1).await;
272
273        // For now, return a mock response
274        // In real implementation, this would call the actual tool executor
275        debug!(
276            "Plugin {} calling tool: {}",
277            self.context.plugin_id, tool_name
278        );
279        Ok(WasmValue::Map(HashMap::from([
280            ("tool".to_string(), WasmValue::String(tool_name.to_string())),
281            (
282                "status".to_string(),
283                WasmValue::String("success".to_string()),
284            ),
285        ])))
286    }
287
288    async fn storage_get(&self, key: &str) -> WasmResult<Option<Vec<u8>>> {
289        self.context
290            .require_capability(&PluginCapability::Storage)?;
291        self.inc_metric(|m| m.storage_reads += 1).await;
292
293        Ok(self.context.storage.read().await.get(key).cloned())
294    }
295
296    async fn storage_set(&self, key: &str, value: &[u8]) -> WasmResult<()> {
297        self.context
298            .require_capability(&PluginCapability::Storage)?;
299        self.inc_metric(|m| m.storage_writes += 1).await;
300
301        self.context
302            .storage
303            .write()
304            .await
305            .insert(key.to_string(), value.to_vec());
306        Ok(())
307    }
308
309    async fn storage_delete(&self, key: &str) -> WasmResult<()> {
310        self.context
311            .require_capability(&PluginCapability::Storage)?;
312
313        self.context.storage.write().await.remove(key);
314        Ok(())
315    }
316
317    async fn now_ms(&self) -> WasmResult<u64> {
318        Ok(std::time::SystemTime::now()
319            .duration_since(std::time::UNIX_EPOCH)
320            .unwrap_or_default()
321            .as_millis() as u64)
322    }
323
324    async fn random_bytes(&self, len: u32) -> WasmResult<Vec<u8>> {
325        self.context.require_capability(&PluginCapability::Random)?;
326
327        use rand::RngCore;
328        let mut bytes = vec![0u8; len as usize];
329        rand::thread_rng().fill_bytes(&mut bytes);
330        Ok(bytes)
331    }
332
333    async fn sleep_ms(&self, ms: u64) -> WasmResult<()> {
334        self.context.require_capability(&PluginCapability::Timer)?;
335
336        tokio::time::sleep(tokio::time::Duration::from_millis(ms)).await;
337        Ok(())
338    }
339
340    async fn call_custom(&self, name: &str, args: Vec<WasmValue>) -> WasmResult<WasmValue> {
341        if let Some(callback) = self.context.custom_functions.get(name) {
342            callback(name, args)
343        } else {
344            Err(WasmError::HostFunctionError(format!(
345                "Custom function not found: {}",
346                name
347            )))
348        }
349    }
350}
351
352/// Host function registry for wasmtime integration
353pub struct HostFunctionRegistry {
354    /// Function name -> (module, function signature)
355    functions: HashMap<String, HostFunctionInfo>,
356}
357
358/// Host function info
359#[derive(Debug, Clone)]
360pub struct HostFunctionInfo {
361    pub name: String,
362    pub module: String,
363    pub params: Vec<String>,
364    pub returns: Vec<String>,
365    pub required_capability: Option<PluginCapability>,
366}
367
368impl HostFunctionRegistry {
369    pub fn new() -> Self {
370        let mut registry = Self {
371            functions: HashMap::new(),
372        };
373        registry.register_builtin_functions();
374        registry
375    }
376
377    fn register_builtin_functions(&mut self) {
378        // Logging
379        self.register("host_log", "env", vec!["i32", "i32", "i32"], vec![], None);
380
381        // Configuration
382        self.register(
383            "host_get_config",
384            "env",
385            vec!["i32", "i32", "i32"],
386            vec!["i32"],
387            Some(PluginCapability::ReadConfig),
388        );
389        self.register(
390            "host_set_config",
391            "env",
392            vec!["i32", "i32", "i32", "i32"],
393            vec!["i32"],
394            Some(PluginCapability::WriteConfig),
395        );
396
397        // Messaging
398        self.register(
399            "host_send_message",
400            "env",
401            vec!["i32", "i32", "i32", "i32"],
402            vec!["i32"],
403            Some(PluginCapability::SendMessage),
404        );
405
406        // Tools
407        self.register(
408            "host_call_tool",
409            "env",
410            vec!["i32", "i32", "i32", "i32", "i32"],
411            vec!["i32"],
412            Some(PluginCapability::CallTool),
413        );
414
415        // Storage
416        self.register(
417            "host_storage_get",
418            "env",
419            vec!["i32", "i32", "i32", "i32"],
420            vec!["i32"],
421            Some(PluginCapability::Storage),
422        );
423        self.register(
424            "host_storage_set",
425            "env",
426            vec!["i32", "i32", "i32", "i32"],
427            vec!["i32"],
428            Some(PluginCapability::Storage),
429        );
430
431        // Utilities
432        self.register("host_now_ms", "env", vec![], vec!["i64"], None);
433        self.register(
434            "host_random_bytes",
435            "env",
436            vec!["i32", "i32"],
437            vec!["i32"],
438            Some(PluginCapability::Random),
439        );
440        self.register(
441            "host_sleep_ms",
442            "env",
443            vec!["i64"],
444            vec![],
445            Some(PluginCapability::Timer),
446        );
447
448        // Memory management
449        self.register("host_alloc", "env", vec!["i32"], vec!["i32"], None);
450        self.register("host_free", "env", vec!["i32"], vec![], None);
451    }
452
453    fn register(
454        &mut self,
455        name: &str,
456        module: &str,
457        params: Vec<&str>,
458        returns: Vec<&str>,
459        required_capability: Option<PluginCapability>,
460    ) {
461        self.functions.insert(
462            name.to_string(),
463            HostFunctionInfo {
464                name: name.to_string(),
465                module: module.to_string(),
466                params: params.into_iter().map(String::from).collect(),
467                returns: returns.into_iter().map(String::from).collect(),
468                required_capability,
469            },
470        );
471    }
472
473    pub fn get(&self, name: &str) -> Option<&HostFunctionInfo> {
474        self.functions.get(name)
475    }
476
477    pub fn list(&self) -> Vec<&HostFunctionInfo> {
478        self.functions.values().collect()
479    }
480
481    pub fn has_function(&self, name: &str) -> bool {
482        self.functions.contains_key(name)
483    }
484}
485
486impl Default for HostFunctionRegistry {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_log_level_conversion() {
498        assert_eq!(LogLevel::from(0), LogLevel::Trace);
499        assert_eq!(LogLevel::from(2), LogLevel::Info);
500        assert_eq!(LogLevel::from(99), LogLevel::Error);
501    }
502
503    #[test]
504    fn test_host_context() {
505        let ctx = HostContext::new(
506            "test-plugin",
507            vec![PluginCapability::ReadConfig, PluginCapability::SendMessage],
508        );
509
510        assert!(ctx.has_capability(&PluginCapability::ReadConfig));
511        assert!(!ctx.has_capability(&PluginCapability::Storage));
512    }
513
514    #[tokio::test]
515    async fn test_host_context_config() {
516        let ctx = HostContext::new("test", vec![PluginCapability::ReadConfig]);
517
518        ctx.set_config("key1", WasmValue::String("value1".into()))
519            .await;
520
521        let val = ctx.get_config("key1").await;
522        assert_eq!(val, Some(WasmValue::String("value1".into())));
523    }
524
525    #[tokio::test]
526    async fn test_default_host_functions() {
527        let ctx = Arc::new(HostContext::new(
528            "test",
529            vec![PluginCapability::ReadConfig, PluginCapability::Timer],
530        ));
531        let host = DefaultHostFunctions::new(ctx.clone());
532
533        // Test logging (always allowed)
534        host.log(LogLevel::Info, "Test message").await.unwrap();
535
536        // Test now_ms (always allowed)
537        let ts = host.now_ms().await.unwrap();
538        assert!(ts > 0);
539
540        // Test sleep (requires Timer)
541        host.sleep_ms(1).await.unwrap();
542
543        // Test storage should fail (no capability)
544        let result = host.storage_get("key").await;
545        assert!(result.is_err());
546    }
547
548    #[test]
549    fn test_host_function_registry() {
550        let registry = HostFunctionRegistry::new();
551
552        assert!(registry.has_function("host_log"));
553        assert!(registry.has_function("host_get_config"));
554        assert!(!registry.has_function("nonexistent"));
555
556        let info = registry.get("host_log").unwrap();
557        assert_eq!(info.module, "env");
558    }
559}