Skip to main content

oxi/extensions/
wasm_hooks.rs

1//! WASM extension event hooks — types and hook manager.
2//!
3//! Provides bridge types for communication between WASM extensions and oxi's
4//! TUI/agent. Extensions call `oxi_notify` and `oxi_send_message` host functions
5//! that push to these queues; the TUI/agent loop drains them.
6
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11/// A pending notification from a WASM extension to be shown in TUI.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PendingNotification {
14    /// Extension name that sent the notification.
15    pub extension: String,
16    /// Notification message text.
17    pub message: String,
18    /// Notification level: "info", "warning", "error".
19    #[serde(default = "default_info")]
20    pub level: String,
21}
22fn default_info() -> String {
23    "info".to_string()
24}
25
26/// A pending message injection from a WASM extension.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct PendingMessage {
29    /// Extension name that sent the message.
30    pub extension: String,
31    /// Message content.
32    pub content: String,
33    /// Role for the message ("user", "system").
34    #[serde(default = "default_user")]
35    pub role: String,
36}
37fn default_user() -> String {
38    "user".to_string()
39}
40
41/// Result from a tool_call hook.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ToolCallHookResult {
44    /// Whether to block this tool call.
45    #[serde(default)]
46    pub block: bool,
47    /// Reason for blocking (shown to LLM).
48    #[serde(default)]
49    pub reason: Option<String>,
50}
51
52/// Result from a tool_result hook.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolResultHookResult {
55    /// Optional replacement content.
56    #[serde(default)]
57    pub content: Option<String>,
58    /// Whether to override the error flag.
59    #[serde(default)]
60    pub is_error: Option<bool>,
61}
62
63/// Manages hook dispatching to WASM extensions.
64pub struct WasmHookManager {
65    extensions: Arc<super::wasm::WasmExtensionManager>,
66    notifications: Arc<RwLock<Vec<PendingNotification>>>,
67    messages: Arc<RwLock<Vec<PendingMessage>>>,
68}
69
70impl WasmHookManager {
71    /// Create a new hook manager.
72    pub fn new(extensions: Arc<super::wasm::WasmExtensionManager>) -> Self {
73        Self {
74            extensions,
75            notifications: Arc::new(RwLock::new(Vec::new())),
76            messages: Arc::new(RwLock::new(Vec::new())),
77        }
78    }
79
80    /// Fire `on_tool_call` for all extensions. Returns first block result.
81    pub fn fire_tool_call(
82        &self,
83        tool_name: &str,
84        tool_call_id: &str,
85        input: &serde_json::Value,
86    ) -> Option<ToolCallHookResult> {
87        let mut plugins = self.extensions.plugins.lock();
88        for (ext_name, plugin) in plugins.iter_mut() {
89            let event = serde_json::json!({
90                "event": "tool_call", "tool_name": tool_name,
91                "tool_call_id": tool_call_id, "input": input,
92            });
93            let Ok(s) = serde_json::to_string(&event) else {
94                continue;
95            };
96            if let Ok(output) = plugin.call::<&str, &str>("on_tool_call", &s)
97                && let Ok(result) = serde_json::from_str::<ToolCallHookResult>(output)
98                && result.block
99            {
100                tracing::info!("Extension '{}' blocked tool_call '{}'", ext_name, tool_name);
101                return Some(result);
102            }
103        }
104        None
105    }
106
107    /// Fire `on_tool_result` for all extensions. Returns first modification.
108    pub fn fire_tool_result(
109        &self,
110        tool_name: &str,
111        tool_call_id: &str,
112        content: &str,
113        is_error: bool,
114    ) -> Option<ToolResultHookResult> {
115        let mut plugins = self.extensions.plugins.lock();
116        for (_, plugin) in plugins.iter_mut() {
117            let event = serde_json::json!({
118                "event": "tool_result", "tool_name": tool_name,
119                "tool_call_id": tool_call_id, "content": content, "is_error": is_error,
120            });
121            let Ok(s) = serde_json::to_string(&event) else {
122                continue;
123            };
124            if let Ok(output) = plugin.call::<&str, &str>("on_tool_result", &s)
125                && let Ok(result) = serde_json::from_str::<ToolResultHookResult>(output)
126                && (result.content.is_some() || result.is_error.is_some())
127            {
128                return Some(result);
129            }
130        }
131        None
132    }
133
134    /// Fire `on_session_shutdown` for all extensions.
135    pub fn fire_session_shutdown(&self, reason: &str) {
136        let mut plugins = self.extensions.plugins.lock();
137        for (_, plugin) in plugins.iter_mut() {
138            let event = serde_json::json!({ "event": "session_shutdown", "reason": reason });
139            let Ok(s) = serde_json::to_string(&event) else {
140                continue;
141            };
142            let _ = plugin.call::<&str, &str>("on_session_shutdown", &s);
143        }
144    }
145
146    /// Fire `on_agent_event` for all extensions (fire-and-forget).
147    pub fn fire_agent_event(&self, event_name: &str, event_data: &serde_json::Value) {
148        let mut plugins = self.extensions.plugins.lock();
149        for (_, plugin) in plugins.iter_mut() {
150            let payload = serde_json::json!({ "event": event_name, "data": event_data });
151            let Ok(s) = serde_json::to_string(&payload) else {
152                continue;
153            };
154            let _ = plugin.call::<&str, &str>("on_agent_event", &s);
155        }
156    }
157
158    /// Drain all pending notifications. Called by TUI.
159    pub fn drain_notifications(&self) -> Vec<PendingNotification> {
160        std::mem::take(&mut *self.notifications.write())
161    }
162
163    /// Drain all pending messages. Called by agent loop.
164    pub fn drain_messages(&self) -> Vec<PendingMessage> {
165        std::mem::take(&mut *self.messages.write())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_hook_results() {
175        let r: ToolCallHookResult =
176            serde_json::from_str(r#"{"block":true,"reason":"Dangerous"}"#).unwrap();
177        assert!(r.block);
178        assert_eq!(r.reason.as_deref(), Some("Dangerous"));
179
180        let r: ToolResultHookResult = serde_json::from_str(r#"{"content":"Modified"}"#).unwrap();
181        assert_eq!(r.content.as_deref(), Some("Modified"));
182    }
183
184    #[test]
185    fn test_notification_serde() {
186        let n = PendingNotification {
187            extension: "ext".into(),
188            message: "Hello".into(),
189            level: "warning".into(),
190        };
191        let json = serde_json::to_string(&n).unwrap();
192        let back: PendingNotification = serde_json::from_str(&json).unwrap();
193        assert_eq!(back.extension, "ext");
194    }
195}