Skip to main content

zeph_mcp/
pruning.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dynamic MCP tool pruning for context optimization (#2204).
5//!
6//! The `prune_tools` free function filters a list of MCP tools to only those relevant
7//! to the current task, using an LLM call with a fast/cheap model. This reduces context
8//! usage and improves tool selection accuracy when MCP servers expose many tools.
9//!
10//! `zeph-mcp` does not depend on `zeph-config` (circular dependency: zeph-config ->
11//! zeph-mcp). Callers in `zeph-core` convert `ToolPruningConfig` into `PruningParams`
12//! before calling `prune_tools`.
13
14use std::fmt::Write as _;
15
16use zeph_llm::LlmError;
17use zeph_llm::provider::{LlmProvider, Message, Role};
18
19use crate::tool::McpTool;
20
21// ── Per-message pruning cache (#2298) ────────────────────────────────────────
22
23/// Cached outcome stored by [`PruningCache`].
24///
25/// [`Ok`] holds the previously-computed pruned tool list; [`Failed`] is a
26/// sentinel written when the LLM call failed, so subsequent lookups with the
27/// same key return the all-tools fallback without retrying the LLM.
28#[derive(Debug, Clone)]
29enum CachedResult {
30    Ok(Vec<McpTool>),
31    /// LLM call failed; caller should use the full tool list.
32    Failed,
33}
34
35/// Per-message cache for MCP tool pruning results.
36///
37/// Stores at most one entry keyed on `(message_content_hash, tool_list_hash)`.
38/// A cache miss triggers an LLM call; a hit returns the stored result
39/// immediately.  Negative entries (`Failed`) prevent retry storms when the
40/// pruning LLM is transiently unavailable.
41///
42/// # Cache contract
43///
44/// `PruningCache` returns previously-computed pruning results keyed on
45/// `(message_content_hash, tool_list_hash)`.
46///
47/// `tool_list_hash` includes: `server_id`, `name`, `description`, and
48/// `input_schema` for every tool.  Any change to tool metadata (not just the
49/// name set) produces a different hash and causes a cache miss.
50///
51/// `PruningCache::reset()` is additionally called on:
52/// - New user message (top of `process_user_message_inner`)
53/// - `tools/list_changed` notification (in `check_tool_refresh`)
54/// - Manual `/mcp add` or `/mcp remove` commands
55///
56/// `PruningParams` is **not** part of the cache key.  Callers must not change
57/// `PruningParams` within a single user turn; this invariant holds because
58/// params are derived from `ToolPruningConfig`, which is stable within a turn
59/// (config changes trigger a full agent rebuild, not a mid-turn param swap).
60///
61/// Designed for single-owner use (`&mut` on `Agent`). Not thread-safe.
62#[derive(Debug, Default, Clone)]
63pub struct PruningCache {
64    key: Option<(u64, u64)>,
65    result: Option<CachedResult>,
66}
67
68/// Outcome of a [`PruningCache::lookup`] call.
69enum CacheLookup<'a> {
70    /// Positive hit: pruned tool slice from a previous successful call.
71    Hit(&'a [McpTool]),
72    /// Negative hit: LLM previously failed; caller should use the full tool list.
73    NegativeHit,
74    /// No entry for this key.
75    Miss,
76}
77
78impl PruningCache {
79    /// Create a new, empty cache.
80    #[must_use]
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    /// Clear the cached entry.
86    ///
87    /// Must be called at the start of each user turn and whenever the MCP tool
88    /// list changes (via notification, `/mcp add`, or `/mcp remove`).
89    pub fn reset(&mut self) {
90        self.key = None;
91        self.result = None;
92    }
93
94    fn lookup(&self, msg_hash: u64, tool_hash: u64) -> CacheLookup<'_> {
95        match (&self.key, &self.result) {
96            (Some(k), Some(CachedResult::Ok(tools))) if *k == (msg_hash, tool_hash) => {
97                CacheLookup::Hit(tools)
98            }
99            (Some(k), Some(CachedResult::Failed)) if *k == (msg_hash, tool_hash) => {
100                CacheLookup::NegativeHit
101            }
102            _ => CacheLookup::Miss,
103        }
104    }
105
106    fn insert_ok(&mut self, msg_hash: u64, tool_hash: u64, tools: Vec<McpTool>) {
107        self.key = Some((msg_hash, tool_hash));
108        self.result = Some(CachedResult::Ok(tools));
109    }
110
111    fn insert_failed(&mut self, msg_hash: u64, tool_hash: u64) {
112        self.key = Some((msg_hash, tool_hash));
113        self.result = Some(CachedResult::Failed);
114    }
115}
116
117/// Compute a `u64` hash of a string using blake3 (first 8 bytes, little-endian).
118///
119/// # Panics
120///
121/// Never panics in practice: blake3 always produces at least 8 bytes of output.
122#[must_use]
123pub fn content_hash(s: &str) -> u64 {
124    let hash = blake3::hash(s.as_bytes());
125    u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
126}
127
128/// Compute a `u64` hash of the full tool list metadata using blake3.
129///
130/// Hashes `server_id`, `name`, `description`, and `input_schema` for every
131/// tool, sorted by qualified name (`server_id` then `name`) for deterministic
132/// ordering regardless of list order.
133///
134/// **`BTreeMap` assumption**: `serde_json::to_vec` produces deterministic output
135/// because `serde_json::Map` defaults to `BTreeMap`-backed storage (sorted
136/// keys).  If the `preserve_order` feature of `serde_json` is ever enabled
137/// (switching `Map` to `IndexMap`), key order becomes insertion-order and
138/// hashing becomes non-deterministic.  Should `preserve_order` be needed,
139/// sort `Map` keys before serialising here.
140///
141/// # Panics
142///
143/// Never panics in practice: blake3 always produces at least 8 bytes of output.
144#[must_use]
145pub fn tool_list_hash(tools: &[McpTool]) -> u64 {
146    let mut hasher = blake3::Hasher::new();
147    let mut sorted: Vec<&McpTool> = tools.iter().collect();
148    sorted.sort_by(|a, b| a.server_id.cmp(&b.server_id).then(a.name.cmp(&b.name)));
149    for tool in sorted {
150        hasher.update(tool.server_id.as_bytes());
151        hasher.update(b"\0");
152        hasher.update(tool.name.as_bytes());
153        hasher.update(b"\0");
154        hasher.update(tool.description.as_bytes());
155        hasher.update(b"\0");
156        match serde_json::to_vec(&tool.input_schema) {
157            Ok(schema_bytes) => {
158                hasher.update(&schema_bytes);
159            }
160            Err(_) => {
161                hasher.update(b"\x00");
162            }
163        }
164        // Tool separator — prevents adjacent-field collisions.
165        hasher.update(b"\x01");
166    }
167    let hash = hasher.finalize();
168    u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
169}
170
171/// Cache-aware wrapper around [`prune_tools`].
172///
173/// On a **positive cache hit**: returns the previously-computed pruned list
174/// without an LLM call.
175///
176/// On a **negative cache hit** (LLM previously failed for this key): returns
177/// `Ok(all_tools.to_vec())` without retrying the LLM, avoiding retry storms
178/// when the pruning LLM is transiently unavailable.
179///
180/// On a **cache miss**: calls [`prune_tools`], stores the result (success or
181/// failure), and returns.  On LLM failure the negative sentinel is cached and
182/// `Err(PruningError)` is returned so the caller can log and fall back.
183///
184/// # Errors
185///
186/// Propagates `PruningError` from [`prune_tools`] on the first (uncached) LLM
187/// failure.  Subsequent calls with the same key return `Ok(all_tools.to_vec())`
188/// from the negative cache entry.
189pub async fn prune_tools_cached<P: LlmProvider>(
190    cache: &mut PruningCache,
191    all_tools: &[McpTool],
192    task_context: &str,
193    params: &PruningParams,
194    provider: &P,
195) -> Result<Vec<McpTool>, PruningError> {
196    let msg_hash = content_hash(task_context);
197    let tl_hash = tool_list_hash(all_tools);
198
199    match cache.lookup(msg_hash, tl_hash) {
200        CacheLookup::Hit(cached) => return Ok(cached.to_vec()),
201        CacheLookup::NegativeHit => {
202            // Negative cache hit: LLM previously failed for this key.
203            // Return all tools as fallback without retrying to avoid retry storms.
204            tracing::warn!("pruning cache: negative hit, returning all tools without LLM call");
205            return Ok(all_tools.to_vec());
206        }
207        CacheLookup::Miss => {}
208    }
209
210    match prune_tools(all_tools, task_context, params, provider).await {
211        Ok(result) => {
212            cache.insert_ok(msg_hash, tl_hash, result.clone());
213            Ok(result)
214        }
215        Err(e) => {
216            cache.insert_failed(msg_hash, tl_hash);
217            Err(e)
218        }
219    }
220}
221
222/// Errors that can occur during tool pruning.
223#[derive(Debug, thiserror::Error)]
224pub enum PruningError {
225    /// LLM call failed.
226    #[error("pruning LLM call failed: {0}")]
227    LlmError(#[from] LlmError),
228    /// Could not extract a valid JSON array from the LLM response.
229    #[error("failed to parse pruning response as JSON array of tool names")]
230    ParseError,
231}
232
233/// Parameters for the `prune_tools` function.
234///
235/// Mirrors `zeph_config::ToolPruningConfig` but lives in `zeph-mcp` to avoid a
236/// circular crate dependency (`zeph-config` → `zeph-mcp`). Callers in `zeph-core`
237/// convert from `ToolPruningConfig`.
238#[derive(Debug, Clone)]
239pub struct PruningParams {
240    /// Maximum number of MCP tools to include after pruning.
241    pub max_tools: usize,
242    /// Minimum number of MCP tools below which pruning is skipped.
243    pub min_tools_to_prune: usize,
244    /// Tool names that are never pruned (always included).
245    ///
246    /// Matches on bare tool `name` (not qualified `server_id:name`).  When two
247    /// MCP servers expose a tool with the same name, both instances are pinned.
248    /// This is intentional: the config is user-facing and users specify tool
249    /// names, not server-qualified identifiers.
250    pub always_include: Vec<String>,
251}
252
253impl Default for PruningParams {
254    fn default() -> Self {
255        Self {
256            max_tools: 15,
257            min_tools_to_prune: 10,
258            always_include: Vec::new(),
259        }
260    }
261}
262
263/// Prune MCP tools to those relevant to the current task.
264///
265/// Returns a filtered subset of `all_tools` based on the LLM's assessment of relevance
266/// to `task_context`. Tools listed in `params.always_include` bypass the LLM filter.
267///
268/// # Behavior
269///
270/// - If `all_tools.len() < params.min_tools_to_prune`, returns `Ok(all_tools.to_vec())`.
271/// - On LLM failure or parse failure, returns `Err(PruningError)` — the caller should
272///   fall back to the full tool list and log at `WARN` level.
273/// - Result is capped at `params.max_tools` total tools. `max_tools == 0` means no cap.
274///
275/// # Errors
276///
277/// Returns `PruningError::LlmError` if the provider call fails.
278/// Returns `PruningError::ParseError` if the response cannot be parsed as a JSON array.
279pub async fn prune_tools<P: LlmProvider>(
280    all_tools: &[McpTool],
281    task_context: &str,
282    params: &PruningParams,
283    provider: &P,
284) -> Result<Vec<McpTool>, PruningError> {
285    if all_tools.len() < params.min_tools_to_prune {
286        return Ok(all_tools.to_vec());
287    }
288
289    // Partition: always-include tools bypass the LLM filter.
290    let (pinned, candidates): (Vec<_>, Vec<_>) = all_tools
291        .iter()
292        .partition(|t| params.always_include.iter().any(|a| a == &t.name));
293
294    // Build the pruning prompt.
295    // Sanitize tool names and descriptions before interpolation to prevent prompt injection
296    // from attacker-controlled MCP servers.
297    let tool_list = candidates.iter().fold(String::new(), |mut acc, t| {
298        let name = sanitize_tool_name(&t.name);
299        let desc = sanitize_tool_description(&t.description);
300        let _ = writeln!(acc, "- {name}: {desc}");
301        acc
302    });
303
304    let prompt = format!(
305        "Return a JSON array of tool names that are relevant to the task below.\n\
306         Return ONLY the JSON array, no explanation, no markdown.\n\n\
307         Task: {task_context}\n\n\
308         Available tools:\n{tool_list}"
309    );
310
311    let messages = vec![Message::from_legacy(Role::User, prompt)];
312    let response = provider.chat(&messages).await?;
313
314    // Parse: strip markdown fences, find first `[` to last `]`.
315    let relevant_names = parse_name_array(&response)?;
316
317    // always_include tools are added unconditionally and bypass the max_tools cap;
318    // max_tools applies only to LLM-selected candidates.
319    let mut result: Vec<McpTool> = pinned.into_iter().cloned().collect();
320    let mut candidates_added: usize = 0;
321    for tool in &candidates {
322        // max_tools == 0 means no cap on LLM-selected candidates.
323        if params.max_tools > 0 && candidates_added >= params.max_tools {
324            break;
325        }
326        if relevant_names.iter().any(|n| n == &tool.name) {
327            result.push((*tool).clone());
328            candidates_added += 1;
329        }
330    }
331
332    Ok(result)
333}
334
335/// Sanitize a tool name before interpolating into an LLM prompt.
336///
337/// Strips control characters and caps at 64 characters.
338fn sanitize_tool_name(name: &str) -> String {
339    name.chars().filter(|c| !c.is_control()).take(64).collect()
340}
341
342/// Sanitize a tool description before interpolating into an LLM prompt.
343///
344/// Strips control characters and caps at 200 characters.
345fn sanitize_tool_description(desc: &str) -> String {
346    desc.chars().filter(|c| !c.is_control()).take(200).collect()
347}
348
349/// Extract tool names from an LLM response expected to contain a JSON array of strings.
350///
351/// Handles markdown code fences (` ```json ... ``` `) and leading/trailing whitespace.
352fn parse_name_array(response: &str) -> Result<Vec<String>, PruningError> {
353    // Strip markdown code fence lines.
354    let stripped = response
355        .lines()
356        .filter(|l| !l.trim_start().starts_with("```"))
357        .collect::<Vec<_>>()
358        .join("\n");
359
360    // Find the first `[` and last `]` to isolate the JSON array.
361    let start = stripped.find('[').ok_or(PruningError::ParseError)?;
362    let end = stripped.rfind(']').ok_or(PruningError::ParseError)?;
363    if end <= start {
364        return Err(PruningError::ParseError);
365    }
366
367    let json_fragment = &stripped[start..=end];
368    let names: Vec<String> =
369        serde_json::from_str(json_fragment).map_err(|_| PruningError::ParseError)?;
370    Ok(names)
371}
372
373#[cfg(test)]
374mod tests {
375    use zeph_llm::mock::MockProvider;
376
377    use super::*;
378
379    fn make_tool(name: &str, description: &str) -> McpTool {
380        McpTool {
381            server_id: "test".into(),
382            name: name.into(),
383            description: description.into(),
384            input_schema: serde_json::Value::Null,
385            security_meta: crate::tool::ToolSecurityMeta::default(),
386        }
387    }
388
389    fn make_tool_with_server(server_id: &str, name: &str, description: &str) -> McpTool {
390        McpTool {
391            server_id: server_id.into(),
392            name: name.into(),
393            description: description.into(),
394            input_schema: serde_json::Value::Null,
395            security_meta: crate::tool::ToolSecurityMeta::default(),
396        }
397    }
398
399    /// Build params with low `min_tools_to_prune` so tests aren't skipped early.
400    fn params_with_max(max_tools: usize) -> PruningParams {
401        PruningParams {
402            max_tools,
403            min_tools_to_prune: 1,
404            always_include: Vec::new(),
405        }
406    }
407
408    #[test]
409    fn parse_plain_array() {
410        let names = parse_name_array(r#"["bash", "read", "write"]"#).unwrap();
411        assert_eq!(names, vec!["bash", "read", "write"]);
412    }
413
414    #[test]
415    fn parse_array_with_markdown_fences() {
416        let input = "```json\n[\"bash\", \"read\"]\n```";
417        let names = parse_name_array(input).unwrap();
418        assert_eq!(names, vec!["bash", "read"]);
419    }
420
421    #[test]
422    fn parse_array_with_preamble() {
423        let input = "Here are the relevant tools:\n[\"bash\", \"read\"]";
424        let names = parse_name_array(input).unwrap();
425        assert_eq!(names, vec!["bash", "read"]);
426    }
427
428    #[test]
429    fn parse_empty_array() {
430        let names = parse_name_array("[]").unwrap();
431        assert!(names.is_empty());
432    }
433
434    #[test]
435    fn parse_invalid_returns_error() {
436        assert!(parse_name_array("not json").is_err());
437        assert!(parse_name_array("").is_err());
438        assert!(parse_name_array("{\"key\": \"val\"}").is_err());
439    }
440
441    // Replaced below_min_detected tautology (#2300): call prune_tools with a failing
442    // mock to verify the early-return path fires before the LLM is ever contacted.
443    #[tokio::test]
444    async fn below_min_detected_early_return() {
445        let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
446        // MockProvider::failing() would panic on any LLM call — if prune_tools invokes it,
447        // the test will error rather than pass.
448        let provider = MockProvider::failing();
449        let params = PruningParams {
450            max_tools: 0,
451            min_tools_to_prune: 10, // 5 tools < 10 → early return before LLM
452            always_include: Vec::new(),
453        };
454
455        let result = prune_tools(&tools, "task", &params, &provider)
456            .await
457            .unwrap();
458        assert_eq!(result.len(), 5, "all tools returned when below threshold");
459    }
460
461    #[tokio::test]
462    async fn always_include_pinned() {
463        let tools = vec![
464            make_tool("pinned", "always here"),
465            make_tool("candidate_a", "desc a"),
466            make_tool("candidate_b", "desc b"),
467        ];
468        // LLM returns only candidate_a; pinned must still appear.
469        let provider = MockProvider::with_responses(vec![r#"["candidate_a"]"#.into()]);
470        let params = PruningParams {
471            max_tools: 0,
472            min_tools_to_prune: 1,
473            always_include: vec!["pinned".into()],
474        };
475
476        let result = prune_tools(&tools, "task", &params, &provider)
477            .await
478            .unwrap();
479        assert!(
480            result.iter().any(|t| t.name == "pinned"),
481            "pinned must survive pruning"
482        );
483        assert!(result.iter().any(|t| t.name == "candidate_a"));
484    }
485
486    /// S4: `always_include` pins tools by bare name across multiple servers.
487    #[tokio::test]
488    async fn always_include_matches_bare_name_across_servers() {
489        let tools = vec![
490            make_tool_with_server("server_a", "search", "search on A"),
491            make_tool_with_server("server_b", "search", "search on B"),
492            make_tool_with_server("server_a", "other", "other tool"),
493        ];
494        // LLM returns only "other"; both "search" instances should still be pinned.
495        let provider = MockProvider::with_responses(vec![r#"["other"]"#.into()]);
496        let params = PruningParams {
497            max_tools: 0,
498            min_tools_to_prune: 1,
499            always_include: vec!["search".into()],
500        };
501
502        let result = prune_tools(&tools, "task", &params, &provider)
503            .await
504            .unwrap();
505        assert_eq!(result.len(), 3, "both search tools + other must be present");
506        let search_count = result.iter().filter(|t| t.name == "search").count();
507        assert_eq!(
508            search_count, 2,
509            "both server_a:search and server_b:search must be pinned"
510        );
511        assert!(result.iter().any(|t| t.name == "other"));
512    }
513
514    #[tokio::test]
515    async fn max_tools_cap_respected() {
516        let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
517        // LLM returns all 5 as relevant; max_tools=2 must cap candidates.
518        let names_json = r#"["t0","t1","t2","t3","t4"]"#;
519        let provider = MockProvider::with_responses(vec![names_json.into()]);
520
521        let result = prune_tools(&tools, "task", &params_with_max(2), &provider)
522            .await
523            .unwrap();
524        assert_eq!(
525            result.len(),
526            2,
527            "max_tools=2 must cap LLM-selected candidates"
528        );
529    }
530
531    #[tokio::test]
532    async fn llm_failure_propagates() {
533        let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
534        let provider = MockProvider::failing();
535        let result = prune_tools(&tools, "task", &params_with_max(0), &provider).await;
536        assert!(matches!(result, Err(PruningError::LlmError(_))));
537    }
538
539    #[tokio::test]
540    async fn parse_error_propagates() {
541        let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
542        let provider = MockProvider::with_responses(vec!["not valid json at all".into()]);
543        let result = prune_tools(&tools, "task", &params_with_max(0), &provider).await;
544        assert!(matches!(result, Err(PruningError::ParseError)));
545    }
546
547    #[tokio::test]
548    async fn max_tools_zero_means_no_cap() {
549        let tools: Vec<McpTool> = (0..5)
550            .map(|i| make_tool(&format!("tool{i}"), "desc"))
551            .collect();
552        let names_json = r#"["tool0","tool1","tool2","tool3","tool4"]"#;
553        let provider = MockProvider::with_responses(vec![names_json.into()]);
554        let params = params_with_max(0);
555
556        let result = prune_tools(&tools, "any task", &params, &provider)
557            .await
558            .unwrap();
559        assert_eq!(result.len(), 5, "max_tools=0 must not cap the result");
560    }
561
562    #[test]
563    fn description_sanitization_strips_control_chars_and_caps() {
564        // Newline and tab are control characters.
565        let desc = "line1\nline2\tinject";
566        let sanitized = sanitize_tool_description(desc);
567        assert!(!sanitized.contains('\n'));
568        assert!(!sanitized.contains('\t'));
569
570        // Cap at 200 characters.
571        let long_desc = "x".repeat(300);
572        assert_eq!(sanitize_tool_description(&long_desc).len(), 200);
573
574        // Name capped at 64 characters.
575        let long_name = "a".repeat(100);
576        assert_eq!(sanitize_tool_name(&long_name).len(), 64);
577    }
578
579    #[tokio::test]
580    async fn always_include_bypasses_max_tools_cap() {
581        // max_tools=1 — only 1 candidate from LLM allowed; but always_include adds unconditionally.
582        let tools = vec![
583            make_tool("pinned", "always here"),
584            make_tool("candidate_a", "desc a"),
585            make_tool("candidate_b", "desc b"),
586        ];
587        let provider =
588            MockProvider::with_responses(vec![r#"["candidate_a","candidate_b"]"#.into()]);
589        let params = PruningParams {
590            max_tools: 1,
591            min_tools_to_prune: 1,
592            always_include: vec!["pinned".into()],
593        };
594
595        let result = prune_tools(&tools, "task", &params, &provider)
596            .await
597            .unwrap();
598
599        // "pinned" is always present regardless of max_tools.
600        assert!(
601            result.iter().any(|t| t.name == "pinned"),
602            "pinned tool must bypass cap"
603        );
604        // Only 1 candidate slot remains after pinned bypasses cap; total = 1 (pinned) + 1 (candidate).
605        assert_eq!(result.len(), 2);
606    }
607
608    // ── PruningCache tests (#2298, #2300) ────────────────────────────────────
609
610    #[tokio::test]
611    async fn cache_positive_hit() {
612        // Two tools to exceed min_tools_to_prune=1; MockProvider has exactly one response.
613        // The second call must succeed from cache without consuming the (empty) response queue.
614        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
615        let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()]);
616        let params = params_with_max(0);
617        let mut cache = PruningCache::new();
618
619        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
620            .await
621            .unwrap();
622        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
623            .await
624            .unwrap();
625
626        assert_eq!(r1.len(), 2);
627        assert_eq!(r1.len(), r2.len(), "cache hit must return same result");
628    }
629
630    #[tokio::test]
631    async fn cache_miss_on_message_change() {
632        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
633        let provider =
634            MockProvider::with_responses(vec![r#"["t0","t1"]"#.into(), r#"["t0"]"#.into()]);
635        let params = params_with_max(0);
636        let mut cache = PruningCache::new();
637
638        let r1 = prune_tools_cached(&mut cache, &tools, "query_a", &params, &provider)
639            .await
640            .unwrap();
641        let r2 = prune_tools_cached(&mut cache, &tools, "query_b", &params, &provider)
642            .await
643            .unwrap();
644
645        assert_eq!(r1.len(), 2, "first call returns both tools");
646        assert_eq!(
647            r2.len(),
648            1,
649            "different message triggers cache miss and LLM call"
650        );
651    }
652
653    #[tokio::test]
654    async fn cache_miss_on_tool_list_change() {
655        let tools1: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
656        let mut tools2 = tools1.clone();
657        tools2.push(make_tool("t2", "new tool"));
658
659        let provider = MockProvider::with_responses(vec![
660            r#"["t0","t1"]"#.into(),
661            r#"["t0","t1","t2"]"#.into(),
662        ]);
663        let params = params_with_max(0);
664        let mut cache = PruningCache::new();
665
666        let r1 = prune_tools_cached(&mut cache, &tools1, "query", &params, &provider)
667            .await
668            .unwrap();
669        let r2 = prune_tools_cached(&mut cache, &tools2, "query", &params, &provider)
670            .await
671            .unwrap();
672
673        assert_eq!(r1.len(), 2);
674        assert_eq!(r2.len(), 3, "new tool triggers cache miss");
675    }
676
677    #[tokio::test]
678    async fn cache_negative_hit_skips_llm() {
679        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
680        let provider = MockProvider::failing();
681        let params = params_with_max(0);
682        let mut cache = PruningCache::new();
683
684        // First call: LLM fails → error is returned and negative entry is cached.
685        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider).await;
686        assert!(r1.is_err(), "first call must propagate LLM error");
687
688        // Second call: negative cache hit → returns all tools without calling LLM.
689        // MockProvider::failing() would panic on a second LLM call, proving cache is used.
690        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
691            .await
692            .unwrap();
693        assert_eq!(r2.len(), 2, "negative cache hit must return all tools");
694    }
695
696    #[tokio::test]
697    async fn cache_negative_hit_clears_on_reset() {
698        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
699        // Fail on the first LLM call; succeed on the second (after cache.reset()).
700        let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()])
701            .with_errors(vec![zeph_llm::LlmError::Other("simulated failure".into())]);
702        let params = params_with_max(0);
703        let mut cache = PruningCache::new();
704
705        // First call: LLM fails → negative entry cached.
706        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider).await;
707        assert!(r1.is_err());
708
709        // Reset clears the negative entry.
710        cache.reset();
711
712        // After reset the LLM is retried; the queued success response is now returned.
713        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
714            .await
715            .unwrap();
716        assert_eq!(r2.len(), 2, "after reset the LLM must be retried");
717    }
718}