claude_code_acp/hooks/
callback_registry.rs1use dashmap::DashMap;
6
7pub type PostToolUseCallback = Box<
9 dyn Fn(String, serde_json::Value, serde_json::Value) -> futures::future::BoxFuture<'static, ()>
10 + Send
11 + Sync,
12>;
13
14#[derive(Default)]
19pub struct HookCallbackRegistry {
20 callbacks: DashMap<String, ToolUseCallbacks>,
22}
23
24struct ToolUseCallbacks {
26 on_post_tool_use: Option<PostToolUseCallback>,
28}
29
30impl HookCallbackRegistry {
31 pub fn new() -> Self {
33 Self {
34 callbacks: DashMap::new(),
35 }
36 }
37
38 pub fn register_post_tool_use(&self, tool_use_id: String, callback: PostToolUseCallback) {
40 self.callbacks.insert(
41 tool_use_id,
42 ToolUseCallbacks {
43 on_post_tool_use: Some(callback),
44 },
45 );
46 }
47
48 pub async fn execute_post_tool_use(
52 &self,
53 tool_use_id: &str,
54 tool_input: serde_json::Value,
55 tool_response: serde_json::Value,
56 ) -> bool {
57 if let Some((_, callbacks)) = self.callbacks.remove(tool_use_id) {
58 if let Some(callback) = callbacks.on_post_tool_use {
59 callback(tool_use_id.to_string(), tool_input, tool_response).await;
60 return true;
61 }
62 }
63 false
64 }
65
66 pub fn has_callback(&self, tool_use_id: &str) -> bool {
68 self.callbacks.contains_key(tool_use_id)
69 }
70
71 pub fn remove(&self, tool_use_id: &str) {
73 self.callbacks.remove(tool_use_id);
74 }
75
76 pub fn len(&self) -> usize {
78 self.callbacks.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.callbacks.is_empty()
84 }
85}
86
87impl std::fmt::Debug for HookCallbackRegistry {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("HookCallbackRegistry")
90 .field("count", &self.callbacks.len())
91 .finish()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use futures::FutureExt;
99 use std::sync::Arc;
100 use std::sync::atomic::{AtomicBool, Ordering};
101
102 #[tokio::test]
103 async fn test_register_and_execute() {
104 let registry = HookCallbackRegistry::new();
105 let was_called = Arc::new(AtomicBool::new(false));
106 let was_called_clone = was_called.clone();
107
108 let callback: PostToolUseCallback = Box::new(move |_id, _input, _response| {
109 let was_called = was_called_clone.clone();
110 async move {
111 was_called.store(true, Ordering::SeqCst);
112 }
113 .boxed()
114 });
115
116 registry.register_post_tool_use("test-id".to_string(), callback);
117 assert!(registry.has_callback("test-id"));
118 assert_eq!(registry.len(), 1);
119
120 let result = registry
121 .execute_post_tool_use(
122 "test-id",
123 serde_json::json!({"command": "ls"}),
124 serde_json::json!("output"),
125 )
126 .await;
127
128 assert!(result);
129 assert!(was_called.load(Ordering::SeqCst));
130 assert!(!registry.has_callback("test-id"));
131 assert!(registry.is_empty());
132 }
133
134 #[tokio::test]
135 async fn test_execute_nonexistent() {
136 let registry = HookCallbackRegistry::new();
137 let result = registry
138 .execute_post_tool_use("nonexistent", serde_json::json!({}), serde_json::json!({}))
139 .await;
140
141 assert!(!result);
142 }
143
144 #[test]
145 fn test_remove() {
146 let registry = HookCallbackRegistry::new();
147 let callback: PostToolUseCallback = Box::new(|_id, _input, _response| async {}.boxed());
148
149 registry.register_post_tool_use("test-id".to_string(), callback);
150 assert!(registry.has_callback("test-id"));
151
152 registry.remove("test-id");
153 assert!(!registry.has_callback("test-id"));
154 }
155}