oxi/extensions/
wasm_hooks.rs1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PendingNotification {
14 pub extension: String,
16 pub message: String,
18 #[serde(default = "default_info")]
20 pub level: String,
21}
22fn default_info() -> String {
23 "info".to_string()
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct PendingMessage {
29 pub extension: String,
31 pub content: String,
33 #[serde(default = "default_user")]
35 pub role: String,
36}
37fn default_user() -> String {
38 "user".to_string()
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ToolCallHookResult {
44 #[serde(default)]
46 pub block: bool,
47 #[serde(default)]
49 pub reason: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolResultHookResult {
55 #[serde(default)]
57 pub content: Option<String>,
58 #[serde(default)]
60 pub is_error: Option<bool>,
61}
62
63pub struct WasmHookManager {
65 extensions: Arc<super::wasm::WasmExtensionManager>,
66 notifications: Arc<RwLock<Vec<PendingNotification>>>,
67 messages: Arc<RwLock<Vec<PendingMessage>>>,
68}
69
70impl WasmHookManager {
71 pub fn new(extensions: Arc<super::wasm::WasmExtensionManager>) -> Self {
73 Self {
74 extensions,
75 notifications: Arc::new(RwLock::new(Vec::new())),
76 messages: Arc::new(RwLock::new(Vec::new())),
77 }
78 }
79
80 pub fn fire_tool_call(
82 &self,
83 tool_name: &str,
84 tool_call_id: &str,
85 input: &serde_json::Value,
86 ) -> Option<ToolCallHookResult> {
87 let mut plugins = self.extensions.plugins.lock();
88 for (ext_name, plugin) in plugins.iter_mut() {
89 let event = serde_json::json!({
90 "event": "tool_call", "tool_name": tool_name,
91 "tool_call_id": tool_call_id, "input": input,
92 });
93 let Ok(s) = serde_json::to_string(&event) else {
94 continue;
95 };
96 if let Ok(output) = plugin.call::<&str, &str>("on_tool_call", &s) {
97 if let Ok(result) = serde_json::from_str::<ToolCallHookResult>(output) {
98 if result.block {
99 tracing::info!(
100 "Extension '{}' blocked tool_call '{}'",
101 ext_name,
102 tool_name
103 );
104 return Some(result);
105 }
106 }
107 }
108 }
109 None
110 }
111
112 pub fn fire_tool_result(
114 &self,
115 tool_name: &str,
116 tool_call_id: &str,
117 content: &str,
118 is_error: bool,
119 ) -> Option<ToolResultHookResult> {
120 let mut plugins = self.extensions.plugins.lock();
121 for (_, plugin) in plugins.iter_mut() {
122 let event = serde_json::json!({
123 "event": "tool_result", "tool_name": tool_name,
124 "tool_call_id": tool_call_id, "content": content, "is_error": is_error,
125 });
126 let Ok(s) = serde_json::to_string(&event) else {
127 continue;
128 };
129 if let Ok(output) = plugin.call::<&str, &str>("on_tool_result", &s) {
130 if let Ok(result) = serde_json::from_str::<ToolResultHookResult>(output) {
131 if result.content.is_some() || result.is_error.is_some() {
132 return Some(result);
133 }
134 }
135 }
136 }
137 None
138 }
139
140 pub fn fire_session_shutdown(&self, reason: &str) {
142 let mut plugins = self.extensions.plugins.lock();
143 for (_, plugin) in plugins.iter_mut() {
144 let event = serde_json::json!({ "event": "session_shutdown", "reason": reason });
145 let Ok(s) = serde_json::to_string(&event) else {
146 continue;
147 };
148 let _ = plugin.call::<&str, &str>("on_session_shutdown", &s);
149 }
150 }
151
152 pub fn fire_agent_event(&self, event_name: &str, event_data: &serde_json::Value) {
154 let mut plugins = self.extensions.plugins.lock();
155 for (_, plugin) in plugins.iter_mut() {
156 let payload = serde_json::json!({ "event": event_name, "data": event_data });
157 let Ok(s) = serde_json::to_string(&payload) else {
158 continue;
159 };
160 let _ = plugin.call::<&str, &str>("on_agent_event", &s);
161 }
162 }
163
164 pub fn drain_notifications(&self) -> Vec<PendingNotification> {
166 std::mem::take(&mut *self.notifications.write())
167 }
168
169 pub fn drain_messages(&self) -> Vec<PendingMessage> {
171 std::mem::take(&mut *self.messages.write())
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_hook_results() {
181 let r: ToolCallHookResult =
182 serde_json::from_str(r#"{"block":true,"reason":"Dangerous"}"#).unwrap();
183 assert!(r.block);
184 assert_eq!(r.reason.as_deref(), Some("Dangerous"));
185
186 let r: ToolResultHookResult = serde_json::from_str(r#"{"content":"Modified"}"#).unwrap();
187 assert_eq!(r.content.as_deref(), Some("Modified"));
188 }
189
190 #[test]
191 fn test_notification_serde() {
192 let n = PendingNotification {
193 extension: "ext".into(),
194 message: "Hello".into(),
195 level: "warning".into(),
196 };
197 let json = serde_json::to_string(&n).unwrap();
198 let back: PendingNotification = serde_json::from_str(&json).unwrap();
199 assert_eq!(back.extension, "ext");
200 }
201}