use std::collections::HashMap;
use std::sync::Arc;
use sacp::JrConnectionCx;
use sacp::link::AgentToClient;
use sacp::schema::{
SessionId, SessionNotification, SessionUpdate, Terminal, ToolCallContent, ToolCallId,
ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields,
};
use serde::{Deserialize, Serialize};
use super::tools::Tool;
use crate::session::BackgroundProcessManager;
use crate::settings::PermissionChecker;
use crate::terminal::TerminalClient;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub status: ToolStatus,
pub content: String,
pub is_error: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl ToolResult {
pub fn success(content: impl Into<String>) -> Self {
Self {
status: ToolStatus::Success,
content: content.into(),
is_error: false,
metadata: None,
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
status: ToolStatus::Error,
content: message.into(),
is_error: true,
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolStatus {
Success,
Error,
Cancelled,
Running,
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub session_id: String,
pub cwd: std::path::PathBuf,
pub allow_dangerous: bool,
background_processes: Option<Arc<BackgroundProcessManager>>,
terminal_client: Option<Arc<TerminalClient>>,
tool_use_id: Option<String>,
connection_cx: Option<JrConnectionCx<AgentToClient>>,
pub permission_checker: Option<Arc<tokio::sync::RwLock<PermissionChecker>>>,
}
impl ToolContext {
pub fn new(session_id: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
Self {
session_id: session_id.into(),
cwd: cwd.into(),
allow_dangerous: false,
background_processes: None,
terminal_client: None,
tool_use_id: None,
connection_cx: None,
permission_checker: None,
}
}
pub fn with_dangerous(mut self, allow: bool) -> Self {
self.allow_dangerous = allow;
self
}
pub fn with_background_processes(mut self, manager: Arc<BackgroundProcessManager>) -> Self {
self.background_processes = Some(manager);
self
}
pub fn with_terminal_client(mut self, client: Arc<TerminalClient>) -> Self {
self.terminal_client = Some(client);
self
}
pub fn with_tool_use_id(mut self, id: impl Into<String>) -> Self {
self.tool_use_id = Some(id.into());
self
}
pub fn with_connection_cx(mut self, cx: JrConnectionCx<AgentToClient>) -> Self {
self.connection_cx = Some(cx);
self
}
pub fn with_permission_checker(
mut self,
checker: Arc<tokio::sync::RwLock<PermissionChecker>>,
) -> Self {
self.permission_checker = Some(checker);
self
}
pub fn background_processes(&self) -> Option<&Arc<BackgroundProcessManager>> {
self.background_processes.as_ref()
}
pub fn terminal_client(&self) -> Option<&Arc<TerminalClient>> {
self.terminal_client.as_ref()
}
pub fn tool_use_id(&self) -> Option<&str> {
self.tool_use_id.as_deref()
}
pub fn send_terminal_update(
&self,
terminal_id: impl Into<String>,
status: ToolCallStatus,
title: Option<&str>,
) -> Result<(), String> {
let Some(connection_cx) = &self.connection_cx else {
return Err("No connection context available".to_string());
};
let Some(tool_use_id) = &self.tool_use_id else {
return Err("No tool use ID available".to_string());
};
let terminal = Terminal::new(terminal_id.into());
let content = vec![ToolCallContent::Terminal(terminal)];
let mut update_fields = ToolCallUpdateFields::new().status(status).content(content);
if let Some(title) = title {
update_fields = update_fields.title(title);
}
let tool_call_id = ToolCallId::new(tool_use_id.clone());
let update = ToolCallUpdate::new(tool_call_id, update_fields);
let notification = SessionNotification::new(
SessionId::new(self.session_id.as_str()),
SessionUpdate::ToolCallUpdate(update),
);
connection_cx
.send_notification(notification)
.map_err(|e| format!("Failed to send notification: {}", e))
}
}
pub const ACP_TOOL_PREFIX: &str = "mcp__acp__";
#[derive(Debug, Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
let name = tool.name().to_string();
self.tools.insert(name, Arc::new(tool));
}
pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
if let Some(tool) = self.tools.get(name) {
return Some(tool.clone());
}
if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
if let Some(tool) = self.tools.get(stripped) {
return Some(tool.clone());
}
}
None
}
pub fn contains(&self, name: &str) -> bool {
if self.tools.contains_key(name) {
return true;
}
if let Some(stripped) = name.strip_prefix(ACP_TOOL_PREFIX) {
return self.tools.contains_key(stripped);
}
false
}
pub fn normalize_name(name: &str) -> &str {
name.strip_prefix(ACP_TOOL_PREFIX).unwrap_or(name)
}
pub fn names(&self) -> Vec<&str> {
self.tools.keys().map(String::as_str).collect()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub async fn execute(
&self,
name: &str,
input: serde_json::Value,
context: &ToolContext,
) -> ToolResult {
match self.get(name) {
Some(tool) => tool.execute(input, context).await,
None => ToolResult::error(format!("Tool not found: {}", name)),
}
}
pub fn schemas(&self) -> Vec<ToolSchema> {
self.tools
.values()
.map(|tool| ToolSchema {
name: tool.name().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_result_success() {
let result = ToolResult::success("Hello, World!");
assert_eq!(result.status, ToolStatus::Success);
assert_eq!(result.content, "Hello, World!");
assert!(!result.is_error);
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("Something went wrong");
assert_eq!(result.status, ToolStatus::Error);
assert_eq!(result.content, "Something went wrong");
assert!(result.is_error);
}
#[test]
fn test_tool_result_with_metadata() {
let result = ToolResult::success("data").with_metadata(json!({"lines": 10}));
assert!(result.metadata.is_some());
}
#[test]
fn test_tool_context() {
let ctx = ToolContext::new("session-1", "/tmp").with_dangerous(true);
assert_eq!(ctx.session_id, "session-1");
assert_eq!(ctx.cwd, std::path::PathBuf::from("/tmp"));
assert!(ctx.allow_dangerous);
}
#[test]
fn test_registry_new() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_acp_prefix_normalize() {
assert_eq!(ToolRegistry::normalize_name("Read"), "Read");
assert_eq!(ToolRegistry::normalize_name("Bash"), "Bash");
assert_eq!(ToolRegistry::normalize_name("mcp__acp__Read"), "Read");
assert_eq!(ToolRegistry::normalize_name("mcp__acp__Bash"), "Bash");
assert_eq!(
ToolRegistry::normalize_name("mcp__acp__TodoWrite"),
"TodoWrite"
);
}
#[test]
fn test_acp_prefix_constant() {
assert_eq!(ACP_TOOL_PREFIX, "mcp__acp__");
}
}