1use std::collections::HashMap;
4use std::sync::Arc;
5
6use sacp::JrConnectionCx;
7use sacp::link::AgentToClient;
8use sacp::schema::{
9 SessionId, SessionNotification, SessionUpdate, Terminal, ToolCallContent, ToolCallId,
10 ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields,
11};
12use serde::{Deserialize, Serialize};
13
14use super::tools::Tool;
15use crate::session::BackgroundProcessManager;
16use crate::settings::PermissionChecker;
17use crate::terminal::TerminalClient;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ToolResult {
22 pub status: ToolStatus,
24 pub content: String,
26 pub is_error: bool,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub metadata: Option<serde_json::Value>,
31}
32
33impl ToolResult {
34 pub fn success(content: impl Into<String>) -> Self {
36 Self {
37 status: ToolStatus::Success,
38 content: content.into(),
39 is_error: false,
40 metadata: None,
41 }
42 }
43
44 pub fn error(message: impl Into<String>) -> Self {
46 Self {
47 status: ToolStatus::Error,
48 content: message.into(),
49 is_error: true,
50 metadata: None,
51 }
52 }
53
54 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
56 self.metadata = Some(metadata);
57 self
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "lowercase")]
64pub enum ToolStatus {
65 Success,
67 Error,
69 Cancelled,
71 Running,
73}
74
75#[derive(Debug, Clone)]
77pub struct ToolContext {
78 pub session_id: String,
80 pub cwd: std::path::PathBuf,
82 pub allow_dangerous: bool,
84 background_processes: Option<Arc<BackgroundProcessManager>>,
86 terminal_client: Option<Arc<TerminalClient>>,
88 tool_use_id: Option<String>,
90 connection_cx: Option<JrConnectionCx<AgentToClient>>,
92 pub permission_checker: Option<Arc<tokio::sync::RwLock<PermissionChecker>>>,
94}
95
96impl ToolContext {
97 pub fn new(session_id: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
99 Self {
100 session_id: session_id.into(),
101 cwd: cwd.into(),
102 allow_dangerous: false,
103 background_processes: None,
104 terminal_client: None,
105 tool_use_id: None,
106 connection_cx: None,
107 permission_checker: None,
108 }
109 }
110
111 pub fn with_dangerous(mut self, allow: bool) -> Self {
113 self.allow_dangerous = allow;
114 self
115 }
116
117 pub fn with_background_processes(mut self, manager: Arc<BackgroundProcessManager>) -> Self {
119 self.background_processes = Some(manager);
120 self
121 }
122
123 pub fn with_terminal_client(mut self, client: Arc<TerminalClient>) -> Self {
125 self.terminal_client = Some(client);
126 self
127 }
128
129 pub fn with_tool_use_id(mut self, id: impl Into<String>) -> Self {
131 self.tool_use_id = Some(id.into());
132 self
133 }
134
135 pub fn with_connection_cx(mut self, cx: JrConnectionCx<AgentToClient>) -> Self {
137 self.connection_cx = Some(cx);
138 self
139 }
140
141 pub fn with_permission_checker(
143 mut self,
144 checker: Arc<tokio::sync::RwLock<PermissionChecker>>,
145 ) -> Self {
146 self.permission_checker = Some(checker);
147 self
148 }
149
150 pub fn background_processes(&self) -> Option<&Arc<BackgroundProcessManager>> {
152 self.background_processes.as_ref()
153 }
154
155 pub fn terminal_client(&self) -> Option<&Arc<TerminalClient>> {
160 self.terminal_client.as_ref()
161 }
162
163 pub fn tool_use_id(&self) -> Option<&str> {
165 self.tool_use_id.as_deref()
166 }
167
168 pub fn send_terminal_update(
183 &self,
184 terminal_id: impl Into<String>,
185 status: ToolCallStatus,
186 title: Option<&str>,
187 ) -> Result<(), String> {
188 let Some(connection_cx) = &self.connection_cx else {
189 return Err("No connection context available".to_string());
190 };
191
192 let Some(tool_use_id) = &self.tool_use_id else {
193 return Err("No tool use ID available".to_string());
194 };
195
196 let terminal = Terminal::new(terminal_id.into());
198 let content = vec![ToolCallContent::Terminal(terminal)];
199
200 let mut update_fields = ToolCallUpdateFields::new().status(status).content(content);
202
203 if let Some(title) = title {
204 update_fields = update_fields.title(title);
205 }
206
207 let tool_call_id = ToolCallId::new(tool_use_id.clone());
209 let update = ToolCallUpdate::new(tool_call_id, update_fields);
210 let notification = SessionNotification::new(
211 SessionId::new(self.session_id.as_str()),
212 SessionUpdate::ToolCallUpdate(update),
213 );
214
215 connection_cx
216 .send_notification(notification)
217 .map_err(|e| format!("Failed to send notification: {}", e))
218 }
219}
220
221pub const ACP_TOOL_PREFIX: &str = "mcp__acp__";
223
224#[derive(Debug, Default)]
226pub struct ToolRegistry {
227 tools: HashMap<String, Arc<dyn Tool>>,
229}
230
231impl ToolRegistry {
232 pub fn new() -> Self {
234 Self {
235 tools: HashMap::new(),
236 }
237 }
238
239 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
241 let name = tool.name().to_string();
242 self.tools.insert(name, Arc::new(tool));
243 }
244
245 pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
247 let name = tool.name().to_string();
248 self.tools.insert(name, tool);
249 }
250
251 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
256 if let Some(tool) = self.tools.get(name) {
258 return Some(tool.clone());
259 }
260
261 if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
263 if let Some(tool) = self.tools.get(stripped) {
264 return Some(tool.clone());
265 }
266 }
267
268 None
269 }
270
271 pub fn contains(&self, name: &str) -> bool {
273 if self.tools.contains_key(name) {
274 return true;
275 }
276
277 if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
279 return self.tools.contains_key(stripped);
280 }
281
282 false
283 }
284
285 pub fn normalize_name(name: &str) -> &str {
287 name.strip_prefix(ACP_TOOL_PREFIX).unwrap_or(name)
288 }
289
290 pub fn names(&self) -> Vec<&str> {
292 self.tools.keys().map(String::as_str).collect()
293 }
294
295 pub fn len(&self) -> usize {
297 self.tools.len()
298 }
299
300 pub fn is_empty(&self) -> bool {
302 self.tools.is_empty()
303 }
304
305 pub async fn execute(
307 &self,
308 name: &str,
309 input: serde_json::Value,
310 context: &ToolContext,
311 ) -> ToolResult {
312 match self.get(name) {
313 Some(tool) => tool.execute(input, context).await,
314 None => ToolResult::error(format!("Tool not found: {}", name)),
315 }
316 }
317
318 pub fn schemas(&self) -> Vec<ToolSchema> {
320 self.tools
321 .values()
322 .map(|tool| ToolSchema {
323 name: tool.name().to_string(),
324 description: tool.description().to_string(),
325 input_schema: tool.input_schema(),
326 })
327 .collect()
328 }
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct ToolSchema {
334 pub name: String,
336 pub description: String,
338 pub input_schema: serde_json::Value,
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use serde_json::json;
346
347 #[test]
348 fn test_tool_result_success() {
349 let result = ToolResult::success("Hello, World!");
350 assert_eq!(result.status, ToolStatus::Success);
351 assert_eq!(result.content, "Hello, World!");
352 assert!(!result.is_error);
353 }
354
355 #[test]
356 fn test_tool_result_error() {
357 let result = ToolResult::error("Something went wrong");
358 assert_eq!(result.status, ToolStatus::Error);
359 assert_eq!(result.content, "Something went wrong");
360 assert!(result.is_error);
361 }
362
363 #[test]
364 fn test_tool_result_with_metadata() {
365 let result = ToolResult::success("data").with_metadata(json!({"lines": 10}));
366 assert!(result.metadata.is_some());
367 }
368
369 #[test]
370 fn test_tool_context() {
371 let ctx = ToolContext::new("session-1", "/tmp").with_dangerous(true);
372 assert_eq!(ctx.session_id, "session-1");
373 assert_eq!(ctx.cwd, std::path::PathBuf::from("/tmp"));
374 assert!(ctx.allow_dangerous);
375 }
376
377 #[test]
378 fn test_registry_new() {
379 let registry = ToolRegistry::new();
380 assert!(registry.is_empty());
381 assert_eq!(registry.len(), 0);
382 }
383
384 #[test]
385 fn test_acp_prefix_normalize() {
386 assert_eq!(ToolRegistry::normalize_name("Read"), "Read");
388 assert_eq!(ToolRegistry::normalize_name("Bash"), "Bash");
389
390 assert_eq!(ToolRegistry::normalize_name("mcp__acp__Read"), "Read");
392 assert_eq!(ToolRegistry::normalize_name("mcp__acp__Bash"), "Bash");
393 assert_eq!(
394 ToolRegistry::normalize_name("mcp__acp__TodoWrite"),
395 "TodoWrite"
396 );
397 }
398
399 #[test]
400 fn test_acp_prefix_constant() {
401 assert_eq!(ACP_TOOL_PREFIX, "mcp__acp__");
402 }
403}