use std::fmt::Write;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use crate::tools::mcp_deferred::{ActivatedToolSet, DeferredMcpToolSet};
use crate::tools::traits::{Tool, ToolResult, ToolSpec};
const DEFAULT_MAX_RESULTS: usize = 5;
fn write_function_entry(output: &mut String, spec: &ToolSpec) {
let entry = serde_json::json!({
"name": spec.name,
"description": spec.description,
"parameters": spec.parameters,
});
let _ = writeln!(
output,
"<function>{}</function>",
serde_json::to_string(&entry).unwrap_or_default()
);
}
pub struct ToolSearchTool {
deferred: DeferredMcpToolSet,
activated: Arc<Mutex<ActivatedToolSet>>,
allowed_tools: Option<Vec<String>>,
}
impl ToolSearchTool {
pub fn new(deferred: DeferredMcpToolSet, activated: Arc<Mutex<ActivatedToolSet>>) -> Self {
Self {
deferred,
activated,
allowed_tools: None,
}
}
#[must_use]
pub fn with_allowed_tools(mut self, allowed_tools: Option<Vec<String>>) -> Self {
self.allowed_tools = allowed_tools;
self
}
fn is_allowed(&self, prefixed_name: &str) -> bool {
match &self.allowed_tools {
None => true,
Some(list) => list.iter().any(|name| name == prefixed_name),
}
}
}
#[async_trait]
impl Tool for ToolSearchTool {
fn name(&self) -> &str {
"tool_search"
}
fn description(&self) -> &str {
"Fetch full schema definitions for deferred MCP tools so they can be called. \
Use \"select:name1,name2\" for exact match or keywords to search."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {
"description": "Query to find deferred tools. Use \"select:<tool_name>\" for direct selection, or keywords to search.",
"type": "string"
},
"max_results": {
"description": "Maximum number of results to return (default: 5)",
"type": "number",
"default": DEFAULT_MAX_RESULTS
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.unwrap_or_default()
.trim();
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.map(|v| usize::try_from(v).unwrap_or(DEFAULT_MAX_RESULTS))
.unwrap_or(DEFAULT_MAX_RESULTS);
if query.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("query parameter is required".into()),
});
}
if let Some(names_str) = query.strip_prefix("select:") {
let names: Vec<&str> = names_str.split(',').map(str::trim).collect();
return self.select_tools(&names);
}
let results: Vec<_> = self
.deferred
.search(query, max_results)
.into_iter()
.filter(|stub| self.is_allowed(&stub.prefixed_name))
.collect();
if results.is_empty() {
return Ok(ToolResult {
success: true,
output: "No matching deferred tools found.".into(),
error: None,
});
}
let mut output = String::from("<functions>\n");
let mut activated_count = 0;
let mut guard = self
.activated
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for stub in &results {
if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
if !guard.is_activated(&stub.prefixed_name) {
if let Some(tool) = self.deferred.activate(&stub.prefixed_name) {
guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
activated_count += 1;
}
}
write_function_entry(&mut output, &spec);
}
}
output.push_str("</functions>\n");
drop(guard);
tracing::debug!(
"tool_search: query={query:?}, matched={}, activated={activated_count}",
results.len()
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
impl ToolSearchTool {
fn select_tools(&self, names: &[&str]) -> anyhow::Result<ToolResult> {
let mut output = String::from("<functions>\n");
let mut not_found = Vec::new();
let mut activated_count = 0;
let mut guard = self
.activated
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for name in names {
if name.is_empty() {
continue;
}
match self.deferred.get_by_name(name) {
Some(stub) if !self.is_allowed(&stub.prefixed_name) => {
not_found.push(*name);
}
Some(stub) => {
let full_name = &stub.prefixed_name;
if let Some(spec) = self.deferred.tool_spec(full_name) {
if !guard.is_activated(full_name) {
if let Some(tool) = self.deferred.activate(full_name) {
guard.activate(full_name.clone(), Arc::from(tool));
activated_count += 1;
}
}
write_function_entry(&mut output, &spec);
}
}
None => {
not_found.push(*name);
}
}
}
output.push_str("</functions>\n");
drop(guard);
if !not_found.is_empty() {
let _ = write!(output, "\nNot found: {}", not_found.join(", "));
}
tracing::debug!(
"tool_search select: requested={}, activated={activated_count}, not_found={}",
names.len(),
not_found.len()
);
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::mcp_client::McpRegistry;
use crate::tools::mcp_deferred::DeferredMcpToolStub;
use crate::tools::mcp_protocol::McpToolDef;
async fn make_deferred_set(stubs: Vec<DeferredMcpToolStub>) -> DeferredMcpToolSet {
let registry = Arc::new(McpRegistry::connect_all(&[]).await.unwrap());
DeferredMcpToolSet { stubs, registry }
}
fn make_stub(name: &str, desc: &str) -> DeferredMcpToolStub {
let def = McpToolDef {
name: name.to_string(),
description: Some(desc.to_string()),
input_schema: serde_json::json!({"type": "object", "properties": {}}),
};
DeferredMcpToolStub::new(name.to_string(), def)
}
#[tokio::test]
async fn tool_metadata() {
let tool = ToolSearchTool::new(
make_deferred_set(vec![]).await,
Arc::new(Mutex::new(ActivatedToolSet::new())),
);
assert_eq!(tool.name(), "tool_search");
assert!(!tool.description().is_empty());
assert!(tool.parameters_schema()["properties"]["query"].is_object());
}
#[tokio::test]
async fn empty_query_returns_error() {
let tool = ToolSearchTool::new(
make_deferred_set(vec![]).await,
Arc::new(Mutex::new(ActivatedToolSet::new())),
);
let result = tool
.execute(serde_json::json!({"query": ""}))
.await
.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn select_nonexistent_tool_reports_not_found() {
let tool = ToolSearchTool::new(
make_deferred_set(vec![]).await,
Arc::new(Mutex::new(ActivatedToolSet::new())),
);
let result = tool
.execute(serde_json::json!({"query": "select:nonexistent"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Not found"));
}
#[tokio::test]
async fn keyword_search_no_matches() {
let tool = ToolSearchTool::new(
make_deferred_set(vec![make_stub("fs__read", "Read file")]).await,
Arc::new(Mutex::new(ActivatedToolSet::new())),
);
let result = tool
.execute(serde_json::json!({"query": "zzzzz_nonexistent"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("No matching"));
}
#[tokio::test]
async fn keyword_search_finds_match() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let tool = ToolSearchTool::new(
make_deferred_set(vec![make_stub("fs__read", "Read a file from disk")]).await,
Arc::clone(&activated),
);
let result = tool
.execute(serde_json::json!({"query": "read file"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("<function>"));
assert!(result.output.contains("fs__read"));
assert!(activated.lock().unwrap().is_activated("fs__read"));
}
#[tokio::test]
async fn multiple_servers_stubs_all_searchable() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("server_a__list_files", "List files on server A"),
make_stub("server_a__read_file", "Read file on server A"),
make_stub("server_b__query_db", "Query database on server B"),
make_stub("server_b__insert_row", "Insert row on server B"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
let result = tool
.execute(serde_json::json!({"query": "file"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("server_a__list_files"));
assert!(result.output.contains("server_a__read_file"));
let result = tool
.execute(serde_json::json!({"query": "database query"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("server_b__query_db"));
}
#[tokio::test]
async fn select_activates_and_persists_across_calls() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("srv__tool_a", "Tool A"),
make_stub("srv__tool_b", "Tool B"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
let result = tool
.execute(serde_json::json!({"query": "select:srv__tool_a"}))
.await
.unwrap();
assert!(result.success);
assert!(activated.lock().unwrap().is_activated("srv__tool_a"));
assert!(!activated.lock().unwrap().is_activated("srv__tool_b"));
let result = tool
.execute(serde_json::json!({"query": "select:srv__tool_b"}))
.await
.unwrap();
assert!(result.success);
let guard = activated.lock().unwrap();
assert!(guard.is_activated("srv__tool_a"));
assert!(guard.is_activated("srv__tool_b"));
assert_eq!(guard.tool_specs().len(), 2);
}
#[tokio::test]
async fn description_with_control_chars_emits_valid_json() {
let desc = "Reads a file at C:\\Users\\me.\nMatches \\d+ digits.\tDone.";
let tool = ToolSearchTool::new(
make_deferred_set(vec![make_stub("fs__read", desc)]).await,
Arc::new(Mutex::new(ActivatedToolSet::new())),
);
let result = tool
.execute(serde_json::json!({"query": "select:fs__read"}))
.await
.unwrap();
assert!(result.success);
let payload = result
.output
.lines()
.find(|l| l.starts_with("<function>"))
.and_then(|l| l.strip_prefix("<function>"))
.and_then(|l| l.strip_suffix("</function>"))
.expect("expected a <function> entry");
let parsed: serde_json::Value =
serde_json::from_str(payload).expect("emitted entry must be valid JSON");
assert_eq!(parsed["name"], "fs__read");
assert_eq!(parsed["description"], desc);
}
#[tokio::test]
async fn reactivation_is_idempotent() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let tool = ToolSearchTool::new(
make_deferred_set(vec![make_stub("srv__tool", "A tool")]).await,
Arc::clone(&activated),
);
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
.await
.unwrap();
tool.execute(serde_json::json!({"query": "select:srv__tool"}))
.await
.unwrap();
assert_eq!(activated.lock().unwrap().tool_specs().len(), 1);
}
#[tokio::test]
async fn allowlist_filters_keyword_search() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("fs__read", "Read a file from disk"),
make_stub("net__fetch", "Fetch a file from the network"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
.with_allowed_tools(Some(vec!["fs__read".to_string()]));
let result = tool
.execute(serde_json::json!({"query": "file"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("fs__read"));
assert!(activated.lock().unwrap().is_activated("fs__read"));
assert!(!result.output.contains("net__fetch"));
assert!(!activated.lock().unwrap().is_activated("net__fetch"));
}
#[tokio::test]
async fn allowlist_blocks_select_of_disallowed_tool() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![
make_stub("fs__read", "Read a file"),
make_stub("net__fetch", "Fetch from the network"),
];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
.with_allowed_tools(Some(vec!["fs__read".to_string()]));
let result = tool
.execute(serde_json::json!({"query": "select:net__fetch"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Not found"));
assert!(!activated.lock().unwrap().is_activated("net__fetch"));
let result = tool
.execute(serde_json::json!({"query": "select:fs__read"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("fs__read"));
assert!(activated.lock().unwrap().is_activated("fs__read"));
}
#[tokio::test]
async fn no_allowlist_allows_all_tools() {
let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
let stubs = vec![make_stub("net__fetch", "Fetch from the network")];
let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
.with_allowed_tools(None);
let result = tool
.execute(serde_json::json!({"query": "select:net__fetch"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("net__fetch"));
assert!(activated.lock().unwrap().is_activated("net__fetch"));
}
}