use std::fmt::Write as _;
use zeph_llm::LlmError;
use zeph_llm::provider::{LlmProvider, Message, Role};
use crate::tool::McpTool;
#[derive(Debug, Clone)]
enum CachedResult {
Ok(Vec<McpTool>),
Failed,
}
#[derive(Debug, Default, Clone)]
pub struct PruningCache {
key: Option<(u64, u64)>,
result: Option<CachedResult>,
}
enum CacheLookup<'a> {
Hit(&'a [McpTool]),
NegativeHit,
Miss,
}
impl PruningCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
self.key = None;
self.result = None;
}
fn lookup(&self, msg_hash: u64, tool_hash: u64) -> CacheLookup<'_> {
match (&self.key, &self.result) {
(Some(k), Some(CachedResult::Ok(tools))) if *k == (msg_hash, tool_hash) => {
CacheLookup::Hit(tools)
}
(Some(k), Some(CachedResult::Failed)) if *k == (msg_hash, tool_hash) => {
CacheLookup::NegativeHit
}
_ => CacheLookup::Miss,
}
}
fn insert_ok(&mut self, msg_hash: u64, tool_hash: u64, tools: Vec<McpTool>) {
self.key = Some((msg_hash, tool_hash));
self.result = Some(CachedResult::Ok(tools));
}
fn insert_failed(&mut self, msg_hash: u64, tool_hash: u64) {
self.key = Some((msg_hash, tool_hash));
self.result = Some(CachedResult::Failed);
}
}
#[must_use]
pub fn content_hash(s: &str) -> u64 {
let hash = blake3::hash(s.as_bytes());
u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
}
#[must_use]
pub fn tool_list_hash(tools: &[McpTool]) -> u64 {
let mut hasher = blake3::Hasher::new();
let mut sorted: Vec<&McpTool> = tools.iter().collect();
sorted.sort_by(|a, b| a.server_id.cmp(&b.server_id).then(a.name.cmp(&b.name)));
for tool in sorted {
hasher.update(tool.server_id.as_bytes());
hasher.update(b"\0");
hasher.update(tool.name.as_bytes());
hasher.update(b"\0");
hasher.update(tool.description.as_bytes());
hasher.update(b"\0");
match serde_json::to_vec(&tool.input_schema) {
Ok(schema_bytes) => {
hasher.update(&schema_bytes);
}
Err(_) => {
hasher.update(b"\x00");
}
}
hasher.update(b"\x01");
}
let hash = hasher.finalize();
u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
}
pub async fn prune_tools_cached<P: LlmProvider>(
cache: &mut PruningCache,
all_tools: &[McpTool],
task_context: &str,
params: &PruningParams,
provider: &P,
) -> Result<Vec<McpTool>, PruningError> {
let msg_hash = content_hash(task_context);
let tl_hash = tool_list_hash(all_tools);
match cache.lookup(msg_hash, tl_hash) {
CacheLookup::Hit(cached) => return Ok(cached.to_vec()),
CacheLookup::NegativeHit => {
tracing::warn!("pruning cache: negative hit, returning all tools without LLM call");
return Ok(all_tools.to_vec());
}
CacheLookup::Miss => {}
}
match prune_tools(all_tools, task_context, params, provider).await {
Ok(result) => {
cache.insert_ok(msg_hash, tl_hash, result.clone());
Ok(result)
}
Err(e) => {
cache.insert_failed(msg_hash, tl_hash);
Err(e)
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum PruningError {
#[error("pruning LLM call failed: {0}")]
LlmError(#[from] LlmError),
#[error("failed to parse pruning response as JSON array of tool names")]
ParseError,
}
#[derive(Debug, Clone)]
pub struct PruningParams {
pub max_tools: usize,
pub min_tools_to_prune: usize,
pub always_include: Vec<String>,
}
impl Default for PruningParams {
fn default() -> Self {
Self {
max_tools: 15,
min_tools_to_prune: 10,
always_include: Vec::new(),
}
}
}
pub async fn prune_tools<P: LlmProvider>(
all_tools: &[McpTool],
task_context: &str,
params: &PruningParams,
provider: &P,
) -> Result<Vec<McpTool>, PruningError> {
if all_tools.len() < params.min_tools_to_prune {
return Ok(all_tools.to_vec());
}
let (pinned, candidates): (Vec<_>, Vec<_>) = all_tools
.iter()
.partition(|t| params.always_include.iter().any(|a| a == &t.name));
let tool_list = candidates.iter().fold(String::new(), |mut acc, t| {
let name = sanitize_tool_name(&t.name);
let desc = sanitize_tool_description(&t.description);
let _ = writeln!(acc, "- {name}: {desc}");
acc
});
let prompt = format!(
"Return a JSON array of tool names that are relevant to the task below.\n\
Return ONLY the JSON array, no explanation, no markdown.\n\n\
Task: {task_context}\n\n\
Available tools:\n{tool_list}"
);
let messages = vec![Message::from_legacy(Role::User, prompt)];
let response = provider.chat(&messages).await?;
let relevant_names = parse_name_array(&response)?;
let mut result: Vec<McpTool> = pinned.into_iter().cloned().collect();
let mut candidates_added: usize = 0;
for tool in &candidates {
if params.max_tools > 0 && candidates_added >= params.max_tools {
break;
}
if relevant_names.iter().any(|n| n == &tool.name) {
result.push((*tool).clone());
candidates_added += 1;
}
}
Ok(result)
}
fn sanitize_tool_name(name: &str) -> String {
name.chars().filter(|c| !c.is_control()).take(64).collect()
}
fn sanitize_tool_description(desc: &str) -> String {
desc.chars().filter(|c| !c.is_control()).take(200).collect()
}
fn parse_name_array(response: &str) -> Result<Vec<String>, PruningError> {
let stripped = response
.lines()
.filter(|l| !l.trim_start().starts_with("```"))
.collect::<Vec<_>>()
.join("\n");
let start = stripped.find('[').ok_or(PruningError::ParseError)?;
let end = stripped.rfind(']').ok_or(PruningError::ParseError)?;
if end <= start {
return Err(PruningError::ParseError);
}
let json_fragment = &stripped[start..=end];
let names: Vec<String> =
serde_json::from_str(json_fragment).map_err(|_| PruningError::ParseError)?;
Ok(names)
}
#[cfg(test)]
mod tests {
use zeph_llm::mock::MockProvider;
use super::*;
fn make_tool(name: &str, description: &str) -> McpTool {
McpTool {
server_id: "test".into(),
name: name.into(),
description: description.into(),
input_schema: serde_json::Value::Null,
security_meta: crate::tool::ToolSecurityMeta::default(),
}
}
fn make_tool_with_server(server_id: &str, name: &str, description: &str) -> McpTool {
McpTool {
server_id: server_id.into(),
name: name.into(),
description: description.into(),
input_schema: serde_json::Value::Null,
security_meta: crate::tool::ToolSecurityMeta::default(),
}
}
fn params_with_max(max_tools: usize) -> PruningParams {
PruningParams {
max_tools,
min_tools_to_prune: 1,
always_include: Vec::new(),
}
}
#[test]
fn parse_plain_array() {
let names = parse_name_array(r#"["bash", "read", "write"]"#).unwrap();
assert_eq!(names, vec!["bash", "read", "write"]);
}
#[test]
fn parse_array_with_markdown_fences() {
let input = "```json\n[\"bash\", \"read\"]\n```";
let names = parse_name_array(input).unwrap();
assert_eq!(names, vec!["bash", "read"]);
}
#[test]
fn parse_array_with_preamble() {
let input = "Here are the relevant tools:\n[\"bash\", \"read\"]";
let names = parse_name_array(input).unwrap();
assert_eq!(names, vec!["bash", "read"]);
}
#[test]
fn parse_empty_array() {
let names = parse_name_array("[]").unwrap();
assert!(names.is_empty());
}
#[test]
fn parse_invalid_returns_error() {
assert!(parse_name_array("not json").is_err());
assert!(parse_name_array("").is_err());
assert!(parse_name_array("{\"key\": \"val\"}").is_err());
}
#[tokio::test]
async fn below_min_detected_early_return() {
let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::failing();
let params = PruningParams {
max_tools: 0,
min_tools_to_prune: 10, always_include: Vec::new(),
};
let result = prune_tools(&tools, "task", ¶ms, &provider)
.await
.unwrap();
assert_eq!(result.len(), 5, "all tools returned when below threshold");
}
#[tokio::test]
async fn always_include_pinned() {
let tools = vec![
make_tool("pinned", "always here"),
make_tool("candidate_a", "desc a"),
make_tool("candidate_b", "desc b"),
];
let provider = MockProvider::with_responses(vec![r#"["candidate_a"]"#.into()]);
let params = PruningParams {
max_tools: 0,
min_tools_to_prune: 1,
always_include: vec!["pinned".into()],
};
let result = prune_tools(&tools, "task", ¶ms, &provider)
.await
.unwrap();
assert!(
result.iter().any(|t| t.name == "pinned"),
"pinned must survive pruning"
);
assert!(result.iter().any(|t| t.name == "candidate_a"));
}
#[tokio::test]
async fn always_include_matches_bare_name_across_servers() {
let tools = vec![
make_tool_with_server("server_a", "search", "search on A"),
make_tool_with_server("server_b", "search", "search on B"),
make_tool_with_server("server_a", "other", "other tool"),
];
let provider = MockProvider::with_responses(vec![r#"["other"]"#.into()]);
let params = PruningParams {
max_tools: 0,
min_tools_to_prune: 1,
always_include: vec!["search".into()],
};
let result = prune_tools(&tools, "task", ¶ms, &provider)
.await
.unwrap();
assert_eq!(result.len(), 3, "both search tools + other must be present");
let search_count = result.iter().filter(|t| t.name == "search").count();
assert_eq!(
search_count, 2,
"both server_a:search and server_b:search must be pinned"
);
assert!(result.iter().any(|t| t.name == "other"));
}
#[tokio::test]
async fn max_tools_cap_respected() {
let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let names_json = r#"["t0","t1","t2","t3","t4"]"#;
let provider = MockProvider::with_responses(vec![names_json.into()]);
let result = prune_tools(&tools, "task", ¶ms_with_max(2), &provider)
.await
.unwrap();
assert_eq!(
result.len(),
2,
"max_tools=2 must cap LLM-selected candidates"
);
}
#[tokio::test]
async fn llm_failure_propagates() {
let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::failing();
let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
assert!(matches!(result, Err(PruningError::LlmError(_))));
}
#[tokio::test]
async fn parse_error_propagates() {
let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::with_responses(vec!["not valid json at all".into()]);
let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
assert!(matches!(result, Err(PruningError::ParseError)));
}
#[tokio::test]
async fn max_tools_zero_means_no_cap() {
let tools: Vec<McpTool> = (0..5)
.map(|i| make_tool(&format!("tool{i}"), "desc"))
.collect();
let names_json = r#"["tool0","tool1","tool2","tool3","tool4"]"#;
let provider = MockProvider::with_responses(vec![names_json.into()]);
let params = params_with_max(0);
let result = prune_tools(&tools, "any task", ¶ms, &provider)
.await
.unwrap();
assert_eq!(result.len(), 5, "max_tools=0 must not cap the result");
}
#[test]
fn description_sanitization_strips_control_chars_and_caps() {
let desc = "line1\nline2\tinject";
let sanitized = sanitize_tool_description(desc);
assert!(!sanitized.contains('\n'));
assert!(!sanitized.contains('\t'));
let long_desc = "x".repeat(300);
assert_eq!(sanitize_tool_description(&long_desc).len(), 200);
let long_name = "a".repeat(100);
assert_eq!(sanitize_tool_name(&long_name).len(), 64);
}
#[tokio::test]
async fn always_include_bypasses_max_tools_cap() {
let tools = vec![
make_tool("pinned", "always here"),
make_tool("candidate_a", "desc a"),
make_tool("candidate_b", "desc b"),
];
let provider =
MockProvider::with_responses(vec![r#"["candidate_a","candidate_b"]"#.into()]);
let params = PruningParams {
max_tools: 1,
min_tools_to_prune: 1,
always_include: vec!["pinned".into()],
};
let result = prune_tools(&tools, "task", ¶ms, &provider)
.await
.unwrap();
assert!(
result.iter().any(|t| t.name == "pinned"),
"pinned tool must bypass cap"
);
assert_eq!(result.len(), 2);
}
#[tokio::test]
async fn cache_positive_hit() {
let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()]);
let params = params_with_max(0);
let mut cache = PruningCache::new();
let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
.await
.unwrap();
let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
.await
.unwrap();
assert_eq!(r1.len(), 2);
assert_eq!(r1.len(), r2.len(), "cache hit must return same result");
}
#[tokio::test]
async fn cache_miss_on_message_change() {
let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider =
MockProvider::with_responses(vec![r#"["t0","t1"]"#.into(), r#"["t0"]"#.into()]);
let params = params_with_max(0);
let mut cache = PruningCache::new();
let r1 = prune_tools_cached(&mut cache, &tools, "query_a", ¶ms, &provider)
.await
.unwrap();
let r2 = prune_tools_cached(&mut cache, &tools, "query_b", ¶ms, &provider)
.await
.unwrap();
assert_eq!(r1.len(), 2, "first call returns both tools");
assert_eq!(
r2.len(),
1,
"different message triggers cache miss and LLM call"
);
}
#[tokio::test]
async fn cache_miss_on_tool_list_change() {
let tools1: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let mut tools2 = tools1.clone();
tools2.push(make_tool("t2", "new tool"));
let provider = MockProvider::with_responses(vec![
r#"["t0","t1"]"#.into(),
r#"["t0","t1","t2"]"#.into(),
]);
let params = params_with_max(0);
let mut cache = PruningCache::new();
let r1 = prune_tools_cached(&mut cache, &tools1, "query", ¶ms, &provider)
.await
.unwrap();
let r2 = prune_tools_cached(&mut cache, &tools2, "query", ¶ms, &provider)
.await
.unwrap();
assert_eq!(r1.len(), 2);
assert_eq!(r2.len(), 3, "new tool triggers cache miss");
}
#[tokio::test]
async fn cache_negative_hit_skips_llm() {
let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::failing();
let params = params_with_max(0);
let mut cache = PruningCache::new();
let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
assert!(r1.is_err(), "first call must propagate LLM error");
let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
.await
.unwrap();
assert_eq!(r2.len(), 2, "negative cache hit must return all tools");
}
#[tokio::test]
async fn cache_negative_hit_clears_on_reset() {
let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()])
.with_errors(vec![zeph_llm::LlmError::Other("simulated failure".into())]);
let params = params_with_max(0);
let mut cache = PruningCache::new();
let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
assert!(r1.is_err());
cache.reset();
let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
.await
.unwrap();
assert_eq!(r2.len(), 2, "after reset the LLM must be retried");
}
}