Skip to main content

oxi/extensions/
wasm_hooks.rs

1//! WASM extension event hooks — types and hook manager.
2//!
3//! Provides bridge types for communication between WASM extensions and oxi's
4//! TUI/agent. Extensions call `oxi_notify` and `oxi_send_message` host functions
5//! that push to these queues; the TUI/agent loop drains them.
6
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11/// A pending notification from a WASM extension to be shown in TUI.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PendingNotification {
14    /// Extension name that sent the notification.
15    pub extension: String,
16    /// Notification message text.
17    pub message: String,
18    /// Notification level: "info", "warning", "error".
19    #[serde(default = "default_info")]
20    pub level: String,
21}
22fn default_info() -> String {
23    "info".to_string()
24}
25
26/// A pending message injection from a WASM extension.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct PendingMessage {
29    /// Extension name that sent the message.
30    pub extension: String,
31    /// Message content.
32    pub content: String,
33    /// Role for the message ("user", "system").
34    #[serde(default = "default_user")]
35    pub role: String,
36}
37fn default_user() -> String {
38    "user".to_string()
39}
40
41/// Result from a tool_call hook.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ToolCallHookResult {
44    /// Whether to block this tool call.
45    #[serde(default)]
46    pub block: bool,
47    /// Reason for blocking (shown to LLM).
48    #[serde(default)]
49    pub reason: Option<String>,
50}
51
52/// Result from a tool_result hook.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolResultHookResult {
55    /// Optional replacement content.
56    #[serde(default)]
57    pub content: Option<String>,
58    /// Whether to override the error flag.
59    #[serde(default)]
60    pub is_error: Option<bool>,
61}
62
63/// Manages hook dispatching to WASM extensions.
64pub struct WasmHookManager {
65    extensions: Arc<super::wasm::WasmExtensionManager>,
66    notifications: Arc<RwLock<Vec<PendingNotification>>>,
67    messages: Arc<RwLock<Vec<PendingMessage>>>,
68}
69
70impl WasmHookManager {
71    /// Create a new hook manager.
72    pub fn new(extensions: Arc<super::wasm::WasmExtensionManager>) -> Self {
73        Self {
74            extensions,
75            notifications: Arc::new(RwLock::new(Vec::new())),
76            messages: Arc::new(RwLock::new(Vec::new())),
77        }
78    }
79
80    /// Fire `on_tool_call` for all extensions. Returns first block result.
81    pub fn fire_tool_call(
82        &self,
83        tool_name: &str,
84        tool_call_id: &str,
85        input: &serde_json::Value,
86    ) -> Option<ToolCallHookResult> {
87        let mut plugins = self.extensions.plugins.lock();
88        for (ext_name, plugin) in plugins.iter_mut() {
89            let event = serde_json::json!({
90                "event": "tool_call", "tool_name": tool_name,
91                "tool_call_id": tool_call_id, "input": input,
92            });
93            let Ok(s) = serde_json::to_string(&event) else {
94                continue;
95            };
96            if let Ok(output) = plugin.call::<&str, &str>("on_tool_call", &s) {
97                if let Ok(result) = serde_json::from_str::<ToolCallHookResult>(output) {
98                    if result.block {
99                        tracing::info!(
100                            "Extension '{}' blocked tool_call '{}'",
101                            ext_name,
102                            tool_name
103                        );
104                        return Some(result);
105                    }
106                }
107            }
108        }
109        None
110    }
111
112    /// Fire `on_tool_result` for all extensions. Returns first modification.
113    pub fn fire_tool_result(
114        &self,
115        tool_name: &str,
116        tool_call_id: &str,
117        content: &str,
118        is_error: bool,
119    ) -> Option<ToolResultHookResult> {
120        let mut plugins = self.extensions.plugins.lock();
121        for (_, plugin) in plugins.iter_mut() {
122            let event = serde_json::json!({
123                "event": "tool_result", "tool_name": tool_name,
124                "tool_call_id": tool_call_id, "content": content, "is_error": is_error,
125            });
126            let Ok(s) = serde_json::to_string(&event) else {
127                continue;
128            };
129            if let Ok(output) = plugin.call::<&str, &str>("on_tool_result", &s) {
130                if let Ok(result) = serde_json::from_str::<ToolResultHookResult>(output) {
131                    if result.content.is_some() || result.is_error.is_some() {
132                        return Some(result);
133                    }
134                }
135            }
136        }
137        None
138    }
139
140    /// Fire `on_session_shutdown` for all extensions.
141    pub fn fire_session_shutdown(&self, reason: &str) {
142        let mut plugins = self.extensions.plugins.lock();
143        for (_, plugin) in plugins.iter_mut() {
144            let event = serde_json::json!({ "event": "session_shutdown", "reason": reason });
145            let Ok(s) = serde_json::to_string(&event) else {
146                continue;
147            };
148            let _ = plugin.call::<&str, &str>("on_session_shutdown", &s);
149        }
150    }
151
152    /// Fire `on_agent_event` for all extensions (fire-and-forget).
153    pub fn fire_agent_event(&self, event_name: &str, event_data: &serde_json::Value) {
154        let mut plugins = self.extensions.plugins.lock();
155        for (_, plugin) in plugins.iter_mut() {
156            let payload = serde_json::json!({ "event": event_name, "data": event_data });
157            let Ok(s) = serde_json::to_string(&payload) else {
158                continue;
159            };
160            let _ = plugin.call::<&str, &str>("on_agent_event", &s);
161        }
162    }
163
164    /// Drain all pending notifications. Called by TUI.
165    pub fn drain_notifications(&self) -> Vec<PendingNotification> {
166        std::mem::take(&mut *self.notifications.write())
167    }
168
169    /// Drain all pending messages. Called by agent loop.
170    pub fn drain_messages(&self) -> Vec<PendingMessage> {
171        std::mem::take(&mut *self.messages.write())
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_hook_results() {
181        let r: ToolCallHookResult =
182            serde_json::from_str(r#"{"block":true,"reason":"Dangerous"}"#).unwrap();
183        assert!(r.block);
184        assert_eq!(r.reason.as_deref(), Some("Dangerous"));
185
186        let r: ToolResultHookResult = serde_json::from_str(r#"{"content":"Modified"}"#).unwrap();
187        assert_eq!(r.content.as_deref(), Some("Modified"));
188    }
189
190    #[test]
191    fn test_notification_serde() {
192        let n = PendingNotification {
193            extension: "ext".into(),
194            message: "Hello".into(),
195            level: "warning".into(),
196        };
197        let json = serde_json::to_string(&n).unwrap();
198        let back: PendingNotification = serde_json::from_str(&json).unwrap();
199        assert_eq!(back.extension, "ext");
200    }
201}