use dashmap::DashMap;
pub type PostToolUseCallback = Box<
dyn Fn(String, serde_json::Value, serde_json::Value) -> futures::future::BoxFuture<'static, ()>
+ Send
+ Sync,
>;
#[derive(Default)]
pub struct HookCallbackRegistry {
callbacks: DashMap<String, ToolUseCallbacks>,
}
struct ToolUseCallbacks {
on_post_tool_use: Option<PostToolUseCallback>,
}
impl HookCallbackRegistry {
pub fn new() -> Self {
Self {
callbacks: DashMap::new(),
}
}
pub fn register_post_tool_use(&self, tool_use_id: String, callback: PostToolUseCallback) {
self.callbacks.insert(
tool_use_id,
ToolUseCallbacks {
on_post_tool_use: Some(callback),
},
);
}
pub async fn execute_post_tool_use(
&self,
tool_use_id: &str,
tool_input: serde_json::Value,
tool_response: serde_json::Value,
) -> bool {
if let Some((_, callbacks)) = self.callbacks.remove(tool_use_id) {
if let Some(callback) = callbacks.on_post_tool_use {
callback(tool_use_id.to_string(), tool_input, tool_response).await;
return true;
}
}
false
}
pub fn has_callback(&self, tool_use_id: &str) -> bool {
self.callbacks.contains_key(tool_use_id)
}
pub fn remove(&self, tool_use_id: &str) {
self.callbacks.remove(tool_use_id);
}
pub fn len(&self) -> usize {
self.callbacks.len()
}
pub fn is_empty(&self) -> bool {
self.callbacks.is_empty()
}
}
impl std::fmt::Debug for HookCallbackRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookCallbackRegistry")
.field("count", &self.callbacks.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::FutureExt;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[tokio::test]
async fn test_register_and_execute() {
let registry = HookCallbackRegistry::new();
let was_called = Arc::new(AtomicBool::new(false));
let was_called_clone = was_called.clone();
let callback: PostToolUseCallback = Box::new(move |_id, _input, _response| {
let was_called = was_called_clone.clone();
async move {
was_called.store(true, Ordering::SeqCst);
}
.boxed()
});
registry.register_post_tool_use("test-id".to_string(), callback);
assert!(registry.has_callback("test-id"));
assert_eq!(registry.len(), 1);
let result = registry
.execute_post_tool_use(
"test-id",
serde_json::json!({"command": "ls"}),
serde_json::json!("output"),
)
.await;
assert!(result);
assert!(was_called.load(Ordering::SeqCst));
assert!(!registry.has_callback("test-id"));
assert!(registry.is_empty());
}
#[tokio::test]
async fn test_execute_nonexistent() {
let registry = HookCallbackRegistry::new();
let result = registry
.execute_post_tool_use("nonexistent", serde_json::json!({}), serde_json::json!({}))
.await;
assert!(!result);
}
#[test]
fn test_remove() {
let registry = HookCallbackRegistry::new();
let callback: PostToolUseCallback = Box::new(|_id, _input, _response| async {}.boxed());
registry.register_post_tool_use("test-id".to_string(), callback);
assert!(registry.has_callback("test-id"));
registry.remove("test-id");
assert!(!registry.has_callback("test-id"));
}
}