use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingNotification {
pub extension: String,
pub message: String,
#[serde(default = "default_info")]
pub level: String,
}
fn default_info() -> String {
"info".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingMessage {
pub extension: String,
pub content: String,
#[serde(default = "default_user")]
pub role: String,
}
fn default_user() -> String {
"user".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallHookResult {
#[serde(default)]
pub block: bool,
#[serde(default)]
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultHookResult {
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub is_error: Option<bool>,
}
pub struct WasmHookManager {
extensions: Arc<super::wasm::WasmExtensionManager>,
notifications: Arc<RwLock<Vec<PendingNotification>>>,
messages: Arc<RwLock<Vec<PendingMessage>>>,
}
impl WasmHookManager {
pub fn new(extensions: Arc<super::wasm::WasmExtensionManager>) -> Self {
Self {
extensions,
notifications: Arc::new(RwLock::new(Vec::new())),
messages: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn fire_tool_call(
&self,
tool_name: &str,
tool_call_id: &str,
input: &serde_json::Value,
) -> Option<ToolCallHookResult> {
let mut plugins = self.extensions.plugins.lock();
for (ext_name, plugin) in plugins.iter_mut() {
let event = serde_json::json!({
"event": "tool_call", "tool_name": tool_name,
"tool_call_id": tool_call_id, "input": input,
});
let Ok(s) = serde_json::to_string(&event) else {
continue;
};
if let Ok(output) = plugin.call::<&str, &str>("on_tool_call", &s) {
if let Ok(result) = serde_json::from_str::<ToolCallHookResult>(output) {
if result.block {
tracing::info!(
"Extension '{}' blocked tool_call '{}'",
ext_name,
tool_name
);
return Some(result);
}
}
}
}
None
}
pub fn fire_tool_result(
&self,
tool_name: &str,
tool_call_id: &str,
content: &str,
is_error: bool,
) -> Option<ToolResultHookResult> {
let mut plugins = self.extensions.plugins.lock();
for (_, plugin) in plugins.iter_mut() {
let event = serde_json::json!({
"event": "tool_result", "tool_name": tool_name,
"tool_call_id": tool_call_id, "content": content, "is_error": is_error,
});
let Ok(s) = serde_json::to_string(&event) else {
continue;
};
if let Ok(output) = plugin.call::<&str, &str>("on_tool_result", &s) {
if let Ok(result) = serde_json::from_str::<ToolResultHookResult>(output) {
if result.content.is_some() || result.is_error.is_some() {
return Some(result);
}
}
}
}
None
}
pub fn fire_session_shutdown(&self, reason: &str) {
let mut plugins = self.extensions.plugins.lock();
for (_, plugin) in plugins.iter_mut() {
let event = serde_json::json!({ "event": "session_shutdown", "reason": reason });
let Ok(s) = serde_json::to_string(&event) else {
continue;
};
let _ = plugin.call::<&str, &str>("on_session_shutdown", &s);
}
}
pub fn fire_agent_event(&self, event_name: &str, event_data: &serde_json::Value) {
let mut plugins = self.extensions.plugins.lock();
for (_, plugin) in plugins.iter_mut() {
let payload = serde_json::json!({ "event": event_name, "data": event_data });
let Ok(s) = serde_json::to_string(&payload) else {
continue;
};
let _ = plugin.call::<&str, &str>("on_agent_event", &s);
}
}
pub fn drain_notifications(&self) -> Vec<PendingNotification> {
std::mem::take(&mut *self.notifications.write())
}
pub fn drain_messages(&self) -> Vec<PendingMessage> {
std::mem::take(&mut *self.messages.write())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_results() {
let r: ToolCallHookResult =
serde_json::from_str(r#"{"block":true,"reason":"Dangerous"}"#).unwrap();
assert!(r.block);
assert_eq!(r.reason.as_deref(), Some("Dangerous"));
let r: ToolResultHookResult = serde_json::from_str(r#"{"content":"Modified"}"#).unwrap();
assert_eq!(r.content.as_deref(), Some("Modified"));
}
#[test]
fn test_notification_serde() {
let n = PendingNotification {
extension: "ext".into(),
message: "Hello".into(),
level: "warning".into(),
};
let json = serde_json::to_string(&n).unwrap();
let back: PendingNotification = serde_json::from_str(&json).unwrap();
assert_eq!(back.extension, "ext");
}
}