claude_code_acp/hooks/
callback_registry.rs

1//! Callback registry for tool use hooks
2//!
3//! Stores callbacks that are executed when receiving hooks from Claude Code.
4
5use dashmap::DashMap;
6
7/// Callback type for PostToolUse events
8pub type PostToolUseCallback = Box<
9    dyn Fn(String, serde_json::Value, serde_json::Value) -> futures::future::BoxFuture<'static, ()>
10        + Send
11        + Sync,
12>;
13
14/// Registry for tool use callbacks
15///
16/// Stores callbacks keyed by tool use ID that are called when
17/// receiving PostToolUse hooks.
18#[derive(Default)]
19pub struct HookCallbackRegistry {
20    /// Callbacks keyed by tool use ID
21    callbacks: DashMap<String, ToolUseCallbacks>,
22}
23
24/// Callbacks for a specific tool use
25struct ToolUseCallbacks {
26    /// Callback for PostToolUse hook
27    on_post_tool_use: Option<PostToolUseCallback>,
28}
29
30impl HookCallbackRegistry {
31    /// Create a new empty callback registry
32    pub fn new() -> Self {
33        Self {
34            callbacks: DashMap::new(),
35        }
36    }
37
38    /// Register a PostToolUse callback for a specific tool use
39    pub fn register_post_tool_use(&self, tool_use_id: String, callback: PostToolUseCallback) {
40        self.callbacks.insert(
41            tool_use_id,
42            ToolUseCallbacks {
43                on_post_tool_use: Some(callback),
44            },
45        );
46    }
47
48    /// Execute and remove the PostToolUse callback for a tool use
49    ///
50    /// Returns None if no callback was registered for this tool use ID.
51    pub async fn execute_post_tool_use(
52        &self,
53        tool_use_id: &str,
54        tool_input: serde_json::Value,
55        tool_response: serde_json::Value,
56    ) -> bool {
57        if let Some((_, callbacks)) = self.callbacks.remove(tool_use_id) {
58            if let Some(callback) = callbacks.on_post_tool_use {
59                callback(tool_use_id.to_string(), tool_input, tool_response).await;
60                return true;
61            }
62        }
63        false
64    }
65
66    /// Check if a callback is registered for a tool use ID
67    pub fn has_callback(&self, tool_use_id: &str) -> bool {
68        self.callbacks.contains_key(tool_use_id)
69    }
70
71    /// Remove a callback without executing it
72    pub fn remove(&self, tool_use_id: &str) {
73        self.callbacks.remove(tool_use_id);
74    }
75
76    /// Get the number of registered callbacks
77    pub fn len(&self) -> usize {
78        self.callbacks.len()
79    }
80
81    /// Check if the registry is empty
82    pub fn is_empty(&self) -> bool {
83        self.callbacks.is_empty()
84    }
85}
86
87impl std::fmt::Debug for HookCallbackRegistry {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("HookCallbackRegistry")
90            .field("count", &self.callbacks.len())
91            .finish()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use futures::FutureExt;
99    use std::sync::Arc;
100    use std::sync::atomic::{AtomicBool, Ordering};
101
102    #[tokio::test]
103    async fn test_register_and_execute() {
104        let registry = HookCallbackRegistry::new();
105        let was_called = Arc::new(AtomicBool::new(false));
106        let was_called_clone = was_called.clone();
107
108        let callback: PostToolUseCallback = Box::new(move |_id, _input, _response| {
109            let was_called = was_called_clone.clone();
110            async move {
111                was_called.store(true, Ordering::SeqCst);
112            }
113            .boxed()
114        });
115
116        registry.register_post_tool_use("test-id".to_string(), callback);
117        assert!(registry.has_callback("test-id"));
118        assert_eq!(registry.len(), 1);
119
120        let result = registry
121            .execute_post_tool_use(
122                "test-id",
123                serde_json::json!({"command": "ls"}),
124                serde_json::json!("output"),
125            )
126            .await;
127
128        assert!(result);
129        assert!(was_called.load(Ordering::SeqCst));
130        assert!(!registry.has_callback("test-id"));
131        assert!(registry.is_empty());
132    }
133
134    #[tokio::test]
135    async fn test_execute_nonexistent() {
136        let registry = HookCallbackRegistry::new();
137        let result = registry
138            .execute_post_tool_use("nonexistent", serde_json::json!({}), serde_json::json!({}))
139            .await;
140
141        assert!(!result);
142    }
143
144    #[test]
145    fn test_remove() {
146        let registry = HookCallbackRegistry::new();
147        let callback: PostToolUseCallback = Box::new(|_id, _input, _response| async {}.boxed());
148
149        registry.register_post_tool_use("test-id".to_string(), callback);
150        assert!(registry.has_callback("test-id"));
151
152        registry.remove("test-id");
153        assert!(!registry.has_callback("test-id"));
154    }
155}