oxi/extensions/
wasm_hooks.rs1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PendingNotification {
14 pub extension: String,
16 pub message: String,
18 #[serde(default = "default_info")]
20 pub level: String,
21}
22fn default_info() -> String {
23 "info".to_string()
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct PendingMessage {
29 pub extension: String,
31 pub content: String,
33 #[serde(default = "default_user")]
35 pub role: String,
36}
37fn default_user() -> String {
38 "user".to_string()
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ToolCallHookResult {
44 #[serde(default)]
46 pub block: bool,
47 #[serde(default)]
49 pub reason: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolResultHookResult {
55 #[serde(default)]
57 pub content: Option<String>,
58 #[serde(default)]
60 pub is_error: Option<bool>,
61}
62
63pub struct WasmHookManager {
65 extensions: Arc<super::wasm::WasmExtensionManager>,
66 notifications: Arc<RwLock<Vec<PendingNotification>>>,
67 messages: Arc<RwLock<Vec<PendingMessage>>>,
68}
69
70impl WasmHookManager {
71 pub fn new(extensions: Arc<super::wasm::WasmExtensionManager>) -> Self {
73 Self {
74 extensions,
75 notifications: Arc::new(RwLock::new(Vec::new())),
76 messages: Arc::new(RwLock::new(Vec::new())),
77 }
78 }
79
80 pub fn fire_tool_call(
82 &self,
83 tool_name: &str,
84 tool_call_id: &str,
85 input: &serde_json::Value,
86 ) -> Option<ToolCallHookResult> {
87 let mut plugins = self.extensions.plugins.lock();
88 for (ext_name, plugin) in plugins.iter_mut() {
89 let event = serde_json::json!({
90 "event": "tool_call", "tool_name": tool_name,
91 "tool_call_id": tool_call_id, "input": input,
92 });
93 let Ok(s) = serde_json::to_string(&event) else {
94 continue;
95 };
96 if let Ok(output) = plugin.call::<&str, &str>("on_tool_call", &s)
97 && let Ok(result) = serde_json::from_str::<ToolCallHookResult>(output)
98 && result.block
99 {
100 tracing::info!("Extension '{}' blocked tool_call '{}'", ext_name, tool_name);
101 return Some(result);
102 }
103 }
104 None
105 }
106
107 pub fn fire_tool_result(
109 &self,
110 tool_name: &str,
111 tool_call_id: &str,
112 content: &str,
113 is_error: bool,
114 ) -> Option<ToolResultHookResult> {
115 let mut plugins = self.extensions.plugins.lock();
116 for (_, plugin) in plugins.iter_mut() {
117 let event = serde_json::json!({
118 "event": "tool_result", "tool_name": tool_name,
119 "tool_call_id": tool_call_id, "content": content, "is_error": is_error,
120 });
121 let Ok(s) = serde_json::to_string(&event) else {
122 continue;
123 };
124 if let Ok(output) = plugin.call::<&str, &str>("on_tool_result", &s)
125 && let Ok(result) = serde_json::from_str::<ToolResultHookResult>(output)
126 && (result.content.is_some() || result.is_error.is_some())
127 {
128 return Some(result);
129 }
130 }
131 None
132 }
133
134 pub fn fire_session_shutdown(&self, reason: &str) {
136 let mut plugins = self.extensions.plugins.lock();
137 for (_, plugin) in plugins.iter_mut() {
138 let event = serde_json::json!({ "event": "session_shutdown", "reason": reason });
139 let Ok(s) = serde_json::to_string(&event) else {
140 continue;
141 };
142 let _ = plugin.call::<&str, &str>("on_session_shutdown", &s);
143 }
144 }
145
146 pub fn fire_agent_event(&self, event_name: &str, event_data: &serde_json::Value) {
148 let mut plugins = self.extensions.plugins.lock();
149 for (_, plugin) in plugins.iter_mut() {
150 let payload = serde_json::json!({ "event": event_name, "data": event_data });
151 let Ok(s) = serde_json::to_string(&payload) else {
152 continue;
153 };
154 let _ = plugin.call::<&str, &str>("on_agent_event", &s);
155 }
156 }
157
158 pub fn drain_notifications(&self) -> Vec<PendingNotification> {
160 std::mem::take(&mut *self.notifications.write())
161 }
162
163 pub fn drain_messages(&self) -> Vec<PendingMessage> {
165 std::mem::take(&mut *self.messages.write())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_hook_results() {
175 let r: ToolCallHookResult =
176 serde_json::from_str(r#"{"block":true,"reason":"Dangerous"}"#).unwrap();
177 assert!(r.block);
178 assert_eq!(r.reason.as_deref(), Some("Dangerous"));
179
180 let r: ToolResultHookResult = serde_json::from_str(r#"{"content":"Modified"}"#).unwrap();
181 assert_eq!(r.content.as_deref(), Some("Modified"));
182 }
183
184 #[test]
185 fn test_notification_serde() {
186 let n = PendingNotification {
187 extension: "ext".into(),
188 message: "Hello".into(),
189 level: "warning".into(),
190 };
191 let json = serde_json::to_string(&n).unwrap();
192 let back: PendingNotification = serde_json::from_str(&json).unwrap();
193 assert_eq!(back.extension, "ext");
194 }
195}